Сбой Pytorch при вводе в режиме eval - PullRequest
0 голосов
/ 24 марта 2020

Моя модель работает отлично, но когда я переключаю ее в режим оценки, ей не нравятся типы данных входных выборок:

Traceback (most recent call last):
  File "model.py", line 558, in <module>
    main_function(train_sequicity=args.train)
  File "model.py", line 542, in main_function
    out = model(user, bspan, response_, degree)
  File "/home/memduh/git/project/venv/lib/python3.6/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "model.py", line 336, in forward
    self.params['bspan_size'])
  File "model.py", line 283, in _greedy_decode_output
    out = decoder(input_, encoder_output)
  File "/home/memduh/git/project/venv/lib/python3.6/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "model.py", line 142, in forward
    tgt = torch.cat([go_tokens, tgt], dim=0)  # concat GO_2 token along sequence lenght axis
RuntimeError: Expected object of scalar type Long but got scalar type Float for sequence element 1 in sequence argument at position #1 'tensors'

Это происходит в той части кода, где конкатенация случается. Это архитектура, подобная преобразователю pytorch, только что модифицированная, чтобы иметь два декодера:

      def forward(self, tgt, memory):
          """ Call decoder
          the decoder should be called repeatedly

          Args:
              tgt: input to transformer_decoder, shape: (seq, batch)
              memory: output from the encoder

          Returns:
              output from linear layer, (vocab size), pre softmax

          """
          go_tokens = torch.zeros((1, tgt.size(1)), dtype=torch.int64) + 3  # GO_2 token has index 3

          tgt = torch.cat([go_tokens, tgt], dim=0)  # concat GO_2 token along sequence lenght axis

+
          mask = tgt.eq(0).transpose(0,1)  # 0 corresponds to <pad>
          tgt = self.embedding(tgt) * self.ninp
          tgt = self.pos_encoder(tgt)
          tgt_mask = self._generate_square_subsequent_mask(tgt.size(0))
          output = self.transformer_decoder(tgt, memory, tgt_mask=tgt_mask, tgt_key_padding_mask=mask)
          output = self.linear(output)
          return output

Бит конкатенации в середине кодового блока - это место, где возникает проблема. Странно то, что он отлично работает и тренируется, а потери снижаются в режиме поезда. Эта проблема возникает только в режиме eval. В чем может быть проблема?

1 Ответ

1 голос
/ 24 марта 2020

Ошибки кажутся очевидными: tgt - это Float, но ожидалось, что оно будет Long. Почему?

В своем коде вы определяете, что go_tokens равно torch.int64 (то есть Long):

def forward(self, tgt, memory):
    go_tokens = torch.zeros((1, tgt.size(1)), dtype=torch.int64) + 3  # GO_2 token has index 3
    tgt = torch.cat([go_tokens, tgt], dim=0)  # concat GO_2 token along sequence lenght axis
    # [...]

Вы можете избежать этой ошибки, сказав, что go_tokens должен иметь тот же тип данных, что и tgt:

def forward(self, tgt, memory):
    go_tokens = torch.zeros((1, tgt.size(1)), dtype=tgt.dtype) + 3  # GO_2 token has index 3
    tgt = torch.cat([go_tokens, tgt], dim=0)  # concat GO_2 token along sequence lenght axis
    # [...]

Теперь, если остальная часть кода полагается на tgt, равную torch.int64, то вы должны определить, почему tgt равно torch.int64 во время обучения и torch.float32 во время теста, в противном случае будет выдана другая ошибка.

...