Как я могу сообщить Keras фазу обучения, когда я использую train_on_batch для обучения модели? - PullRequest
0 голосов
/ 09 января 2019

В моей модели есть выпадающие слои, поэтому я хочу, чтобы keras вычислил фазы обучения и тестирования, чтобы запустить или проигнорировать выпадающие слои, и я обнаружил, что K.set_learning_phase может оказать мне такую ​​услугу, но как я могу добавить его в обучение? а тестовые процессы? Мой код такой:

def discriminator(self):
    x_A = Input(shape=self.shape)
    x_B = Input(shape=self.shape)
    x = concatenate([x_A, x_B], axis=-1)
    self.model = Sequential()
    self.model.add(Dropout(0.5, input_shape=self.shape_double))
    self.model.add(LSTM(200, return_sequences=True, kernel_constraint=unit_norm()))
    self.model.add(Dropout(0.5))
    self.model.add(LSTM(200, return_sequences=True, kernel_constraint=unit_norm()))
    self.model.add(Dropout(0.5))
    self.model.add(Flatten())
    self.model.add(Dense(8, activation="softmax", kernel_constraint=unit_norm())

    label=self.model(x)

    return Model([x_A,x_B], label)
...
def train(self, epoch, batch_size):
    for epoch in range(epochs):
        for batch,train_A,train_B,train_label in enumerate(Load_train(batch_size)):
            Dloss = self.discriminator.train_on_batch([train_A,train_B],train_label)
            ...
def test(self,test_A,test_B,test_label):
    predicted_label_dist = self.discriminator.predict([test_A,test_B])
    ...

Любые предложения будут оценены. Спасибо.

1 Ответ

0 голосов
/ 09 января 2019

Керас самостоятельно определяет подходящий этап обучения по умолчанию, когда вы называете подгонку или прогнозирование. Следовательно, ваш отсев будет применяться только во время обучения, но не во время тестирования. Однако, если вы все же хотите настроить фазу обучения самостоятельно, то есть переписать поведение по умолчанию, вы можете сделать это следующим образом (из документации keras):

keras.backend.set_learning_phase(value) 

Где:

значение: значение фазы обучения, 0 или 1 (целые числа).

просто добавьте этот код в функцию обучения и тестирования.

...