Что на самом деле означает сброс в наборе данных Tensorflow 2? - PullRequest
0 голосов
/ 01 июня 2019

Я слежу за тензорным потоком 2 Кераса Документация .Моя модель выглядит следующим образом:

train_dataset = tf.data.Dataset.from_tensor_slices((np.array([_my_cus_func(i) for i in X_train]), y_train))
train_dataset = train_dataset.map(lambda vals,lab: _process_tensors(vals,lab), num_parallel_calls=4)
train_dataset = train_dataset.shuffle(buffer_size=10000)
train_dataset = train_dataset.batch(64,drop_remainder=True)
train_dataset = train_dataset.prefetch(1)
model=get_compiled_model()
model.fit(train_dataset, epochs=100)

В документации написано

Обратите внимание, что набор данных сбрасывается в конце каждой эпохи, поэтому его можно использовать повторно для следующей эпохи.

Если вы хотите запустить обучение только на определенном количестве пакетов из этого набора данных, вы можете передать аргумент steps_per_epoch, который указывает, сколько этапов обучения должна выполнить модель с использованием этого набора данных, прежде чем перейти к следующей эпохе..

Если вы сделаете это, набор данных не сбрасывается в конце каждой эпохи, вместо этого мы просто продолжаем рисовать следующие партии.В наборе данных со временем закончатся данные (если только это не набор данных с бесконечной петлей).

Что на самом деле означает сброс?Будет ли тензорный поток считывать данные из тензорных срезов после каждой эпохи?или он только перетасовывает и запускает map функцию?Я хочу, чтобы тензорный поток считывал данные с numpy после эпохи и запускал _my_cus_func.Я скорее могу передать _my_cus_func на dataset map or apply api, но мне удобнее делать это в списке Python или массиве numpy.

1 Ответ

0 голосов
/ 01 июня 2019

В этом контексте сброс означает начать итерацию по набору данных с нуля. В вашем конкретном случае в коде отсутствует функция repeat(). Итак, если вы укажете steps_per_epoch параметр, подобный этому

model.fit(train_dataset, steps_per_epoch=N, epochs=100)

Он будет перебирать набор данных для N шагов, если N меньше фактического количества примеров, он прекращает обучение. Если N больше, он завершит одну эпоху, но все равно завершится, когда закончатся данные. Если вы добавите повторить,

train_dataset = train_dataset.shuffle(buffer_size=10000).repeat()

Он начнет новый цикл над набором данных, когда будет достигнуто фактическое количество примеров, а не когда начнется новая эпоха.

...