Отслеживание устаревшего предупреждения в pytorch - PullRequest
2 голосов
/ 05 февраля 2020

Я тренирую yolov3 на своих данных, используя этот код здесь: https://github.com/cfotache/pytorch_custom_yolo_training/

Но я получаю это раздражающее предупреждение об устаревании

Warning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead. (expandTensors at /pytorch/aten/src/ATen/native/IndexingUtils.h:20)

Я пытался использовать python3 -W ignore train.py Я попытался добавить:

import warnings
warnings.filterwarnings('ignore')

, но предупреждение все еще остается постоянным.

Я нашел этот фрагмент кода здесь в stackoverflow, который печатает этот стек в предупреждениях,

import traceback
import warnings
import sys

def warn_with_traceback(message, category, filename, lineno, file=None, line=None):

    log = file if hasattr(file,'write') else sys.stderr
    traceback.print_stack(file=log)
    log.write(warnings.formatwarning(message, category, filename, lineno, line))

warnings.showwarning = warn_with_traceback

и вот что я получаю:

  File "/content/pytorch_custom_yolo_training/train.py", line 102, in <module>
   loss = model(imgs, targets)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/content/pytorch_custom_yolo_training/models.py", line 267, in forward
    x, *losses = module[0](x, targets)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/content/pytorch_custom_yolo_training/models.py", line 203, in forward
    loss_x = self.mse_loss(x[mask], tx[mask])
  File "/usr/lib/python3.6/warnings.py", line 99, in _showwarnmsg
    msg.file, msg.line)
  File "/content/pytorch_custom_yolo_training/train.py", line 29, in warn_with_traceback
    traceback.print_stack(file=log)
  /pytorch/aten/src/ATen/native/IndexingUtils.h:20: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.

Переходя к файлам и функциям, упомянутым в стеке, я не нахожу никаких uint8. Что я могу решить проблему или даже прекратить получать эти предупреждения?

1 Ответ

0 голосов
/ 05 февраля 2020

Нашел проблему. строка: loss_x = self.mse_loss(x[mask], tx[mask]) переменная mask была ByteTensor, которая устарела. Просто заменил его на BoolTensor

...