Не удается правильно восстановить слои tf.keras.layers.BatchNormalization из контрольной точки - PullRequest
0 голосов
/ 14 мая 2019

У меня проблемы с восстановлением слоев нормализации партии обученной модели с контрольной точки во время теста.

Проверка параметров слоев нормализации партии после восстановления показывает, что они загружены неправильно, но имеют значения инициализации (гамма = 1, бета = 0). Это дает мне плохие результаты во время теста. Обучение и проверка работоспособности, как и ожидалось, и параметры нормализации партии обновляются, поэтому я думаю, что проблема должна заключаться в сохранении и / или восстановлении модели.

Модель определяется с использованием API подклассов:

class convUnit(tf.keras.layers.Layer):
    def __init__(self,numFilters):
        super(convUnit, self).__init__()
        self.conv = tf.keras.layers.Conv2D(numFilters,kernel_size=3,padding='same')
        self.bn = tf.keras.layers.BatchNormalization(trainable=True) 
        self.relu = tf.keras.layers.Activation('relu')

    def call(self,inputs):
        x = self.conv(inputs)
        x = self.bn(x)
        x = self.relu(x)
        return x

class Network(tf.keras.Model):  #### USING SUBCLASSING API
     def __init__(self):
          super(Network, self).__init__()
          self.conv1 = convUnit(32)
          self.mp1 = tf.keras.layers.MaxPooling2D(pool_size = 2,padding = 'same')

          self.conv2 = convUnit(32)
          self.mp2 = tf.keras.layers.MaxPooling2D(pool_size = 2,padding = 'same')

          self.conv3 = convUnit(32)
          self.mp3 = tf.keras.layers.MaxPooling2D(pool_size = 2,padding = 'same')

          self.conv4 = convUnit(32)
          self.mp4 = tf.keras.layers.MaxPooling2D(pool_size = 2,padding = 'same')

          self.flat = tf.keras.layers.Flatten()
          self.drop1 = tf.keras.layers.Dropout(rate = 0.5)
          self.dense1 = tf.keras.layers.Dense(units = 128)


     def call(self, inputs,training):
          x = self.conv1(inputs)
          x = self.mp1(x)

          x = self.conv2(x)
          x = self.mp2(x)

          x = self.conv3(x)
          x = self.mp3(x)

          x = self.conv4(x)
          x = self.mp4(x)

          x = self.flat(x)
          x = self.drop1(x,training = training)  ###### training argument passing to drop object
          x = self.dense1(x)
          x = tf.math.l2_normalize(x,axis=1)
          return x

Для сохранения модели я использую:

model = Network()
modelCheckpoint = tf.train.Checkpoint(optimizer=optimizer, model=model,optimizer_step=tf.train.get_or_create_global_step())
lastModelSaver = tf.train.CheckpointManager(modelCheckpoint, directory=lastModelDir, max_to_keep=1)
bestModelSaver = tf.train.CheckpointManager(modelCheckpoint, directory=bestModelDir, max_to_keep=1)

...

if ep_mean_val_loss < best_val_loss:
        ######## BEST MODEL CHECKPOINT WRITTING
        bestModelSaver.save()

А для восстановления модели:

model = Network()
modelCheckpoint = tf.train.Checkpoint(model=model)
bestModelSaver = tf.train.CheckpointManager(modelCheckpoint, directory=bestModelDir, max_to_keep=1)

modelCheckpoint.restore(bestModelSaver.latest_checkpoint) ##### restore last model latest checkpoint

Я использую tenorflow 1.13 в режиме активного исполнения. Я новичок в использовании API подклассов моделей и нетерпеливого режима выполнения в tenorflow, так что, возможно, я что-то упустил.

Любая помощь будет высоко ценится.

...