Должны ли мы применять повтор, пакетное перемешивание к tf.data.Dataset при передаче его для соответствия функции? - PullRequest
0 голосов
/ 25 февраля 2020

Я до сих пор не прочитал документацию о tf.keras.Model.fit и tf.data.Dataset, при передаче tf.data.Dataset для соответствия функции, следует ли мне вызывать repeat и batch для объекта набора данных или я должен предоставить batch_size и epochs аргументы, чтобы соответствовать вместо? или оба? Должен ли я применить ту же обработку к набору проверки?

И пока я здесь, могу ли я shuffle набор данных до fit? (кажется, что это очевидно, да) Если это так, до, после вызова Dataset.batch и Dataset.repeat (при их вызове)?

Редактировать: При использовании аргумента batch_size и не позвонив Dataset.batch(batch_size) ранее, я получаю следующую ошибку:

ValueError: The `batch_size` argument must not be specified for the given input type.
Received input: <MapDataset shapes: ((<unknown>, <unknown>, <unknown>, <unknown>), (<unknown>, <unknown>, <unknown>)), 
types: ((tf.float32, tf.float32, tf.float32, tf.float32), (tf.float32, tf.float32, tf.float32))>, 
batch_size: 1

Спасибо

1 Ответ

1 голос
/ 26 февраля 2020

Есть разные способы сделать то, что вы хотите здесь, но я всегда использую:

batch_size = 32
ds = tf.Dataset()
ds = ds.shuffle(len_ds)
train_ds = ds.take(0.8*len_ds)
train_ds = train_ds.repeat().batch(batch_size)
validation_ds = ds.skip(0.8*len_ds)
validation_ds = train_ds.repeat().batch(batch_size)
model.fit(train_ds,
          steps_per_epoch = len_train_ds // batch_size,
          validation_data = validation_ds,
          validation_steps = len_validation_ds // batch_size,
          epochs = 5)

Таким образом, у вас есть доступ ко всем переменным и после подгонки модели, например, если вы хотите Чтобы визуализировать набор проверки, вы можете. На самом деле это невозможно с validation_split. Если вы удалите .batch(batch_size), вы должны удалить // batch_size s, но я бы оставил их, поскольку это прояснит, что происходит сейчас.

Вы всегда должны указывать эпохи.

Расчет длина вашего поезда / проверочных наборов требует от вас l oop над ними:

len_train_ds = 0
for i in train_ds:
  len_train_ds += 1

, если в форме tf.Dataset.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...