ОБНОВЛЕНИЕ: с полной видимостью кода, я думаю, что проблемы в реализации. torch.load загрузит информацию из dict, которая была десериализована в файл. Он загружается как исходный объект dict, поэтому в функции следует ожидать checkpoint == checkpoint (оригинальное определение).
В данном случае, я думаю, что вы на самом деле хотите вызвать загрузку файла, сохраненного как checkpoint.pth
, и первый вызов может не понадобиться.
def load_checkpoint(filepath):
model = torch.load(filepath)
return model
Другая возможность состоит в том, что вложенный объект должен быть тем, что этот объект называется, и тогда это будет просто небольшая корректировка:
def load_checkpoint(filepath):
checkpoint = torch.load(filepath)
model = torch.load_state_dict(checkpoint['state_dict'])
return model
Наиболее вероятная проблема заключается в том, что вы вызываете класс Network, который не содержится в объекте словаря контрольных точек.
Я не могу говорить о реальном уроке или других нюансах урока, самое простое решение - просто вызвать определение класса Network с переменными, уже имеющимися в словаре контрольных точек, например:
model = Network(checkpoint['input_size'],
checkpoint['output_size'],
checkpoint['hidden_layers'],
checkpoint['epochs'],
checkpoint['optimizer'],
checkpoint['class_to_index'])
model.load_state_dict(checkpoint['state_dict'])
return model
Диктовка контрольной точки может иметь только те значения, которые вы ожидаете ('input_size', 'output_size' и т. Д.) Но это только самая очевидная проблема, которую я вижу.