Об использовании RNN в pytorch - PullRequest
0 голосов
/ 17 апреля 2020

Я пытаюсь использовать RNN для двоичной классификации. Но когда моя модель тренируется, она застревает на loss.backward(). Вот моя модель:

class RNN2(nn.Module):
    def __init__(self, input_size, hidden_size, output_size=2, num_layers=1):
        super(RNN2, self).__init__()
        self.rnn = nn.RNN(input_size, hidden_size, num_layers)
        self.reg = nn.Linear(hidden_size, output_size)
        #self.softmax = nn.LogSoftmax(dim=1)

    def forward(self,x):
        x, hidden = self.rnn(x)
        return self.reg(x[:,2])

rnn = RNN2(13,10)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(rnn.parameters(), lr=learning_rate)
for e in range(10):
    out = rnn(train_X)
    optimizer.zero_grad()
    print(out[0])
    print(out.shape)
    print(train_Y.shape)
    loss = criterion(out, train_Y)
    print(loss)
    loss.backward()
    print("1")
    optimizer.step()
    print("2")

Форма train_X равна 420000 * 3 * 13, а форма train_Y равна 420000. Может кто-нибудь сказать мне, почему он застревает на loss.backward(). Он не может печатать 1.

1 Ответ

0 голосов
/ 19 апреля 2020

Вы должны знать, что в RRN вычисление обратной функции для последовательности длиной 420000 происходит чрезвычайно медленно. Если вы запускаете свой код на компьютере с графическим процессором (или google colab) и добавляете следующие строки перед для l oop, выполнение кода завершается менее чем за две минуты.

rnn = rnn.cuda()
train_X = train_X.cuda()
train_Y = train_Y.cuda()

Обратите внимание, что по умолчанию второе входное измерение, переданное в RNN, будет считаться размером пакета. Поэтому, если 420000 - это количество пакетов, передайте batch_first=True конструктору RNN.

self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)

Это значительно ускорит процесс (менее одной секунды в google colab). Однако, если это не так, попробуйте разделить последовательности на более мелкие части и увеличить размер пакета с 3 до большего значения.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...