Я пытаюсь обучить сеть pix2pix в Google Colab с использованием Tensorflow 2.0, и я использую tf.data.Dataset для импорта данных, которые являются изображениями (всего около 4900).Тем не менее, так как это много изображений для обучения, я разбил набор данных на партии с помощью функции .batch ().Однако я не знаю, как использовать во время тренировки только одну из партий, чтобы не загружать весь набор данных в ОЗУ.В общем, я следую приведенному здесь коду https://www.tensorflow.org/beta/tutorials/generative/pix2pix
Когда я передаю весь набор данных в своей функции Train для итерации, ОЗУ сходит с ума, а иногда и весь ноутбук падает.Я не уверен, как передать случайную партию для каждой эпохи, чтобы решить эту проблему.
Так я загружаю свои данные.В конце каждая партия имеет форму (64 256 256,3).
train_dataset = tf.data.Dataset.list_files(workpath + '/train/*.jpg')
train_dataset = train_dataset.shuffle(4900, seed=23).take(4900)
train_dataset = train_dataset.map(load_image_train,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_dataset = train_dataset.batch(64)
Моя функция поезда выглядит следующим образом.
def train(dataset, epochs):
for epoch in range(epochs):
for input_image, target in dataset:
gen, disc = train_step(input_image, target)
print('Gen loss: {} Disc loss: {}'.format(gen, disc))
train(train_dataset, 60)
Если я передам весь набор train_datase вФункция поезда будет проходить через все партии в каждую эпоху.Я пытаюсь найти способ передать разные партии для каждой эпохи.Единственный способ сделать это - переставить пакеты и выбрать один из всего набора данных для каждой эпохи, но это на самом деле не решает проблему с памятью, так как мне все равно нужно передать весь набор данных.способ сделать это?