best_state изменяется с моделью во время обучения в pytorch - PullRequest
1 голос
/ 10 июня 2019

Я хочу сохранить лучшую модель, а затем загрузить ее во время теста.Поэтому я использовал следующий метод:

def train():  
    #training steps …  
    if acc > best_acc:  
        best_state = model.state_dict()  
        best_acc = acc
    return best_state  

Затем в основной функции я использовал:

model.load_state_dict(best_state)  

для возобновления модели.

Однако я обнаружил, что best_state всегда совпадает с последним состоянием во время тренировки, а не с лучшим состоянием.Кто-нибудь знает причину и как ее избежать?

Кстати, я знаю, что могу использовать torch.save(the_model.state_dict(), PATH), а затем загрузить модель по the_model.load_state_dict(torch.load(PATH)).Однако я не хочу сохранять параметры в файл, так как функции train и test находятся в одном файле.

Ответы [ 2 ]

1 голос
/ 12 июня 2019

model.state_dict() is OrderedDict

from collections import OrderedDict

Вы можете использовать:

from copy import deepcopy

Для решения проблемы

Вместо:

best_state = model.state_dict() 

Вы должны использовать:

best_state = copy.deepcopy(model.state_dict())

Глубокая (не мелкая) копия заставляет изменяемый экземпляр OrderedDict не изменять best_state, как это происходит.

Вы можете проверить мой другой ответ о сохранении состояния в PyTorch.

0 голосов
/ 12 июня 2019

При сохранении состояния модели вы должны сохранить в сети следующие вещи

1) Оптимизатор состояния и 2) Состояние модели dict

Вы можете определить один метод в вашей модели класса следующим образом

def save_state(state,filename):
    torch.save(state,filename)

''» Когда вы сохраняете состояние, сделайте следующее: '' '

Model model //for example  
model.save_state({'state_dict':model.state_dict(), 'optimizer': optimizer.state_dict()}) 

Сохраненная модель будет сохранена как model.pth.tar (для примера)

Теперь во время загрузки выполните следующие шаги,

checkpoint = torch.load('model.pth.tar')         

model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])

Надеюсь, это поможет вам.

...