То, что я хочу сделать, это использовать DataParallel в моем пользовательском классе RNN.
Похоже, я неправильно инициализировал hidden_0 ...
class RNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size, n_layers=1):
super(RNN, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size
self.n_layers = n_layers
self.encoder = nn.Embedding(input_size, hidden_size)
self.gru = nn.GRU(hidden_size, hidden_size, n_layers,batch_first = True)
self.decoder = nn.Linear(hidden_size, output_size)
self.init_hidden(batch_size)
def forward(self, input):
input = self.encoder(input)
output, self.hidden = self.gru(input,self.hidden)
output = self.decoder(output.contiguous().view(-1,self.hidden_size))
output = output.contiguous().view(batch_size,num_steps,N_CHARACTERS)
#print (output.size())10,50,67
return output
def init_hidden(self,batch_size):
self.hidden = Variable(T.zeros(self.n_layers, batch_size, self.hidden_size).cuda())
И я вызываю сетьтаким образом:
decoder = T.nn.DataParallel(RNN(N_CHARACTERS, HIDDEN_SIZE, N_CHARACTERS), dim=1).cuda()
Затем начните тренировку:
for epoch in range(EPOCH_):
hidden = decoder.init_hidden()
Но я получаю ошибку, и у меня нет идеального способа ее исправить ...
Объект 'DataParallel' не имеет атрибута 'init_hidden'
Спасибо за помощь!