Моя модель работает отлично, но когда я переключаю ее в режим оценки, ей не нравятся типы данных входных выборок:
Traceback (most recent call last):
File "model.py", line 558, in <module>
main_function(train_sequicity=args.train)
File "model.py", line 542, in main_function
out = model(user, bspan, response_, degree)
File "/home/memduh/git/project/venv/lib/python3.6/site-packages/torch/nn/modules/module.py", line 541, in __call__
result = self.forward(*input, **kwargs)
File "model.py", line 336, in forward
self.params['bspan_size'])
File "model.py", line 283, in _greedy_decode_output
out = decoder(input_, encoder_output)
File "/home/memduh/git/project/venv/lib/python3.6/site-packages/torch/nn/modules/module.py", line 541, in __call__
result = self.forward(*input, **kwargs)
File "model.py", line 142, in forward
tgt = torch.cat([go_tokens, tgt], dim=0) # concat GO_2 token along sequence lenght axis
RuntimeError: Expected object of scalar type Long but got scalar type Float for sequence element 1 in sequence argument at position #1 'tensors'
Это происходит в той части кода, где конкатенация случается. Это архитектура, подобная преобразователю pytorch, только что модифицированная, чтобы иметь два декодера:
def forward(self, tgt, memory):
""" Call decoder
the decoder should be called repeatedly
Args:
tgt: input to transformer_decoder, shape: (seq, batch)
memory: output from the encoder
Returns:
output from linear layer, (vocab size), pre softmax
"""
go_tokens = torch.zeros((1, tgt.size(1)), dtype=torch.int64) + 3 # GO_2 token has index 3
tgt = torch.cat([go_tokens, tgt], dim=0) # concat GO_2 token along sequence lenght axis
+
mask = tgt.eq(0).transpose(0,1) # 0 corresponds to <pad>
tgt = self.embedding(tgt) * self.ninp
tgt = self.pos_encoder(tgt)
tgt_mask = self._generate_square_subsequent_mask(tgt.size(0))
output = self.transformer_decoder(tgt, memory, tgt_mask=tgt_mask, tgt_key_padding_mask=mask)
output = self.linear(output)
return output
Бит конкатенации в середине кодового блока - это место, где возникает проблема. Странно то, что он отлично работает и тренируется, а потери снижаются в режиме поезда. Эта проблема возникает только в режиме eval. В чем может быть проблема?