Я работаю с 388 3D-изображениями МРТ, которые слишком велики, чтобы вместить память, доступную при обучении модели CNN, поэтому я решил создать генератор, который принимает пакеты изображений в память для обучения за раз и объединяет это с пользовательским ImageDataGenerator для 3D-изображений (загружен для github). Я пытаюсь предсказать единый тестовый балл (от 1 до 30) с помощью изображения МРТ. У меня есть следующий код генератора, и я не уверен, что он правильный:
x = np.asarray(img)
y = np.asarray(scores)
def create_batch(x, y, batch_size):
x, y = shuffle(x, y)
x_split, x_val, y_split, y_val = train_test_split(x, y, test_size=.05, shuffle=True)
x_batch, x_test, y_batch, y_test = train_test_split(x_split, y_split, test_size=.05, shuffle=True)
x_train, y_train = [], []
num_batches = len(x_batch)//batch_size
for i in range(num_batches):
x_train.append([x_batch[0:batch_size]])
y_train.append([y_batch[0:batch_size]])
return x_train, y_train, x_val, y_val, x_batch, y_batch, x_test, y_test, num_batches
epochs = 1
model = build_model(input_size)
x_train, y_train, x_val, y_val, x_batch, y_batch, x_test, y_test, num_batches = create_batch(x, y, batch_size)
train_datagen = customImageDataGenerator(shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True)
val_datagen = customImageDataGenerator()
validation_set = val_datagen.flow(x_val, y_val, batch_size=batch_size, shuffle=False)
def generator(batch_size, epochs):
for e in range(epochs):
print('Epoch', e+1)
batches = 0
images_fitted = 0
for i in range(num_batches):
training_set = train_datagen.flow(x_train[i][0], y_train[i][0], batch_size=batch_size, shuffle=False)
images_fitted += len(x_train[i][0])
total_images = len(x_batch)
print('number of images used: %s/%s' % (images_fitted, total_images))
history = model.fit_generator(training_set,
steps_per_epoch = 1,
#callbacks = [earlystop],
validation_data = validation_set,
validation_steps = 1)
model.load_weights('jesse_weights_13layers.h5')
batches += 1
yield history
if batches >= num_batches:
break
return model
def train_load_weights():
history = generator(batch_size, epochs)
for e in range(epochs):
for i in range(num_batches):
print(next(history))
model.save_weights('jesse_weights_13layers.h5')
for i in range(1):
print('Run', i+1)
train_load_weights()
Я не уверен, был ли генератор построен правильно или модель правильно обучается и не знаю, как проверьте, есть ли это. Если у кого-то есть совет, я был бы признателен! Код запускается, и вот часть обучения:
Run 1
Epoch 1
number of images used: 8/349
Epoch 1/1
1/1 [==============================] - 156s 156s/step - loss: 8.0850 - accuracy: 0.0000e+00 - val_loss: 10.8686 - val_accuracy: 0.0000e+00
<keras.callbacks.callbacks.History object at 0x00000269A4B4E848>
number of images used: 16/349
Epoch 1/1
1/1 [==============================] - 154s 154s/step - loss: 4.3460 - accuracy: 0.0000e+00 - val_loss: 4.5994 - val_accuracy: 0.0000e+00
<keras.callbacks.callbacks.History object at 0x0000026899A96708>
number of images used: 24/349
Epoch 1/1
1/1 [==============================] - 148s 148s/step - loss: 4.1174 - accuracy: 0.0000e+00 - val_loss: 4.6038 - val_accuracy: 0.0000e+00
<keras.callbacks.callbacks.History object at 0x00000269A4F2F488>
number of images used: 32/349
Epoch 1/1
1/1 [==============================] - 151s 151s/step - loss: 4.2788 - accuracy: 0.0000e+00 - val_loss: 4.6029 - val_accuracy: 0.0000e+00
<keras.callbacks.callbacks.History object at 0x00000269A4F34D08>
number of images used: 40/349
Epoch 1/1
1/1 [==============================] - 152s 152s/step - loss: 3.9328 - accuracy: 0.0000e+00 - val_loss: 4.6057 - val_accuracy: 0.0000e+00
<keras.callbacks.callbacks.History object at 0x00000269A4F57848>
number of images used: 48/349
Epoch 1/1
1/1 [==============================] - 154s 154s/step - loss: 3.9423 - accuracy: 0.0000e+00 - val_loss: 4.6077 - val_accuracy: 0.0000e+00
<keras.callbacks.callbacks.History object at 0x00000269A4F4D888>
number of images used: 56/349
Epoch 1/1
1/1 [==============================] - 160s 160s/step - loss: 3.7610 - accuracy: 0.0000e+00 - val_loss: 4.6078 - val_accuracy: 0.0000e+00
<keras.callbacks.callbacks.History object at 0x00000269A4F3E4C8>
number of images used: 64/349