Построение рекуррентной нейронной сети с прямой связью в pytorch - PullRequest
0 голосов
/ 03 июля 2018

Я проходил этот учебник. У меня вопрос по поводу следующего кода класса:

class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNN, self).__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size

        self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
        self.i2o = nn.Linear(input_size + hidden_size, output_size)
        self.softmax = nn.LogSoftmax()

    def forward(self, input, hidden):
        combined = torch.cat((input, hidden), 1)
        hidden = self.i2h(combined)
        output = self.i2o(combined)
        output = self.softmax(output)
        return output, hidden

    def init_hidden(self):
        return Variable(torch.zeros(1, self.hidden_size))

Этот код был взят из Здесь . Там было упомянуто, что

Поскольку состояние сети хранится в графике, а не в слоях, вы можете просто создать nn.Linear и использовать его снова и снова для повторения.

Что я не понимаю, так это то, как можно просто увеличить размер входного объекта в nn.Linear и сказать, что это RNN. Что мне здесь не хватает?

1 Ответ

0 голосов
/ 03 июля 2018

Сеть является периодической, поскольку в этом примере вы оцениваете несколько временных шагов. Следующий код также взят из учебника pytorch, который вы связали с .

loss_fn = nn.MSELoss()

batch_size = 10
TIMESTEPS = 5

# Create some fake data
batch = torch.randn(batch_size, 50)
hidden = torch.zeros(batch_size, 20)
target = torch.zeros(batch_size, 10)
loss = 0
for t in range(TIMESTEPS):
    # yes! you can reuse the same network several times,
    # sum up the losses, and call backward!
    hidden, output = rnn(batch, hidden)
    loss += loss_fn(output, target)
loss.backward()

Таким образом, сама сеть не является рекуррентной, но в этом цикле вы используете ее как рекуррентную сеть, подавая скрытое состояние предыдущего шага вперед вместе с вашим пакетным вводом несколько раз.

Вы также можете использовать его как единовременное, просто увеличивая потери на каждом шаге и игнорируя скрытое состояние.

Поскольку состояние сети хранится в графике, а не в слоях, вы можете просто создать nn.Linear и использовать его снова и снова для повторения.

Это означает, что информация для вычисления градиента не содержится в самой модели, поэтому вы можете добавить несколько оценок модуля к графику и затем выполнить обратное распространение по всему графику. Это описано в предыдущих параграфах урока.

...