Как я могу сделать бэкпроп через PyTorch LSTM, созданный из примитивных тензорных операций (например, не `nn.LSTM`) - PullRequest
0 голосов
/ 07 декабря 2018

РЕДАКТИРОВАТЬ : проблема с моей реализацией заключалась в том, чтобы пытаться извлечь мой output, IE, вектор с одним «горячим», прямо из скрытого состояния.Вместо этого я добавил сверху плотный слой, и он отлично работает.


Я пытаюсь сделать LSTM из более примитивных операций в PyTorch и использую функции torch.autograd для ошибок backprop.Я бы хотел, чтобы он был «онлайн», так как h и c накапливают свое состояние с течением времени, и на каждом временном шаге есть 1 символ и 1 выход.

Это символуровень rnn, так:

  • мой словарь состоит из 30 символов (строчная буква z и некоторые знаки препинания)

  • inp - это однообразный вектордлина 30.

  • h и c имеют длину 30 + 100. Первые 30 из h - это мой "вывод"

  • моя потеря сравнивает символ target с горячим кодированием с этими первыми 30 индексами h.

  • Я накапливаю loss за 10 шагов, а затемне знаю, как правильно сделать это обратно.Ниже приведена (неудачная) попытка.

TL; DR.Как правильно сделать бэкпроп для этого LSTM?

    def ff(inp, h, c):
        xh = torch.cat((inp, h), 0)
        f = (xh @ Wf + bf).sigmoid()
        i = (xh @ Wi + bi).sigmoid()
        g = (xh @ Wg + bg).tanh() # C-bar, in some literature
        c = f * c + i * g
        o = (xh @ Wo + bo).sigmoid()
        h = o * c.tanh()
        return h, c

    loss = torch.zeros(1)
    def bp(out, target, lr):
        global Wf, Wi, Wg, Wo
        global bf, bi, bg, bo
        global h, c
        global loss

        # Accumulate loss every step
        loss += (-target * out[:out_n].softmax(dim=0).log()).sum()

        # Every 10 chars, run backprop
        if i % 10 == 0:
            loss.backward()

            with torch.no_grad():
                for param in [Wf, Wi, Wg, Wo, bf, bi, bg, bo]:
                    param -= lr * param.grad
                    param.grad.zero_()

            h.detach_()
            c.detach_()
            loss.detach_()
            loss = torch.zeros(1)

        return loss
...