Как загрузить контрольную точку обучения Chainer из npz? - PullRequest
0 голосов
/ 24 марта 2020

Я использую Chainer для обучения (точной настройки) модели Re snet, а затем использую контрольную точку для оценки. Контрольная точка - это файл npz со следующей структурой:

File list in npz checkpoint

Когда я загружаю модель для оценки с chainer.serializers.load_npz(args.load, model) (где модель является стандартом re snet) Я получаю следующую ошибку: KeyError: 'rpn / loc / b не является файлом в архиве'.

Мне кажется, проблема в том, что файлы в модели не имеют ' Префикс обновления / оптимизатора / более быстрого / экстрактора '.

Как мне изменить имя файла в полученном npz, чтобы удалить префикс, или что еще я должен сделать, чтобы решить проблему?

Спасибо!

1 Ответ

0 голосов
/ 25 марта 2020

Когда вы загружаете снимок, сгенерированный расширением снимка, вам нужно сделать это из тренера.

chainer.serializers.load_npz(args.load, trainer) Тренер автоматически загрузит состояние программы обновления, оптимизатора и модели.

Вы также можете загрузить только модель вручную, открыв соответствующее поле в снимке и передав его в качестве аргумента функции model.serialize

npz_data = numpy.load(args.load)
snap = chainer.serializers.NpzDeserializer(npz_data)
model.serialize(snap['updater']['model:main'])

Это должно загружать только веса модели

...