По какой-то загадочной причине он будет работать правильно, если поместить K.cast()
в lambda
:
input_A = Input(batch_shape=(128,1), name='A_input', dtype='int32')
input_B = Input(batch_shape=(128,1), name='B_input', dtype='int32')
input_A_ = Lambda(lambda x: K.reshape(x, (-1,)))(input_A)
input_B_ = Lambda(lambda x: K.reshape(x, (-1,)))(input_B)
embedded_text_A = Lambda(lambda x: K.gather(M1, K.cast(x, dtype='int32')))(input_A_)
embedded_text_B = Lambda(lambda x: K.gather(M1, K.cast(x, dtype='int32')))(input_B_)
Следовательно, слой Lambda
делает странное преобразование dtype внутри.
Полагаю, это какая-то ошибка, и моя гипотеза состоит в том, что неявное преобразование происходит внутри *1000* __call__
(который унаследован от Layer.__call__
) . Я не могу отследить это, но я предполагаю, что ошибка «неявного преобразования» где-то в Layer.__call__
, но до строки 451 , где на самом деле вызывается Lambda.call
.