Невозможно загрузить веса модели в TensorFlow 2 - PullRequest
2 голосов
/ 16 июня 2020

Я не могу загрузить веса модели после их сохранения в TensorFlow 2.2. Кажется, что веса сохраняются правильно (я думаю), однако я не могу загрузить предварительно обученную модель.

Мой текущий код:

segmentor = sequential_model_1()
discriminator = sequential_model_2()

def save_model(ckp_dir):
    # create directory, if it does not exist:
    utils.safe_mkdir(ckp_dir)

    # save weights
    segmentor.save_weights(os.path.join(ckp_dir, 'checkpoint-segmentor'))
    discriminator.save_weights(os.path.join(ckp_dir, 'checkpoint-discriminator'))

def load_pretrained_model(ckp_dir):
    try:
        segmentor.load_weights(os.path.join(ckp_dir, 'checkpoint-segmentor'), skip_mismatch=True)
        discriminator.load_weights(os.path.join(ckp_dir, 'checkpoint-discriminator'), skip_mismatch=True)
        print('Loading pre-trained model from: {0}'.format(ckp_dir))
    except ValueError:
        print('No pre-trained model available.')

Затем у меня есть тренировка l oop:

# training loop:
for epoch in range(num_epochs):

    for image, label in dataset:
        train_step()

    # save best model I find during training:
    if this_is_the_best_model_on_validation_set():
        save_model(ckp_dir='logs_dir')

А потом, по окончании обучения «для l oop», я хочу загрузить лучшую модель и провести с ней тест. Следовательно, я запускаю:

# load saved model and do a test:
load_pretrained_model(ckp_dir='logs_dir')
test()

Однако это приводит к ValueError. Я проверил каталог, в котором должны быть сохранены веса, и вот они!

Есть идеи, что не так с моим кодом? Я неправильно загружаю веса?

Спасибо!

1 Ответ

3 голосов
/ 22 июня 2020

Хорошо, вот ваша проблема - имеющийся у вас блок try-except скрывает реальную проблему. Удаление дает ValueError:

ValueError: When calling model.load_weights, skip_mismatch can only be set to True when by_name is True.

Есть два способа смягчить это - вы можете либо вызвать load_weights с помощью by_name=True, либо удалить skip_mismatch=True в зависимости от твои нужды. Любой случай работает для меня при тестировании вашего кода.

Еще одно соображение заключается в том, что при сохранении контрольных точек дискриминатора и сегментатора в каталоге журнала вы каждый раз перезаписываете файл checkpoint. Он содержит две строки, которые задают путь к файлам контрольных точек конкретной модели c. Поскольку вы сохраняете дискриминатор второй, каждый раз в этом файле будет отображаться дискриминатор без ссылки на сегментатор. Вы можете смягчить это, сохраняя вместо этого каждую модель в двух подкаталогах в каталоге журнала, т.е.

logs_dir/
    + discriminator/
        + checkpoint
        + ...
    + segmentor/
        + checkpoint
        + ...

Хотя в текущем состоянии ваш код будет работать в этом случае.

...