В настоящее время я пытаюсь реализовать сеть, которая обучает триплетам изображений. Для этого я адаптировал генератор пар, который я нашел в Inte rnet:
def triplet_generator(triples, image_cache, datagens, batch_size=32):
while True:
# shuffle once per batch
indices = np.random.permutation(np.arange(len(triples)))
num_batches = len(triples) // batch_size
for bid in range(num_batches):
batch_indices = indices[bid * batch_size : (bid + 1) * batch_size]
batch = [triples[i] for i in batch_indices]
X1 = np.zeros((batch_size, 64, 64, 3))
X2 = np.zeros((batch_size, 64, 64, 3))
X3 = np.zeros((batch_size, 64, 64, 3))
for i, (image_filename_l, image_filename_m, image_filename_r) in enumerate(batch):
if datagens is None or len(datagens) == 0:
X1[i] = image_cache[image_filename_l]
X2[i] = image_cache[image_filename_m]
X3[i] = image_cache[image_filename_r]
else:
X1[i] = datagens[0].random_transform(image_cache[image_filename_l])
X2[i] = datagens[1].random_transform(image_cache[image_filename_m])
X3[i] = datagens[2].random_transform(image_cache[image_filename_r])
yield [np.array(X1), np.array(X2), np.array(X3)]
, который используется для обучения моей сети:
base_network = create_base_network(input_shape)
print(base_network.summary())
input_a = Input(shape=input_shape)
input_b = Input(shape=input_shape)
input_c = Input(shape=input_shape)
# because we re-use the same instance `base_network`,
# the weights of the network
# will be shared across the two branches
processed_a = base_network(input_a)
processed_b = base_network(input_b)
processed_c = base_network(input_c)
merged_vector = concatenate([processed_a, processed_b, processed_c], axis=-1, name='merged_layer')
model = Model([input_a, input_b, input_c], merged_vector)
checkpoint = ModelCheckpoint(filepath=BEST_MODEL_FILE, save_best_only=True)
# train
rms = RMSprop()
model.compile(loss=loss_desc_triplet, optimizer=rms, metrics=[accuracy])
history = model.fit_generator(train_pair_gen,
steps_per_epoch=num_train_steps,
epochs=epochs,
validation_data=val_pair_gen,
validation_steps=num_val_steps,
callbacks=[checkpoint])
Однако я получаю это сообщение об ошибке из этого:
Трассировка (последний последний вызов):
Файл "DeepLearningWithAugmentationWithTriplets.py", строка 256, в обратных вызовах = [контрольная точка])
Файл "lib / python3 .7 / site-packages / keras / legacy / interfaces.py", строка 91, в оболочке возвращает удовольствие c (* args, ** kwargs)
Файл "lib / python3 .7 / site-packages / keras / engine / training.py ", строка 1418, в fit_generator initial_epoch = initial_epoch)
Файл" lib / python3 .7 / site-packages / keras / engine / training_generator.py ", строка 217, в fit_generator class_weight = class_weight)
Файл" lib / python3 .7 / site-packages / keras / engine / training.py ", строка 1211, в train_on_batch class_weight = class_weight)
Файл "lib / python3 .7 / site-packages / keras / engine / training.py", строка 751, в _standardize_user_data exception_prefix = 'input')
Файл "lib / python3 .7 / site-packages / keras / engine / training_utils.py", строка 102, в standardize_input_data str (len (data)) + 'arrays:' + str (data) [: 200] + '...')
ValueError: Ошибка при проверке ввода модели: список Numpy массивов, передаваемых в вашу модель, не соответствует размеру, ожидаемому моделью. Ожидается увидеть 3 массива (ов), но вместо этого получен следующий список из 1 массива: [array ([[[[0.36862746, 0.36862746, 0.36862746], [0.36862746, 0.36862746, 0.36862746], [0.36862746, 0.36862746, 0.36862746],. .., [0.41176471, 0.41176471, 0.41176471 ...
Тем не менее, когда я печатаю следующий элемент генератора, это список из трех массивов:
train_pair_gen = triplet_generator(triples_data_train, image_cache, datagens, batch_size)
[X1, X2, X3] = next(train_pair_gen)
print(X1.shape, X2.shape, X3.shape) --> (32, 64, 64, 3) (32, 64, 64, 3) (32, 64, 64, 3)
print("###")
print(len(next(train_pair_gen))) --> 3
Что я делаю неправильно?
Определение сети:
_________________________________________________________________ Layer (type) Параметр выходной формы #
=============== ================================================== input_131 (InputLayer) (Нет, 64, 64, 3) 0
_________________________________________________________________ conv2d_100 (Conv2D) (Нет, 29, 29, 32) 4736
_________________________________________________________________ conv2d_101 (Conv2D) (Нет, 8, 8, 64) 73792
_________________________________________________________________ conv2d_102 (Conv2D) (нет, 1, 1, 128) 204928 * 10 41 * _________________________________________________________________ flatten_34 (Flatten) (Нет, 128) 0
======================================== =============================== Всего параметров: 283 456 Обучаемые параметры: 283 456 Необучаемые параметры: 0