Как сохранить промежуточные веса при обучении нейронной сети - PullRequest
0 голосов
/ 08 ноября 2019

Я тренирую нейронную сеть, используя pytorch, и я хочу сохранять веса на каждой итерации. Другими словами, я хочу создать список, содержащий все веса, которые нейронная сеть имела во время обучения.

Я сделал следующее:

for i, (images, labels) in enumerate(train_loader):

     (.....code that is used to train the model here.....)

     weight = model.fc2.weight.detach().numpy()
     weights_list.append(weight)

Когда я затем печатаю записи списка weights_list, я замечаю, что они все одинаковы, что не может быть правдой, потому что у меня естьнапечатали веса во время обучения, и они меняются (и сеть действительно учится, поэтому они должны). Я предполагаю, что каждая запись списка на самом деле является указателем на вес сети в момент проверки списка. Итак:

1) Правильно ли мое предположение? 2) Как я могу решить эту проблему?

Спасибо!

1 Ответ

0 голосов
/ 08 ноября 2019

Встроена функция сохранения и загрузки весов. Для сохранения в файл вы можете использовать

torch.save('checkpoint.pt', model.state_dict())

, а для загрузки вы можете использовать

model.load_state_dict(torch.load('checkpoint.pt'))

Этосказал, что преобразование в numpy не обязательно создает копию. Например, если у вас есть пустой массив y и вы хотите создать копию, вы можете использовать

x = numpy.copy(y)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...