tf. keras fit_generator () застревает на validation_data - PullRequest
1 голос
/ 13 июля 2020

Я использую форму DataGenerator tf.keras.Sequence для загрузки моих данных партиями. Генератор данных возвращает numpy массивов изображений и масок. Когда я вызываю fit_generator (), похоже, что модель подходит для данных поезда, но застревает на данных проверки. Если я установил Validation_data = None, а затем запустил его, ошибки не будет. Я использую tenorflow 1.14, tf.keras 2.2.4

Вот фрагмент кода:

model = create_model()
optimizer = Adam(lr = 0.001)
model.compile(loss=loss, optimizer=optimizer, metrics=[dice_coefficient])

train_gen = DataGenerator(X_train, batch_size=1,  predict=False, shuffle=True)
val_gen = DataGenerator(X_val, batch_size=1,  predict=False, shuffle=True)

model.fit_generator(train_gen, validation_data = val_gen, callbacks = [checkpoint, reduce_lr, stop], epochs=1,  verbose=1)    

Вот ошибка, которую я получаю:

Use tf.where in 2.0, which has the same broadcast rule as np.where
19/20 [===========================>..] - ETA: 7s - loss: 3.7508 - dice_coefficient: 0.1282 
---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-18-9915a3b43b57> in <module>
      1 model.fit_generator(generator=train_gen, validation_data=val_gen, epochs=1, 
      2                     callbacks = [checkpoint, reduce_lr, stop],
----> 3                     shuffle=True, verbose=1)
      4 

~\Anaconda3\envs\tf_gpu\lib\site-packages\tensorflow\python\keras\engine\training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, validation_freq, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
   1431         shuffle=shuffle,
   1432         initial_epoch=initial_epoch,
-> 1433         steps_name='steps_per_epoch')
   1434 
   1435   def evaluate_generator(self,

~\Anaconda3\envs\tf_gpu\lib\site-packages\tensorflow\python\keras\engine\training_generator.py in model_iteration(model, data, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, validation_freq, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch, mode, batch_size, steps_name, **kwargs)
    320           verbose=0,
    321           mode=ModeKeys.TEST,
--> 322           steps_name='validation_steps')
    323 
    324       if not isinstance(val_results, list):

~\Anaconda3\envs\tf_gpu\lib\site-packages\tensorflow\python\keras\engine\training_generator.py in model_iteration(model, data, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, validation_freq, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch, mode, batch_size, steps_name, **kwargs)
    262 
    263       is_deferred = not model._is_compiled
--> 264       batch_outs = batch_function(*batch_data)
    265       if not isinstance(batch_outs, list):
    266         batch_outs = [batch_outs]

~\Anaconda3\envs\tf_gpu\lib\site-packages\tensorflow\python\keras\engine\training.py in test_on_batch(self, x, y, sample_weight, reset_metrics)
   1245       self._update_sample_weight_modes(sample_weights=sample_weights)
   1246       self._make_test_function()
-> 1247       outputs = self.test_function(inputs)  # pylint: disable=not-callable
   1248 
   1249     if reset_metrics:

~\Anaconda3\envs\tf_gpu\lib\site-packages\tensorflow\python\keras\backend.py in __call__(self, inputs)
   3290 
   3291     fetched = self._callable_fn(*array_vals,
-> 3292                                 run_metadata=self.run_metadata)
   3293     self._call_fetch_callbacks(fetched[-len(self._fetches):])
   3294     output_structure = nest.pack_sequence_as(

~\Anaconda3\envs\tf_gpu\lib\site-packages\tensorflow\python\client\session.py in __call__(self, *args, **kwargs)
   1456         ret = tf_session.TF_SessionRunCallable(self._session._session,
   1457                                                self._handle, args,
-> 1458                                                run_metadata_ptr)
   1459         if run_metadata:
   1460           proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

InvalidArgumentError: You must feed a value for placeholder tensor 'reshape_target' with dtype float and shape [?,?,?]
     [[{{node reshape_target}}]]

1 Ответ

0 голосов
/ 13 июля 2020

В настоящее время нет возможности использовать генератор для ваших данных проверки. Взгляните на документацию Tensowflow .

В целях проверки я на самом деле не вижу цели полноценного генератора, поскольку основная цель - измерить возможность обобщения вашей сети после заданной эпохи. Ваша сеть будет эффективной, если вы добавите данные проверки в целом, разделите их на пакеты и ограничите количество шагов.

...