Я пытаюсь реализовать нейронную сеть, используя функцию keras.fit_generator ().Я реализовал функцию генератора, которая выдает x, y, где x - данные, а y - основную правду для результата, и оба являются (156, 156, 156) массивами-пустышками.Однако, когда я пытаюсь передать данные, я получаю сообщение об ошибке «Выход генератора должен быть кортежем (x, y, sample_weight)
или (x, y)
. Найдено: tf.Tensor».
Когда я проверяю, что я получаюкогда я создаю свой набор данных с использованием функции генератора и перебираю его, и я получаю действительно tf.Tensor.Однако я не мог понять, как заставить его возвращать кортеж вместо tf.Tensor с формой (1, 2, 156, 156, 156).Что я должен сделать, чтобы получить кортеж (x, y)?
Для простоты я использовал следующую функцию генератора, которая должна давать (x, y):
def tuple_fun():
for _ in range(10):
x = np.random.rand(156,156,156)
y = np.random.rand(156,156,156)
yield tuple((x, y))
Я сгенерировал набор данных со следующим фрагментом кода:
def dataset_generator(batch_size):
dataset = tf.data.Dataset.from_generator(lambda: tuple_fun(),
output_types=tf.int8,
output_shapes = (2, 156, 156, 156)).batch(batch_size)
return dataset
Затем я попытался передать данные в нейронную сеть через keras.fit_generator () следующим образом:
test_batch_size = 1
test_dataset = dataset_generator(test_batch_size)
iterator = iter(test_dataset)
model = tf.keras.Sequential()
model.add(tf.keras.layers.Conv3D(input_shape=(156, 156, 156, 1),
filters=5, padding="same", kernel_size=3,
activation='relu',
kernel_initializer=tf.keras.initializers.TruncatedNormal()))
model.compile(loss='mean_squared_error', optimizer=tf.keras.optimizers.Adam(decay=0.002))
model.fit_generator(iterator, steps_per_epoch=1, epochs=1)
И яполучил следующую ошибку:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-97-e1055a8161e6> in <module>
26
27 model.compile(loss='mean_squared_error', optimizer=tf.keras.optimizers.Adam(decay=0.002))
---> 28 model.fit_generator(iterator, steps_per_epoch=1, epochs=1)
~/anaconda3/envs/condaEnv/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, validation_freq, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
1513 shuffle=shuffle,
1514 initial_epoch=initial_epoch,
-> 1515 steps_name='steps_per_epoch')
1516
1517 def evaluate_generator(self,
~/anaconda3/envs/condaEnv/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_generator.py in model_iteration(model, data, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, validation_freq, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch, mode, batch_size, steps_name, **kwargs)
211 step = 0
212 while step < target_steps:
--> 213 batch_data = _get_next_batch(generator, mode)
214 if batch_data is None:
215 if is_dataset:
~/anaconda3/envs/condaEnv/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_generator.py in _get_next_batch(generator, mode)
363 raise ValueError('Output of generator should be '
364 'a tuple `(x, y, sample_weight)` '
--> 365 'or `(x, y)`. Found: ' + str(generator_output))
366
367 if len(generator_output) < 1 or len(generator_output) > 3:
ValueError: Output of generator should be a tuple `(x, y, sample_weight)` or `(x, y)`. Found: tf.Tensor(
[[[[[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
...
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]]
[[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
...
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]]
[[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
...
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]]
...
[[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
...
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]]
[[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
...
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]]
[[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
...
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]]]
[[[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
...
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]]
[[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
...
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]]
[[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
...
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]]
...
[[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
...
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]]
[[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
...
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]]
[[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
...
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]]]]], shape=(1, 2, 156, 156, 156), dtype=int8)
Как я могу решить эту проблему?Это был мой первый вопрос в stackoverflow, поэтому я надеюсь, что не сделал ничего плохого.Заранее спасибо за любую помощь!