vaidation_data должен быть кортежем в keras fit_generator - PullRequest
2 голосов
/ 28 марта 2019

Я пытаюсь переопределить keras.util.Sequence с моим классом SequenceGenerator(Sequence) и передать его fit_generator, но fit_generator выдал ошибку ValueError

Вот мой пользовательский класс

import os import numpy as np из keras.utils Импорт последовательности из batchGenerator import BatchGenerator

из настроек import batch_size, train_folder, test_folder

class SequenceGenerator(Sequence):
    def __init__(self, batches_folder):
        self.batch_generator = BatchGenerator(folder_name=batches_folder)      
        self.names = [f for f in os.listdir(batches_folder) if f.lower().endswith('.jpg')]

    def __len__(self):
        return int(np.ceil(len(self.names) / float(batch_size)))

    def __getitem__(self, idx):
        print('Getting a bacth{0}'.format(idx))
        [X_batch, Y_batch] = self.batch_generator.load_batch_from_disk(idx)
        return X_batch, Y_batch


def train_seq_genenrator():
    return SequenceGenerator(train_folder)


def test_seq_generator():
    return SequenceGenerator(test_folder)

и в блокнот jupyter, Iимпортировал следующее

from sequenceGenerator import train_seq_genenrator, test_seq_generator

наконец, вот вам вызов fit_generator

history = new_model.fit_generator(train_seq_genenrator()
                        , steps_per_epoch=num_train_samples // batch_size                        
                        , validation_data=test_seq_generator()
                        , validation_steps=num_test_samples // batch_size
                        , epochs=epochs
                        , shuffle=True)

Я получил следующую ошибку:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-63-347bef86c8c0> in <module>()
      4                         , validation_steps=num_test_samples // batch_size
      5                         , epochs=epochs
----> 6                         , shuffle=True)

~\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\keras\engine\training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
   1759         use_multiprocessing=use_multiprocessing,
   1760         shuffle=shuffle,
-> 1761         initial_epoch=initial_epoch)
   1762 
   1763   def evaluate_generator(self,

~\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\keras\engine\training_generator.py in fit_generator(model, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
    121             '`validation_data` should be a tuple '
    122             '`(val_x, val_y, val_sample_weight)` '
--> 123             'or `(val_x, val_y)`. Found: ' + str(validation_data))
    124       val_x, val_y, val_sample_weights = model._standardize_user_data(
    125           val_x, val_y, val_sample_weight)

ValueError: `validation_data` should be a tuple `(val_x, val_y, val_sample_weight)` or `(val_x, val_y)`. Found: <sequenceGenerator.SequenceGenerator object at 0x000001DCB58259B0>

Не знаю, почему этослучается, однако это найдено в https://keras.io/models/sequential/

validation_data : Это может быть либо

  • генератор или объект Sequence для данных проверки
  • кортеж (x_val, y_val)
  • кортеж (x_val, y_val, val_sample_weights)
...