InvalidArgumentError при тренировке с TripletSemiHardLoss и последовательностью керас - PullRequest
0 голосов
/ 22 апреля 2020

Я тренирую модель, используя функцию TripletSemiHardLoss, предоставленную пакетом tenorflow_addons. Я загружаю свои данные с помощью пользовательского объекта Seras Sequence, код обучения следующий:

train_generator = LSTMGenerator(train_df, ONE_HOT, config)

X, y = train_generator[0]
print(X.shape)
print(y.shape)

history = model.fit(train_generator, epochs=100)


Результат выполнения этого кода:

(50, 7, 144, 180, 3) # Shape of X
(50,) # Shape of y

...

InvalidArgumentError:  Can not squeeze dim[0], expected a dimension of 1, got 50
     [[node loss_5/lambda_2_loss/TripletSemiHardLoss/weighted_loss/Squeeze (defined at /home/user/.local/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:3009) ]] [Op:__inference_keras_scratch_graph_33244]

Я выбрал размер партии из 50, но по какой-то причине TripletSemiHardLoss, похоже, ожидает batch_size, равный 1. Не сталкивались ли некоторые из вас с подобного рода проблемами?

...