Установите для обучения значение False для данных проверки, используемых в функции генератора Keras Fit - PullRequest
0 голосов
/ 01 июля 2019

Как мне установить training на False для validation_data, используемого в Keras fit_generator? У меня есть Dropout слоев в моей модели, и я хочу, чтобы обучение было True во время обучения и False во время проверки и тестирования.

Ответы [ 2 ]

0 голосов
/ 01 июля 2019

Keras автоматически устанавливает learning_phase в False при выполнении проверки.Нет ничего лишнего, что вам нужно сделать.

Отбрасывающие узлы автоматически проверяют, находятся ли они в режиме обучения.

https://github.com/keras-team/keras/blob/master/keras/layers/core.py#L126

Если вы хотите убедиться, что Keras автоматическиизменяет флаг режима обучения, вы можете выполнить код ниже.Он добавляет лямбда-слой, который добавляет тензор печати к графику, который выводит разные сообщения в каждом случае.

from tensorflow import keras
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
from tensorflow.keras import backend as K

def inspect(x):
  xp = K.in_train_phase(K.print_tensor(x, message='train x:'),
                        K.print_tensor(x, message='test x:'))
  return xp

def make_model():
  inp = Input(shape=(4,))
  h1 = Dense(2)(inp)
  h1p = Lambda(inspect)(h1)
  out = Dense(1)(h1p)
  model = Model(inp, out)
  model.compile('adam', 'mse')
  return model

model = make_model()
model.summary()


import numpy as np

X_train = np.random.rand(1, 4)
Y_train = np.random.rand(1, 1)
X_test = np.random.rand(1, 4)
Y_test = np.random.rand(1, 1)

model.fit(X_train, Y_train, validation_data=(X_test, Y_test))
0 голосов
/ 01 июля 2019

Я добавил аргумент is_training в свою функцию генератора данных.Если это True, я установил фазу обучения Keras на 1, в противном случае на 0 (см. Документацию на keras.io/backend/):

if is_training:
    K.set_learning_phase( 1 )
else:
    K.set_learning_phase( 0 )

Итак, для своего генератора тренировочных данных я использую is_training = True, и для моего генератора проверочных данных я использую is_training = False.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...