Я пытаюсь создать RNN с нуля с помощью pytorch, и я следую этому руководству , чтобы создать его.
import torch
import torch.nn as nn
import torch.nn.functional as F
class BasicRNN(nn.Module):
def __init__(self, n_inputs, n_neurons):
super(BasicRNN, self).__init__()
self.Wx = torch.randn(n_inputs, n_neurons) # n_inputs X n_neurons
self.Wy = torch.randn(n_neurons, n_neurons) # n_neurons X n_neurons
self.b = torch.zeros(1, n_neurons) # 1 X n_neurons
def forward(self, X0, X1):
self.Y0 = torch.tanh(torch.mm(X0, self.Wx) + self.b) # batch_size X n_neurons
self.Y1 = torch.tanh(torch.mm(self.Y0, self.Wy) +
torch.mm(X1, self.Wx) + self.b) # batch_size X n_neurons
return self.Y0, self.Y1
class CleanBasicRNN(nn.Module):
def __init__(self, batch_size, n_inputs, n_neurons):
super(CleanBasicRNN, self).__init__()
self.rnn = BasicRNN(n_inputs, n_neurons)
self.hx = torch.randn(batch_size, n_neurons) # initialize hidden state
def forward(self, X):
output = []
# for each time step
for i in range(2):
self.hx = self.rnn(X[i], self.hx)
output.append(self.hx)
return output, self.hx
FIXED_BATCH_SIZE = 4 # our batch size is fixed for now
N_INPUT = 3
N_NEURONS = 5
X_batch = torch.tensor([[[0,1,2], [3,4,5],
[6,7,8], [9,0,1]],
[[9,8,7], [0,0,0],
[6,5,4], [3,2,1]]
], dtype = torch.float) # X0 and X1
model = CleanBasicRNN(FIXED_BATCH_SIZE,N_INPUT,N_NEURONS)
a1,a2 = model(X_batch)
Выполнение этого кода возвращает эту ошибку
RuntimeError: несоответствие размера, m1: [4 x 5], m2: [3 x 5] в /pytorch/..
После некоторого рытья я обнаружил, что эта ошибка возникает при передаче скрытого состояний к модели BasicRNN
N_INPUT = 3 # number of features in input
N_NEURONS = 5 # number of units in layer
X0_batch = torch.tensor([[0,1,2], [3,4,5],
[6,7,8], [9,0,1]],
dtype = torch.float) #t=0 => 4 X 3
X1_batch = torch.tensor([[9,8,7], [0,0,0],
[6,5,4], [3,2,1]],
dtype = torch.float) #t=1 => 4 X 3
test_model = BasicRNN(N_INPUT,N_NEURONS)
a1,a2 = test_model(X0_batch,X1_batch)
a1,a2 = test_model(X0_batch,torch.randn(1,N_NEURONS)) # THIS LINE GIVES ERROR
Что происходит в скрытых состояниях и как я могу решить эту проблему?