Начало передачи Pytorch V3 приводит к ошибке - max () получил недопустимую комбинацию аргументов - PullRequest
0 голосов
/ 26 июня 2018

Программа для обучения обучению inception_v3 в pytorch, которую я использую, находится здесь: https://drive.google.com/file/d/1zn4z7nOp_wJne0En6zq4WJfwHVVftERT/view?usp=sharing

Я получаю следующую ошибку при запуске программы:

Epoch 0/24   
    ---------------------------------------------------------------------------
    TypeError                                 Traceback (most recent call last)
    <ipython-input-20-cc88ea5f8bd3> in <module>()
          1 model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,
    ----> 2                        num_epochs=25)

    <ipython-input-17-812cf3c4576a> in train_model(model, criterion, optimizer, scheduler, num_epochs)
         33                     outputs = model(inputs)
         34                     print(outputs)
    ---> 35                     _, preds = torch.max(outputs, 1)
         36                     loss = criterion(outputs, labels)
         37 

    TypeError: max() received an invalid combination of arguments - got (tuple, int), but expected one of:
     * (Tensor input)
     * (Tensor input, Tensor other, Tensor out)
     * (Tensor input, int dim, bool keepdim, tuple of Tensors out)

Как это можно исправить? Спасибо

Ответы [ 3 ]

0 голосов
/ 18 сентября 2018

Я нашел проблему отсюда: начало pytorchv3 int line 125. Ошибка заключается в том, что при открытии мода поезда и открытии aux_logits return он возвращает (x, aux). Я решил это

  1. использование output, aux = model(input_var)

  2. если этап == «поезд»: выходы, aux = модель (входы) еще: выходов = модель (входы)

0 голосов
/ 18 декабря 2018

в этом случае я изменил код, как показано ниже, и он работал для меня, Учебник https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html

model_ft = models.inception_v3(pretrained=True)
model_ft.aux_logits=False
0 голосов
/ 27 июня 2018

Строка должна быть такой, как показано ниже:

_, preds = torch.max(outputs.data, 1)

...