Как заставить генератор генерировать набор данных из кортежей (x, y), где x и y являются массивами? - PullRequest
0 голосов
/ 12 мая 2019

Я пытаюсь реализовать нейронную сеть, используя функцию keras.fit_generator ().Я реализовал функцию генератора, которая выдает x, y, где x - данные, а y - основную правду для результата, и оба являются (156, 156, 156) массивами-пустышками.Однако, когда я пытаюсь передать данные, я получаю сообщение об ошибке «Выход генератора должен быть кортежем (x, y, sample_weight) или (x, y). Найдено: tf.Tensor».

Когда я проверяю, что я получаюкогда я создаю свой набор данных с использованием функции генератора и перебираю его, и я получаю действительно tf.Tensor.Однако я не мог понять, как заставить его возвращать кортеж вместо tf.Tensor с формой (1, 2, 156, 156, 156).Что я должен сделать, чтобы получить кортеж (x, y)?

Для простоты я использовал следующую функцию генератора, которая должна давать (x, y):

def tuple_fun():
    for _ in range(10):
        x = np.random.rand(156,156,156)
        y = np.random.rand(156,156,156)
        yield tuple((x, y))

Я сгенерировал набор данных со следующим фрагментом кода:

def dataset_generator(batch_size):
    dataset = tf.data.Dataset.from_generator(lambda: tuple_fun(), 
                                               output_types=tf.int8, 
                                               output_shapes = (2, 156, 156, 156)).batch(batch_size)
    return dataset

Затем я попытался передать данные в нейронную сеть через keras.fit_generator () следующим образом:

test_batch_size = 1

test_dataset = dataset_generator(test_batch_size)

iterator = iter(test_dataset)

model = tf.keras.Sequential()
model.add(tf.keras.layers.Conv3D(input_shape=(156, 156, 156, 1),
            filters=5, padding="same", kernel_size=3,
            activation='relu',
            kernel_initializer=tf.keras.initializers.TruncatedNormal()))

model.compile(loss='mean_squared_error', optimizer=tf.keras.optimizers.Adam(decay=0.002))
model.fit_generator(iterator, steps_per_epoch=1, epochs=1)

И яполучил следующую ошибку:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-97-e1055a8161e6> in <module>
     26 
     27 model.compile(loss='mean_squared_error', optimizer=tf.keras.optimizers.Adam(decay=0.002))
---> 28 model.fit_generator(iterator, steps_per_epoch=1, epochs=1)

~/anaconda3/envs/condaEnv/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, validation_freq, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
   1513         shuffle=shuffle,
   1514         initial_epoch=initial_epoch,
-> 1515         steps_name='steps_per_epoch')
   1516 
   1517   def evaluate_generator(self,

~/anaconda3/envs/condaEnv/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_generator.py in model_iteration(model, data, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, validation_freq, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch, mode, batch_size, steps_name, **kwargs)
    211     step = 0
    212     while step < target_steps:
--> 213       batch_data = _get_next_batch(generator, mode)
    214       if batch_data is None:
    215         if is_dataset:

~/anaconda3/envs/condaEnv/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_generator.py in _get_next_batch(generator, mode)
    363       raise ValueError('Output of generator should be '
    364                        'a tuple `(x, y, sample_weight)` '
--> 365                        'or `(x, y)`. Found: ' + str(generator_output))
    366 
    367   if len(generator_output) < 1 or len(generator_output) > 3:

ValueError: Output of generator should be a tuple `(x, y, sample_weight)` or `(x, y)`. Found: tf.Tensor(
[[[[[0 0 0 ... 0 0 0]
    [0 0 0 ... 0 0 0]
    [0 0 0 ... 0 0 0]
    ...
    [0 0 0 ... 0 0 0]
    [0 0 0 ... 0 0 0]
    [0 0 0 ... 0 0 0]]

   [[0 0 0 ... 0 0 0]
    [0 0 0 ... 0 0 0]
    [0 0 0 ... 0 0 0]
    ...
    [0 0 0 ... 0 0 0]
    [0 0 0 ... 0 0 0]
    [0 0 0 ... 0 0 0]]

   [[0 0 0 ... 0 0 0]
    [0 0 0 ... 0 0 0]
    [0 0 0 ... 0 0 0]
    ...
    [0 0 0 ... 0 0 0]
    [0 0 0 ... 0 0 0]
    [0 0 0 ... 0 0 0]]

   ...

   [[0 0 0 ... 0 0 0]
    [0 0 0 ... 0 0 0]
    [0 0 0 ... 0 0 0]
    ...
    [0 0 0 ... 0 0 0]
    [0 0 0 ... 0 0 0]
    [0 0 0 ... 0 0 0]]

   [[0 0 0 ... 0 0 0]
    [0 0 0 ... 0 0 0]
    [0 0 0 ... 0 0 0]
    ...
    [0 0 0 ... 0 0 0]
    [0 0 0 ... 0 0 0]
    [0 0 0 ... 0 0 0]]

   [[0 0 0 ... 0 0 0]
    [0 0 0 ... 0 0 0]
    [0 0 0 ... 0 0 0]
    ...
    [0 0 0 ... 0 0 0]
    [0 0 0 ... 0 0 0]
    [0 0 0 ... 0 0 0]]]


  [[[0 0 0 ... 0 0 0]
    [0 0 0 ... 0 0 0]
    [0 0 0 ... 0 0 0]
    ...
    [0 0 0 ... 0 0 0]
    [0 0 0 ... 0 0 0]
    [0 0 0 ... 0 0 0]]

   [[0 0 0 ... 0 0 0]
    [0 0 0 ... 0 0 0]
    [0 0 0 ... 0 0 0]
    ...
    [0 0 0 ... 0 0 0]
    [0 0 0 ... 0 0 0]
    [0 0 0 ... 0 0 0]]

   [[0 0 0 ... 0 0 0]
    [0 0 0 ... 0 0 0]
    [0 0 0 ... 0 0 0]
    ...
    [0 0 0 ... 0 0 0]
    [0 0 0 ... 0 0 0]
    [0 0 0 ... 0 0 0]]

   ...

   [[0 0 0 ... 0 0 0]
    [0 0 0 ... 0 0 0]
    [0 0 0 ... 0 0 0]
    ...
    [0 0 0 ... 0 0 0]
    [0 0 0 ... 0 0 0]
    [0 0 0 ... 0 0 0]]

   [[0 0 0 ... 0 0 0]
    [0 0 0 ... 0 0 0]
    [0 0 0 ... 0 0 0]
    ...
    [0 0 0 ... 0 0 0]
    [0 0 0 ... 0 0 0]
    [0 0 0 ... 0 0 0]]

   [[0 0 0 ... 0 0 0]
    [0 0 0 ... 0 0 0]
    [0 0 0 ... 0 0 0]
    ...
    [0 0 0 ... 0 0 0]
    [0 0 0 ... 0 0 0]
    [0 0 0 ... 0 0 0]]]]], shape=(1, 2, 156, 156, 156), dtype=int8)

Как я могу решить эту проблему?Это был мой первый вопрос в stackoverflow, поэтому я надеюсь, что не сделал ничего плохого.Заранее спасибо за любую помощь!

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...