Как я могу исправить AttributeError при загрузке контрольной точки? - PullRequest
0 голосов
/ 20 июня 2019

Я работаю над проектом 2 курса с Udacity (Искусственный интеллект с программированием на Python).

Я обучил модель и сохранил ее в checkpoint.pth, и я хочу загрузить checkpoint.pth такЯ могу перестроить модель.

Я написал код для сохранения checkpoint.pth, а также для загрузки контрольной точки.

model.class_to_idx = image_datasets['train_dir'].class_to_idx

model.cpu()

checkpoint = {'input_size': 25088,
              'output_size': 102,
              'hidden_layers': 4096,
              'epochs': epochs,
              'optimizer': optimizer.state_dict(),
              'state_dict': model.state_dict(),
              'class_to_index' : model.class_to_idx
             }


torch.save(checkpoint, 'checkpoint.pth')

def load_checkpoint(filepath):
    checkpoint = torch.load(filepath)

    model = checkpoint.Network(checkpoint['input_size'],
                               checkpoint['output_size'],
                               checkpoint['hidden_layers'],
                               checkpoint['epochs'],
                               checkpoint['optimizer'],
                               checkpoint['class_to_index']
                              )
    model.load_state_dict(checkpoint['state_dict'])

    return model

model = load_checkpoint('checkpoint.pth')

При загрузке checkpoint.pth я получаю сообщение об ошибке:

AttributeError: 'dict' object has no attribute 'Network'

Я хочу успешно загрузить контрольную точку.

Спасибо

1 Ответ

0 голосов
/ 20 июня 2019

ОБНОВЛЕНИЕ: с полной видимостью кода, я думаю, что проблемы в реализации. torch.load загрузит информацию из dict, которая была десериализована в файл. Он загружается как исходный объект dict, поэтому в функции следует ожидать checkpoint == checkpoint (оригинальное определение).

В данном случае, я думаю, что вы на самом деле хотите вызвать загрузку файла, сохраненного как checkpoint.pth, и первый вызов может не понадобиться.

def load_checkpoint(filepath):
    model = torch.load(filepath)
    return model

Другая возможность состоит в том, что вложенный объект должен быть тем, что этот объект называется, и тогда это будет просто небольшая корректировка:

def load_checkpoint(filepath):
    checkpoint = torch.load(filepath)
    model = torch.load_state_dict(checkpoint['state_dict'])
    return model

Наиболее вероятная проблема заключается в том, что вы вызываете класс Network, который не содержится в объекте словаря контрольных точек.

Я не могу говорить о реальном уроке или других нюансах урока, самое простое решение - просто вызвать определение класса Network с переменными, уже имеющимися в словаре контрольных точек, например:

model = Network(checkpoint['input_size'],
                checkpoint['output_size'],
                checkpoint['hidden_layers'],
                checkpoint['epochs'],
                checkpoint['optimizer'],
                checkpoint['class_to_index'])
model.load_state_dict(checkpoint['state_dict'])

return model

Диктовка контрольной точки может иметь только те значения, которые вы ожидаете ('input_size', 'output_size' и т. Д.) Но это только самая очевидная проблема, которую я вижу.

...