У меня есть функция, где мне нужно сделать Keras batch_dot с тензором размера (?,61,80)
с 2D-тензором размера (40,61)
.Размер ?
предназначен для размера партии в пользовательском слое.При использовании Keras repeat_elements
нам необходимо указать размер пакета, чтобы сделать его тензор (batch_size, 40,61)
.Однако repeat_elements
не работает с ?
размером пакета.
Код
M1 = K.expand_dims(M,axis=0)
BatchM = K.repeat_elements(x=M1,rep=batch_size,axis=0)
out1 = K.batch_dot(BatchM,Ash1,axes=[2,1])
Здесь M
- это 2D-тензор размера (40,61)
.BatchM
должен дать (batch_size,40,61)
, а Ash1
имеет размер (?,61,80)
.
Редактировать 1:
A= Input(shape=(61,80))
M= K.variable(np.random.rand(40,61))
n=1
import tensorflow as tf
M1 = K.expand_dims(M,axis=0)
BatchM = K.repeat_elements(x=M1,rep=tf.shape(A)[0],axis=0)
out1 = K.batch_dot(BatchM,Ash1,axes=[2,1])
Эта ошибка возврата показывает:
Traceback (most recent call last)
File "<ipython-input-7-edc5ef31181b>", line 3, in <module>
BatchM = K.repeat_elements(x=M1,rep=tf.shape(A)[0],axis=0)
File "/home/hanumant/.conda/envs/kerasenv/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py", line 2092, in repeat_elements
x_rep = [s for s in splits for _ in range(rep)]
File "/home/hanumant/.conda/envs/kerasenv/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py", line 2092, in <listcomp>
x_rep = [s for s in splits for _ in range(rep)]
TypeError: 'Tensor' object cannot be interpreted as an integer