Ошибка использования fit_generator с сиамской сетью - PullRequest
0 голосов
/ 30 декабря 2018

Я пытаюсь адаптировать Сиамский пример Keras MNIST для использования генератора.

На примере имеем:

model.fit([tr_pairs[:, 0], tr_pairs[:, 1]], tr_y,
          batch_size=128,
          epochs=epochs,
          validation_data=([te_pairs[:, 0], te_pairs[:, 1]], te_y))

Пытаясь выяснить форму, которую должен вернуть генератор, я сделал:

np.array([tr_pairs[:, 0], tr_pairs[:, 1]]).shape

и получил

(2, 108400, 28, 28)

Мой генератор затем возвращает это:

(data, labels) = my_generator
data.shape
(2, 6, 300, 300, 3)
labels.shape
(6,)

Итак, это два массива (для входов NN) с 6 изображениями (batch_size) размером 300x300x3 (RGB).

Ниже приведено использование fit_generator():

...
input_shape = (300, 300, 3)
...
model.fit_generator(kbg.generate(set='train'), 
                    steps_per_epoch=training_steps,
                    epochs=1,
                    verbose=1,
                    callbacks=[],
                    validation_data=kbg.generate(set='test'),
                    validation_steps=validation_steps,
                    use_multiprocessing=False,
                    workers=0)  

Я полагаю, что я кормлю NN той же формы, но получаю следующую ошибку:

ValueError: Error when checking model input: the list of Numpy arrays that you are passing to your model is not the size the model expected. Expected to see 2 array(s), but instead gotthe following list of 1 arrays: [array([[[[[0.49803922, 0.48235294, 0.55686275],
          [0.63137255, 0.61176471, 0.64313725],
          [0.8627451 , 0.84313725, 0.84313725],
          ...,
          [0.58823529, 0.64705882, 0.631...

Что не так?

1 Ответ

0 голосов
/ 30 декабря 2018

Поскольку модель имеет два входных слоя , генератор должен выдать список из двух массивов в качестве входных выборок, соответствующих двум входным слоям, например:

def my_generator(args):
    # ...
    yield [first_pair, second_pair], labels

, где first_pair и second_pair оба имеют форму (n_samples, 300, 300, 3).

...