Когда вы сохраняете модель с помощью torch.save(model, PATH)
, весь объект сериализуется с помощью pickle
, который не сохраняет сам класс, а указывает путь к файлу, содержащему этот класс, следовательно, при загрузке Модель точно такой же директории и файловой структуры требуется, чтобы найти правильный класс. При запуске сценария Python модуль этого файла __main__
, поэтому, если вы хотите загрузить этот модуль, ваш класс NN
должен быть определен в сценарии, который вы запускаете.
Это очень негибкий, поэтому рекомендуется не сохранять всю модель, а просто сохранить словарь состояний, в котором сохраняются только параметры модели.
# Save the state dictionary of the model
torch.save(model.state_dict(), PATH)
После этого словарь состояний может быть загружен. и применяется к вашей модели.
from dnn_predict import NN
# Create the model (will have randomly initialised parameters)
model = NN()
# Load the previously saved state dictionary
state_dict = torch.load(PATH)
# Apply the state dictionary to the model
model.load_state_dict(state_dict)
Подробнее о словаре состояния и сохранении / загрузке моделей: PyTorch - сохранение и загрузка моделей