Я использую функциональную версию модели keras для обучения сети автоэнкодеров и использую функцию model.fit () из keras следующим образом:
model.fit(train_data, train_ground_truth_data, validation_data = (validation_data, validation_ground_truth_data), epochs=1000, batch_size=40, callbacks=callbacks)
Я смог тренировать модель таким образом. В это время я сразу конвертировал данные в тип данных float32, используя следующую команду:
train_data = train_data.astype ('float32') / 255.
Поскольку у меня большой объем данных, из-за которого ОЗУ не хватает памяти при преобразовании в float32 , я решил преобразовать его в пакетном режиме и использовать fit_generator для обучения моей модели. Следующий код загружает изображение и его основную правду в пакетном режиме.
def imageLoader(data, ground_truth_data, batch_size):
L = data.shape[0] # there are equal numbers of images in data and its ground_truth_data
#this line is just to make the generator infinite, keras needs that
while True:
batch_start = 0
batch_end = batch_size
while batch_start < L:
limit = min(batch_end, L)
#print(limit)
X = data[batch_start:limit, :].astype('float32') / 255.
Y = ground_truth_data[batch_start:limit, :].astype('float32') / 255.
yield (X,Y) #a tuple with two numpy arrays with batch_size samples
batch_start += batch_size
batch_end += batch_size
Я использую следующую строку кода для обучения, используя функцию fit_generator из keras.
model.fit_generator(imageLoader(train_data,train_ground_truth_data, batch_size), validation_data = imageLoader(validation_data, validation_ground_truth_data, batch_size), epochs=1000, steps_per_epoch = math.ceil(num_samples / batch_size), validation_steps = math.ceil(num_validation_samples / batch_size), callbacks=callbacks)
Я хотел проверить, что этот метод обучения с использованием fit_generator () будет работать так же, как функция fit (), я использовал небольшой набор данных и протестировал его, используя оба эти метода. Но проблема в том, что модель не тренируется, когда я использую fit_generator. Это было правильно, когда я использовал функцию model.fit_generator () с помощью функции imageLoader (). Потеря валидации и обучения постоянна после 2 эпох.
Разве это не правильный способ сделать это с помощью fit_generator ()?