Что означает "model.trainable = False" в Керасе? - PullRequest
1 голос
/ 03 октября 2019

Я хочу заморозить предварительно обученную сеть в Керасе. Я нашел base.trainable = False в документации. Но я не понял, как это работает. С len(model.trainable_weights) я узнал, что у меня 30 тренировочных весов. Как это может быть? Сеть показывает общее количество обучаемых параметров: 16 812 353. После заморозки у меня 4 тренировочных веса. Может быть, я не понимаю разницу между параметрами и весами. К сожалению, я новичок в углубленном изучении. Может быть, кто-то может мне помочь.

1 Ответ

2 голосов
/ 03 октября 2019

A Keras Model по умолчанию обучаемо - у вас есть два способа заморозить все веса:

  1. model.trainable = False до компиляциимодель
  2. for layer in model.layers: layer.trainable = False - работает до и после компиляции

(1) необходимо выполнить до компиляции, поскольку Keras при компиляции обрабатывает model.trainable как логический флаг и выполняет (2) под капотом. После выполнения любого из вышеперечисленных действий вы должны увидеть:

print(model.trainable_weights)
# [] 

Относительно документов, вероятно, устаревших - см. Приведенный выше исходный код, актуальный.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...