Производный класс Pytorch nn.Module не может быть загружен при импорте модуля в Python - PullRequest
0 голосов
/ 04 мая 2020

Использование Python 3.6 с Pytorch 1.3.1. Я заметил, что некоторые сохраненные nn.Modules не могут быть загружены, когда весь модуль импортируется в другой модуль. Чтобы привести пример, вот шаблон минимального рабочего примера.

#!/usr/bin/env python3
#encoding:utf-8
# file 'dnn_predict.py'

from torch import nn
class NN(nn.Module):##NN network
    # Initialisation and other class methods

networks=[torch.load(f=os.path.join(resource_directory, 'nn-classify-cpu_{fold}.pkl'.format(fold=fold))) for fold in range(5)]
...
if __name__=='__main__':
    # Some testing snippets
    pass

Весь файл прекрасно работает, когда я запускаю его непосредственно в оболочке. Тем не менее, когда я хочу использовать класс и загрузить нейронную сеть в другой файл, используя этот код, это не удается.

#!/usr/bin/env python3
#encoding:utf-8
from dnn_predict import *

Ошибка: AttributeError: Can't get attribute 'NN' on <module '__main__'>

В Pytorch загрузка сохраненных переменных или импорт модулей происходит иначе, чем в других распространенных библиотеках Python? Некоторая помощь или указатель на причину root будут по достоинству оценены.

1 Ответ

0 голосов
/ 04 мая 2020

Когда вы сохраняете модель с помощью torch.save(model, PATH), весь объект сериализуется с помощью pickle, который не сохраняет сам класс, а указывает путь к файлу, содержащему этот класс, следовательно, при загрузке Модель точно такой же директории и файловой структуры требуется, чтобы найти правильный класс. При запуске сценария Python модуль этого файла __main__, поэтому, если вы хотите загрузить этот модуль, ваш класс NN должен быть определен в сценарии, который вы запускаете.

Это очень негибкий, поэтому рекомендуется не сохранять всю модель, а просто сохранить словарь состояний, в котором сохраняются только параметры модели.

# Save the state dictionary of the model
torch.save(model.state_dict(), PATH)

После этого словарь состояний может быть загружен. и применяется к вашей модели.

from dnn_predict import NN

# Create the model (will have randomly initialised parameters)
model = NN()

# Load the previously saved state dictionary
state_dict = torch.load(PATH)

# Apply the state dictionary to the model
model.load_state_dict(state_dict)

Подробнее о словаре состояния и сохранении / загрузке моделей: PyTorch - сохранение и загрузка моделей

...