У меня есть модель Keras и данные, которые я загружаю в Pandas фрейм данных. В целях тестирования и отладки мой набор данных довольно скромный. Фактически, я могу обработать и загрузить весь набор данных в память и обучить модель без проблем.
Позже я собираюсь перейти к обучению модели на гораздо большем наборе данных, поэтому я написал следующее генератор:
augment_image = keras.preprocessing.image.ImageDataGenerator(
rotation_range = 20,
zoom_range = 0.1,
width_shift_range = 0.1,
height_shift_range = 0.1,
shear_range = 0.1,
horizontal_flip = True,
vertical_flip = True,
fill_mode = "nearest"
)
def data_generator(df, batch_size, augment = None):
df = df.sample(frac = 1).reset_index(drop = True) # Shuffle
num_rows = len(df.index)
batch_num = 0
while True:
index_low = min(batch_num * batch_size, num_rows - 1)
index_high = min(index_low + batch_size, num_rows)
batch_num += 1
if batch_num * batch_size > num_rows - 1:
batch_num = 0
subframe = df.iloc[index_low:index_high]
images = load_images(subframe, path, image_type)
if augment is not None:
images = augment.flow(np.array(images))
targets = subframe["target"]
yield (np.array(images), np.array(targets))
Функция «load_images» просто берет список имен файлов (то есть подкадр [«image_type»]) и загружает фактические изображения, связанные с этими именами файлов. Я знаю, что это (и все, что не является уникальным для генератора) работает, потому что, как я упоминал ранее, я могу обучать модель, просто загружая весь набор данных в память. (То есть обработайте и загрузите весь набор данных - фактические изображения и все - в одну переменную и передайте его в model.fit.)
Но когда я пытаюсь использовать вышеуказанный генератор для передачи данные для model.fit, вот так ...
history = model.fit(
data_generator(train_set, batch_size=32, augment=augment_image),
verbose = 2,
epochs = 1,
steps_per_epoch = len(train_set.index) // 32,
validation_data = data_generator(test_set, batch_size=32),
validation_steps = len(test_set.index) // 32,
callbacks = [checkpoint, early_stopping, tensorboard]
)
... он зависает минут 15, прежде чем окончательно завершиться с тем, что кажется абсурдным MemoryError:
Traceback (most recent call last):
File "datagen_test.py", line 193, in <module>
callbacks = [checkpoint, early_stopping, tensorboard]
File "C:\Users\Username\Anaconda3\envs\tensorflow2\lib\site-packages\tensorflow_core\python\keras\engine\training.py", line 819, in fit
use_multiprocessing=use_multiprocessing)
File "C:\Users\Username\Anaconda3\envs\tensorflow2\lib\site-packages\tensorflow_core\python\keras\engine\training_v2.py", line 235, in fit
use_multiprocessing=use_multiprocessing)
File "C:\Users\Username\Anaconda3\envs\tensorflow2\lib\site-packages\tensorflow_core\python\keras\engine\training_v2.py", line 593, in _process_training_inputs
use_multiprocessing=use_multiprocessing)
File "C:\Users\Username\Anaconda3\envs\tensorflow2\lib\site-packages\tensorflow_core\python\keras\engine\training_v2.py", line 706, in _process_inputs
use_multiprocessing=use_multiprocessing)
File "C:\Users\Username\Anaconda3\envs\tensorflow2\lib\site-packages\tensorflow_core\python\keras\engine\data_adapter.py", line 747, in __init__
peek, x = self._peek_and_restore(x)
File "C:\Users\Username\Anaconda3\envs\tensorflow2\lib\site-packages\tensorflow_core\python\keras\engine\data_adapter.py", line 850, in _peek_and_restore
peek = next(x)
File "datagen_test.py", line 84, in data_generator
yield (np.array(images), np.array(targets))
File "C:\Users\Username\Anaconda3\envs\tensorflow2\lib\site-packages\keras_preprocessing\image\iterator.py", line 104, in __next__
return self.next(*args, **kwargs)
File "C:\Users\Username\Anaconda3\envs\tensorflow2\lib\site-packages\keras_preprocessing\image\iterator.py", line 116, in next
return self._get_batches_of_transformed_samples(index_array)
File "C:\Users\Username\Anaconda3\envs\tensorflow2\lib\site-packages\keras_preprocessing\image\numpy_array_iterator.py", line 148, in _get_batches_of_transformed_samples
dtype=self.dtype)
MemoryError: Unable to allocate 18.4 MiB for an array with shape (32, 224, 224, 3) and data type float32
Это кажется довольно небольшим объемом памяти, который невозможно выделить. (И снова у меня достаточно памяти, чтобы просто загрузить весь набор данных и обучить модель таким образом.) Что я делаю не так?