Тестирование проблемы Pytorch CNN: RuntimeError: shape '[64, 1]' недопустим для ввода размером 1920 - PullRequest
0 голосов
/ 13 марта 2019

Тестирование сети:

 def test(args, model, device, test_loader):
           model.eval()
           total = 0
           test_loss = 0
           correct = 0
           with torch.no_grad():
           for batch_idx, batch in enumerate(test_loader):
               data = batch['image']
               target = batch['key_points']
               data, target = data.to(device), target.to(device)
               # data, target = Variable(data), Variable(target)
           # for data, target in test_loader:
           #     data, target = data.to(device), target.to(device)
           #     data, target = Variable(data), Variable(target)
                 data = data.unsqueeze(1).float() 

                 print('TESTING 1: Data and target shape: ', data.shape, ' 
                       ', target.shape)
                 output = model(data)
                 target = target.view(target.shape[0], -1)  

                 print('TESTING 2: Data and target shape: ', data.shape, ' 
                       ', target.shape)
                 # test_loss += F.nll_loss(output, target, 
                 reduction='sum').item()  # sum up batch loss orginal
                 test_loss = F.nll_loss(output, torch.max(target, 1)[1])
                  # test_loss += F.nll_loss(output, 
                  torch.max(target.float(), 1)[1], reduction='sum').item()

                   print('TESTING 3: Data and target shape: ', data.shape, 
                       ' ', target.shape)
                   # pred = output.max(1, keepdim=True)[1]  
                   pred = output.argmax(dim=1, keepdim=True)  # original

                   print('TESTING 4: Data and target shape: ', data.shape, 
                       ' ', target.shape)

                   correct += pred.eq(target.view_as(pred)).sum().item()   
                   print('TESTING last: Data and target shape: ', 
                   data.shape, ' ', target.shape)

                   test_loss /= len(test_loader.dataset)
                   print('\nTest set: Average loss: {:.4f}, Accuracy: 
                        {}/{} ({:.0f}%)\n'.format(
                         test_loss , correct, len(test_loader.dataset),
                         100. * correct / len(test_loader.dataset)))

Я могу запустить цикл обучения, что нормально, однако, когда я пытаюсь запустить цикл тестирования, обучение будет проходить только 1 эпоху, и я получаю следующее сообщение.

TESTING 1: Data and target shape:  torch.Size([64, 1, 96, 96])  
torch.Size([64, 15, 2]) TESTING 2: Data and target shape: 
torch.Size([64, 1, 96, 96])   torch.Size([64, 30]) 
TESTING 3: Data and target shape:  torch.Size([64, 1, 96, 96])  
torch.Size([64, 30])
TESTING 4: Data and target shape:  torch.Size([64, 1, 96, 96])  
 torch.Size([64, 30]) Traceback (most recent call last):   File
 "/home/keith/PycharmProjects/FacialLandMarks/WorkOut.py", line 468, in <module>
    main()   File "/home/keith/PycharmProjects/FacialLandMarks/WorkOut.py", line 463, in main
    test(args, model, device, test_loader)   File "/home/keith/PycharmProjects/FacialLandMarks/WorkOut.py", line 380, in test
    correct += pred.eq(target.view_as(pred)).sum().item()   
 RuntimeError: shape '[64, 1]' is invalid for input of size 1920
 Process finished with exit code 1

Мне не удалось найти какие-либо полезные ресурсы в Интернете или в ряде книг, которые я приобрел, по этой конкретной проблеме, на этапе тестирования обнаружения лицевых точек. Я думаю, что проблема отображается в сообщении об ошибке (помечено комментарием) или с моими данными. Я использую отдельный набор тестовых изображений с обучением ключевых точек CSV-файла. Как вы можете видеть, я распечатал форму данных на этапе тестирования. Любая помощь или полезные ссылки всегда приветствуются. Спасибо

Когда я использую отладчик pycharm и вижу точки останова, когда попадаю в строку кода ниже:

pred = output.max(1, keepdim=True)[1]

Тогда пред тензор изменится на это:

tensor([[29],
        [29],
        [29],
        [29],
        [29],
        [29],
        [29],
        [29],

         ...
        [29],
        [29],
        [29],
        [29],
        [29],
        [29],
        [29]], device='cuda:0')

1 Ответ

0 голосов
/ 12 апреля 2019

Предположительно, форма 'prered' - [64,1], тогда как форма target - [64,30].Теперь, если вы звоните target.view_as(pred), вы пытаетесь просмотреть target в той же форме, что и pred, но target имеет 64 * 30 = 1920 записей, тогда как pred имеет только 64, так что вот гдеошибка исходит от.

Вы уверены, что ваши цели верны?Потому что вы предсказываете одномерный выход, но пытаетесь сравнить его с 30-мерной целью.

...