Обучение с учетом квантования в TensorFlow версии 2 и сворачивании BatchNorm - PullRequest
1 голос
/ 27 марта 2020

Мне интересно, какие текущие доступные опции имитируют сворачивание BatchNorm во время обучения с учетом квантования в Tensorflow 2. Tensorflow 1 имеет функцию tf.contrib.quantize.create_training_graph, которая вставляет слои FakeQuantification в график и заботится о моделировании сворачивания нормализации партии (согласно к этой белой книге ).

В Tensorflow 2 есть учебник о том, как использовать квантование в недавно принятом tf.keras API, но они ничего не упоминают о пакетной нормализации. Я попробовал следующий простой пример со слоем BatchNorm:

import tensorflow_model_optimization as tfmo

model = tf.keras.Sequential([
      l.Conv2D(32, 5, padding='same', activation='relu', input_shape=input_shape),
      l.MaxPooling2D((2, 2), (2, 2), padding='same'),
      l.Conv2D(64, 5, padding='same', activation='relu'),
      l.BatchNormalization(),    # BN!
      l.MaxPooling2D((2, 2), (2, 2), padding='same'),
      l.Flatten(),
      l.Dense(1024, activation='relu'),
      l.Dropout(0.4),
      l.Dense(num_classes),
      l.Softmax(),
])
model = tfmo.quantization.keras.quantize_model(model)

Однако он дает следующее исключение:

RuntimeError: Layer batch_normalization:<class 'tensorflow.python.keras.layers.normalization.BatchNormalization'> is not supported. You can quantize this layer by passing a `tfmot.quantization.keras.QuantizeConfig` instance to the `quantize_annotate_layer` API.

, которое указывает, что TF не знает, что с ним делать.

Я также видел эту связанную топику c, где они применяют tf.contrib.quantize.create_training_graph к модели, построенной на основе кераса. Однако они не используют слои BatchNorm, поэтому я не уверен, что это сработает.

Итак, каковы варианты использования этой функции сворачивания BatchNorm в TF2? Можно ли это сделать с помощью API-интерфейса keras или я должен вернуться к API-интерфейсу TensorFlow 1 и определить график по-старому?

1 Ответ

0 голосов
/ 25 апреля 2020

Если вы добавите BatchNormalization до активации, у вас не будет проблем с квантованием. Примечание. Квантование поддерживается в BatchNormalization только в том случае, если уровень находится точно после слоя Conv2D. https://www.tensorflow.org/model_optimization/guide/quantization/training

# Change
l.Conv2D(64, 5, padding='same', activation='relu'),
l.BatchNormalization(),    # BN!
# with this
l.Conv2D(64, 5, padding='same'),
l.BatchNormalization(),
l.Activation('relu'),

#Other way of declaring the same
o = (Conv2D(512, (3, 3), padding='valid' , data_format=IMAGE_ORDERING))(o)
o = (BatchNormalization())(o)
o = Activation('relu')(o)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...