Pytorch: возможно ли загрузить модель, когда параметры или размер изменились? - PullRequest
1 голос
/ 30 июня 2019

Модель Pytorch (график, веса и смещения) сохраняется с:

torch.save(self.state_dict(), file)

и загружается с:

self.load_state_dict(torch.load(file))

Но если параметры изменяются, модель выигрывает 't загрузить с ошибкой, например:

RuntimeError: Error(s) in loading state_dict for LeNet5:
    size mismatch for conv1.weight:

Можно ли загрузить модель с измененным размером?Что-то вроде заполнения остальных весов, как при инициализации (если весов больше), и обрезки, если весов меньше?

1 Ответ

1 голос
/ 30 июня 2019

Не существует автоматического способа сделать это - потому что вам нужно явно решить, что делать, когда вещи не совпадают.

Лично, когда мне нужно "принудительно" предварительно-тренированные веса на слегка измененной модели.Я считаю, что работать с самой state_dict - это наиболее удобный способ.

new_model = model( ... )  # construct the new model
new_sd = new_model.state_dict()  # take the "default" state_dict
pre_trained_sd = torch.load(file)  # load the old version pre-trained weights
# merge information from pre_trained_sd into new_sd
# ...
# after merging the state dict you can load it:
new_model.load_state_dict(new_sd)
...