Обучающая модель Keras с fit_generator и TFrecords - PullRequest
0 голосов
/ 19 марта 2019

Я бы хотел потренировать свою ConvNet с Керасом. После некоторых уроков я написал что-то вроде этого.

Я не знаю, хорошо ли это, в частности, у меня есть некоторые сомнения по поводу использования генераторов для обучения моей модели.

Раньше я кормил процедуру train генераторами numpy-array, но я читал, что для улучшения производительности можно использовать tfrecords.

В первый раз я преобразовал тензоры в пустые массивы в функции create_dataset ниже (до их "выдачи"), но позже прочитал, что

Действительно, есть более эффективный способ использования набора данных без необходимости преобразовать тензоры в массивы NumPy.

поэтому я попытался отредактировать свой код в этом смысле, используя input_image=tf.keras.Input(tensor=x) и model.compile(optimizer=optimizer, loss=compute_loss, target_tensors=[y]).

Раньше я не использовал ни target_tensors внутри model.compile, ни tensor=x внутри tf.keras.Input (я только указал форму ввода).

import tensorflow as tf
import keras
import compute_loss #my loss function

dataset_train_path="dataset_train.tfrecords"
dataset_val_path="dataset_val.tfrecords"

filepath_checkpoint="weights-best.hdf5"


Adam=tf.keras.optimizers.Adam
optimizer = Adam(lr=0.00001, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0)

BATCH_SIZE=32
TRAINING_SIZE=5717
VALIDATION_SIZE=5823
TRAINING_STEPS=TRAINING_SIZE//BATCH_SIZE
VALIDATION_STEPS=VALIDATION_SIZE//BATCH_SIZE

"""-----------------Here I define my generator-----------------"""
def _parse_function(proto):

    keys_to_features = {'image': tf.FixedLenFeature([], tf.string),
                        'label': tf.FixedLenFeature([], tf.string)}


    parsed_features = tf.parse_single_example(proto, keys_to_features)

    parsed_features['image'] = tf.decode_raw(parsed_features['image'], tf.float16)
    parsed_features['label'] = tf.decode_raw(parsed_features['label'], tf.float16)
    return parsed_features['image'], parsed_features["label"]


def create_dataset(filepath, batch_size=BATCH_SIZE):

    dataset = tf.data.TFRecordDataset(filepath)

    dataset = dataset.map(_parse_function, num_parallel_calls=8)
    dataset = dataset.repeat()
    dataset = dataset.shuffle(100)
    dataset = dataset.batch(BATCH_SIZE)

    iterator = dataset.make_one_shot_iterator()
    image, label = iterator.get_next()

    image = tf.reshape(image, [BATCH_SIZE, 416, 416, 3])
    label = tf.reshape(label, [BATCH_SIZE, 75, 25])

    while True:
        yield image, label

"""-----------------Here I create my train/val generators-----------------"""
training_generator=create_dataset(dataset_train_path)
validation_generator=create_dataset(dataset_val_path)

"""-----------------Now I can define my model-----------------"""
x,y=next(training_generator);
def net():
    input_image=tf.keras.Input(tensor=x)
    inputs=tf.keras.layers.Conv2D(16,3,padding='same', activation='relu', name='conv_1')(input_image)
    inputs=tf.keras.layers.BatchNormalization(name='norm_1')(inputs)
    ...
    ...
    outputs = tf.keras.layers.Conv2D(75, 1, name='conv_13')(inputs)
    model = tf.keras.Model(inputs=input_image, outputs=outputs)
    return model

if __name__ == '__main__':
    model=net()
    model.compile(optimizer=optimizer, loss=compute_loss, target_tensors=[y])
    model.fit_generator(generator=training_generator,validation_data=validation_generator, epochs=1000, max_queue_size=1000, steps_per_epoch=TRAINING_STEPS, validation_steps=VALIDATION_STEPS, callbacks=callbacks_list)

Сейчас обучение идет быстро, но я подозреваю, что где-то есть ошибки. Не могли бы вы помочь мне?

РЕДАКТИРОВАТЬ: Если я помещаю непосредственно наборы данных в fit_generator, я получаю следующее:

>>> train(model, DataGenerator, filepath_checkpoint="weights-best-tiny-test.hdf5")
Epoch 1/5000
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "tiny.py", line 239, in train
    model.fit_generator(generator=training_generator,validation_data=validation_generator, epochs=5000, max_queue_size=100, steps_per_epoch=TRAINING_STEPS, validation_steps=VALIDATION_STEPS, callbacks=callbacks_list)
  File "C:\Program Files\Python35\lib\site-packages\tensorflow\python\keras\engine\training.py", line 1586, in fit_generator
    steps_name='steps_per_epoch')
  File "C:\Program Files\Python35\lib\site-packages\tensorflow\python\keras\engine\training_generator.py", line 211, in model_iteration
    batch_data = _get_next_batch(output_generator, mode)
  File "C:\Program Files\Python35\lib\site-packages\tensorflow\python\keras\engine\training_generator.py", line 323, in _get_next_batch
    generator_output = next(output_generator)
  File "C:\Program Files\Python35\lib\site-packages\tensorflow\python\keras\utils\data_utils.py", line 767, in get
    six.reraise(*sys.exc_info())
  File "C:\Program Files\Python35\lib\site-packages\six.py", line 693, in reraise
    raise value
  File "C:\Program Files\Python35\lib\site-packages\tensorflow\python\keras\utils\data_utils.py", line 743, in get
    inputs = self.queue.get(block=True).get()
  File "C:\Program Files\Python35\lib\multiprocessing\pool.py", line 644, in get
    raise self._value
  File "C:\Program Files\Python35\lib\multiprocessing\pool.py", line 119, in worker
    result = (True, func(*args, **kwds))
  File "C:\Program Files\Python35\lib\site-packages\tensorflow\python\keras\utils\data_utils.py", line 680, in next_sample
    return six.next(_SHARED_SEQUENCES[uid])
TypeError: 'Iterator' object is not an iterator
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...