Поезд CNN после автоэнкодера - PullRequest
0 голосов
/ 21 сентября 2018

У меня есть обученный автоэнкодер, который я хочу использовать, чтобы уменьшить размерность изображения, а затем обучить CNN, используя закодированное изображение.Как я тренирую свой CNN, используя закодированные изображения?Я хочу использовать генератор соответствия для возврата закодированного изображения вместе с соответствующей меткой.

def custom_generator(generator):
    for data, labels in generator:
        data=encoder.predict(data)
        yield data, labels
model.fit_generator(custom_generator(train_generator), steps_per_epoch=num_train_steps, epochs=25,validation_data=custom_generator(validation_generator), validation_steps=num_valid_steps)

Это ошибка, которую я получаю:

     Epoch 1/25
    ---------------------------------------------------------------------------
    ValueError                                Traceback (most recent call last)
    <ipython-input-33-54d2c6697155> in <module>()
    ----> 1 model.fit_generator(custom_generator(train_generator), steps_per_epoch=num_train_steps, epochs=25,validation_data=custom_generator(validation_generator), validation_steps=num_valid_steps)

    /usr/local/lib/python3.6/dist-packages/keras/legacy/interfaces.py in wrapper(*args, **kwargs)
         89                 warnings.warn('Update your `' + object_name +
         90                               '` call to the Keras 2 API: ' + signature, stacklevel=2)
    ---> 91             return func(*args, **kwargs)
         92         wrapper._original_function = func
         93         return wrapper

    /usr/local/lib/python3.6/dist-packages/keras/models.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
       1313                                         use_multiprocessing=use_multiprocessing,
       1314                                         shuffle=shuffle,
    -> 1315                                         initial_epoch=initial_epoch)
  1316 
   1317     @interfaces.legacy_generator_methods_support

/usr/local/lib/python3.6/dist-packages/keras/legacy/interfaces.py in wrapper(*args, **kwargs)
     89                 warnings.warn('Update your `' + object_name +
     90                               '` call to the Keras 2 API: ' + signature, stacklevel=2)
---> 91             return func(*args, **kwargs)
     92         wrapper._original_function = func
     93         return wrapper

/usr/local/lib/python3.6/dist-packages/keras/engine/training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
   2192                 batch_index = 0
   2193                 while steps_done < steps_per_epoch:
-> 2194                     generator_output = next(output_generator)
   2195 
   2196                     if not hasattr(generator_output, '__len__'):

/usr/local/lib/python3.6/dist-packages/keras/utils/data_utils.py in get(self)
    791             success, value = self.queue.get()
    792             if not success:
--> 793                 six.reraise(value.__class__, value, value.__traceback__)

/usr/local/lib/python3.6/dist-packages/six.py in reraise(tp, value, tb)
    691             if value.__traceback__ is not tb:
    692                 raise value.with_traceback(tb)
--> 693             raise value
    694         finally:
    695             value = None
/usr/local/lib/python3.6/dist-packages/keras/utils/data_utils.py in _data_generator_task(self)
    656                             # => Serialize calls to
    657                             # infinite iterator/generator's next() function
--> 658                             generator_output = next(self._generator)
    659                             self.queue.put((True, generator_output))
    660                         else:

<ipython-input-32-81cd29d5c219> in custom_generator(generator)
      1 def custom_generator(generator):
----> 2   for data, labels in generator:
      3     data=encoder.predict(data)
      4     yield data, labels

ValueError: too many values to unpack (expected 2)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...