Как использовать файл модели контрольной точки в pytorch для проверки набора данных CIFAR-10? - PullRequest
0 голосов
/ 23 октября 2018
model = SqueezeNext()
model = model.to(device)

def load_checkpoint(model, optimizer, losslogger, filename='SqNxt_23_1x_Cifar.ckpt'):
# Note: Input model & optimizer should be pre-defined.  This routine only updates their states.
start_epoch = 0
if os.path.isfile(filename):
    print("=> loading checkpoint '{}'".format(filename))
    checkpoint = torch.load(filename)
    start_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    losslogger = checkpoint['losslogger']
    print("=> loaded checkpoint '{}' (epoch {})"
              .format(filename, checkpoint['epoch']))
else:
    print("=> no checkpoint found at '{}'".format(filename))


return model, optimizer, start_epoch, losslogger

model, optimizer, start_epoch, losslogger = load_checkpoint(model, optimizer, losslogger)

TypeError: Traceback (последний последний вызов) в () 41 test_loader = torch.utils.data.DataLoader (test_dataset, batch_size = 80, num_workers = 8, shuffle = False) 42---> 43 модель = SqueezeNext () 44 модель = model.to (устройство) 45 def load_checkpoint (модель, оптимизатор, losslogger, имя файла = 'SqNxt_23_1x_Cifar.ckpt'): ошибка типа: init () отсутствует3 обязательных позиционных аргумента: 'width_x', 'blocks' и 'num_classes'

Я думаю, что я не реализую это правильно !!

...