Я обнаружил странную ошибку при создании собственного слоя keras. Я пытаюсь создать собственный слой, который очень похож на слой GRU, но принимает дополнительные входные данные в дополнение к sampled_z, чтобы выполнить принудительное использование учителем в вариационном автоэнкодере.
Я успешно создаю модель VAE, как показано ниже, где terminal_GRU означает кастомный слой ГРУ.
Model: "VAE"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
encoder_input (InputLayer) (None, 80, 69) 0
__________________________________________________________________________________________________
conv1d_1 (Conv1D) (None, 80, 9) 5598 encoder_input[0][0]
__________________________________________________________________________________________________
conv1d_2 (Conv1D) (None, 80, 9) 738 conv1d_1[0][0]
__________________________________________________________________________________________________
conv1d_3 (Conv1D) (None, 80, 10) 910 conv1d_2[0][0]
__________________________________________________________________________________________________
flatten_1 (Flatten) (None, 800) 0 conv1d_3[0][0]
__________________________________________________________________________________________________
dense_1 (Dense) (None, 180) 144180 flatten_1[0][0]
__________________________________________________________________________________________________
z_mean (Dense) (None, 180) 32580 dense_1[0][0]
__________________________________________________________________________________________________
z_log_var (Dense) (None, 180) 32580 dense_1[0][0]
__________________________________________________________________________________________________
z_sampling (Lambda) (None, 180) 0 z_mean[0][0]
z_log_var[0][0]
__________________________________________________________________________________________________
decoder (Model) (None, 80, 69) 1760451 z_sampling[0][0]
encoder_input[0][0]
==================================================================================================
Total params: 1,977,037
Trainable params: 1,977,037
Non-trainable params: 0
__________________________________________________________________________________________________
и модель декодера выглядит как
Model: "decoder"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
encoded_input (InputLayer) (None, 180) 0
__________________________________________________________________________________________________
reapeat_context (RepeatVector) (None, 80, 180) 0 encoded_input[0][0]
__________________________________________________________________________________________________
decoder_GRU1 (GRU) [(None, 80, 400), (N 697200 reapeat_context[0][0]
__________________________________________________________________________________________________
decoder_GRU2 (GRU) [(None, 80, 400), (N 961200 decoder_GRU1[0][0]
decoder_GRU1[0][1]
__________________________________________________________________________________________________
true_seq_input (InputLayer) (None, 80, 69) 0
__________________________________________________________________________________________________
terminal_GRU (TGRU) [(None, 80, 69), (No 102051 decoder_GRU2[0][0]
true_seq_input[0][0]
==================================================================================================
Total params: 1,760,451
Trainable params: 1,760,451
Non-trainable params: 0
__________________________________________________________________________________________________
Однако, когда я попытался использовать метод fit_generator () для обучения этой модели, я обнаружил InvalidArgumentError следующим образом:
InvalidArgumentError: Input 'pred' passed float expected bool while building NodeDef 'decoder/terminal_GRU/PartitionedCall/cond/switch_pred/_1362' using Op<name=Switch; signature=data:T, pred:bool -> output_false:T, output_true:T; attr=T:type> [Op:__inference_keras_scratch_graph_7791]
Кто-нибудь может сказать мне, почему возникает эта ошибка? Меня расстраивает, я не могу понять, почему возникает эта ошибка ...