Проверка в обучении Кераса - PullRequest
0 голосов
/ 07 мая 2020

Я новичок в Keras и машинном обучении в целом, и я тренирую такую ​​модель:

history = model.fit_generator(flight_generator(train_files_train, 4), steps_per_epoch=500, epochs=50)

Где flight_generator - это функция, которая подготавливает обучающие данные и форматирует их, а затем дает это обратно к модели, чтобы соответствовать. это отлично работает, поэтому теперь я хочу добавить некоторую проверку, и после долгих поисков в Интернете я все еще не знаю, как это реализовать.

Я бы предположил что-то вроде:

history = model.fit_generator(flight_generator(train_files_train, 4), steps_per_epoch=500, epochs=50, validation_data=flight_generator(train_files_cv, 4))

Но когда я запускаю код, он просто зависает в первую эпоху. Что мне не хватает?

РЕДАКТИРОВАТЬ:

Код для flight_generator:

def flight_generator(files, batch_size):

    while True:
          batch_inputs  = numpy.random.choice(a    = files, 
                                          size = batch_size)
          batch_input_X = []
          batch_input_Y = []
          c=0
          for batch_input in batch_inputs:
            # reshape into X=t and Y=t+1
            trainX, trainY = create_dataset(batch_input, look_back)
            # reshape input to be [samples, time steps, features]
            trainX = numpy.reshape(trainX, (trainX.shape[0], 1, trainX.shape[1]))

            if c is 0:
              batch_input_X = trainX
              batch_input_Y = trainY

            else:
              batch_input_X = numpy.concatenate((batch_input_X, trainX), axis = 0)
              batch_input_Y = numpy.concatenate((batch_input_Y, trainY), axis = 0)

            c += 1


          # Return a tuple of (input) to feed the network

          batch_x = numpy.array( batch_input_X )
          batch_y = numpy.array( batch_input_Y )


          yield(batch_x, batch_y)

Ответы [ 2 ]

0 голосов
/ 08 мая 2020

Думаю, вам следует использовать model.fit (........) Не пытайтесь использовать генератор, если он вам действительно не нужен. В любом коде, который я видел, model.fit () выполняет магию c

Пожалуйста, обратитесь к документации Keras для fit () https://keras.io/api/models/sequential/ И, пожалуйста, укажите оптимизатор и метрики

0 голосов
/ 07 мая 2020

Ваш validation_data должен быть в формате кортежа. Поэтому попробуйте изменить его:

history = model.fit_generator(flight_generator(train_files_train, 4), steps_per_epoch=500, epochs=50,batch_size=32,validation_data=(flight_generator(train_files_cv, 4)))

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...