Учитывается ли первое измерение (со значением?) Для аргументов
осей?
Да, это считается.
Дело в том, что первое измерение в слое Input
в приведенном выше примере - это размер партии, а в K.ones()
- нет. В результате оси [3, 3] для Input равны осям [2, 2] в K.ones()
. В вашем коде следующие два batch_dot
равны:
q = tf.keras.layers.Input(shape=(36, 8, 24)) # shape =(?, 36,8,24)
k = tf.keras.layers.Input(shape=(36, 8, 24)) # shape =(?, 36,8,24)
print(tf.keras.backend.batch_dot(q, k, axes=[3, 3]))
q = tf.keras.backend.ones(shape=(36, 8, 24))
k = tf.keras.backend.ones(shape=(36, 8, 24))
print(tf.keras.backend.batch_dot(q, k, axes=[2, 2]))
Обратите внимание, что в K.ones()
, если форма была символической, мы не можем вернуть переменную, а вместо этого вернем динамический тензор. Что это значит? См. Следующий пример для лучшего понимания:
a = tf.keras.layers.Input(shape=(30,))
c = tf.keras.backend.ones(shape=tf.shape(a))
print(c) # shape=(?, 30)
d = tf.keras.backend.ones(shape=(30, 40))
print(d) # shape=(30,40)
Размеры вывода из операции batch_dot () неожиданны
K.batch_dot(q,k,axes=[1,1])
<tf.Tensor 'MatMul_11:0' shape=(?, 36, 24, 24) dtype=float32>
K.batch_dot(q,k,axes=[2,2])
<tf.Tensor 'MatMul_12:0' shape=(?, 36, 24, 24) dtype=float32>
С какой стати это происходит, когда оси разные?
Чтобы ответить на этот вопрос, мы должны знать о базовой реализации batch_dot
. Если ранг входных тензоров не равен 2, то наша batch_dot
ведет себя как операция tf.matmul
, когда один из входных тензоров сопряженно транспонирован. В результате, когда наши входные тензоры имеют ранг 3 и мы устанавливаем ось, равную 0 или 1, они вычисляют те же самые вещи, но когда устанавливают оси на 2, это вычисляет что-то другое:
a = np.array([[[1, 2, 3],
[3, 2, 1]]]) # rank 3
b = np.array([[[1, 3, 3],
[2, 2, 0]]]) # rank 3
a = tf.constant(a, dtype=tf.float32)
b = tf.constant(b, dtype=tf.float32)
c = tf.matmul(a, b, adjoint_a=True, adjoint_b=False) # when axes is [0,0] or [1,1]
d = tf.matmul(a, b, adjoint_a=False, adjoint_b=True) # when axes is [2,2]
print(c.shape) # shape=(1,3,3)
print(d.shape) # shape=(1,2,2)
То же самое произошло в вашем примере:
a = np.array([[[1, 2, 3],
[3, 2, 1]]])
b = np.array([[[1, 3, 3],
[2, 2, 0]]])
q = tf.placeholder(dtype=tf.float32, shape=(None, 2, 3))
k = tf.placeholder(dtype=tf.float32, shape=(None, 2, 3))
res1 = tf.keras.backend.batch_dot(q, k, axes=0)
res2 = tf.keras.backend.batch_dot(q, k, axes=1)
res3 = tf.keras.backend.batch_dot(q, k, axes=2)
with tf.Session() as sess:
feed_dic = {q: a, k: b}
print(sess.run(res1, feed_dict=feed_dic))
print(20 * '-')
print(sess.run(res2, feed_dict=feed_dic))
print(20 * '-')
print(sess.run(res3, feed_dict=feed_dic))