Понимание batch_dot () в Keras с бэкэндом Tensorflow - PullRequest
0 голосов
/ 06 января 2019

Я пытаюсь понять этот фрагмент кода (из здесь ), который реализует внимание к точечному произведению с использованием умножения матриц между двумя тензорами. В частности, функция batch_dot () из бэкэнда Keras используется между двумя тензорами с переменным первым измерением. В этом случае batch_dot () работает иначе, чем когда указано первое измерение.

MWE:

Фиксированное первое измерение, работает как ожидалось

q = K.ones(shape=(36,8,24))
k = K.ones(shape=(36,8,24))
print(K.batch_dot(q,k,axes=[1,1]))

1010 * возвращается *

Tensor("MatMul_8:0", shape=(?, 36, 24, 24), dtype=float32)

и,

print(K.batch_dot(q,k,axes=[2,2]))

возвращает

Tensor("MatMul_9:0", shape=(?, 36, 8, 8), dtype=float32)

Однако, определив q и k следующим образом:

q = Input(shape=(36,8,24))
k = Input(shape=(36,8,24))
print(q)
print(k)

(переменное первое измерение)

Tensor("input_24:0", shape=(?, 36, 8, 24), dtype=float32)
Tensor("input_25:0", shape=(?, 36, 8, 24), dtype=float32)

Размеры вывода из операции batch_dot () неожиданны:

K.batch_dot(q,k,axes=[1,1])
<tf.Tensor 'MatMul_11:0' shape=(?, 36, 24, 24) dtype=float32>
K.batch_dot(q,k,axes=[2,2])
<tf.Tensor 'MatMul_12:0' shape=(?, 36, 24, 24) dtype=float32>

В соответствии с документацией аргументы axes указывают размеры, которые удаляются во время операции, однако я не могу связать это определение с выходами выше. Учитывается ли первое измерение (со значением ?) для аргументов axes?

Ответы [ 2 ]

0 голосов
/ 10 января 2019

Все будет ясно, если вы посмотрите на исходный код на https://github.com/tensorflow/tensorflow/blob/a6d8ffae097d0132989ae4688d224121ec6d8f35/tensorflow/python/keras/backend.py#L1437

мы можем пойти прямо line1507

if ndim(x) == 2 and ndim(y) == 2:
    if axes[0] == axes[1]:
      out = math_ops.reduce_sum(math_ops.multiply(x, y), axes[0])
    else:
      out = math_ops.reduce_sum(
          math_ops.multiply(array_ops.transpose(x, [1, 0]), y), axes[1])
else: 
    adj_x = None if axes[0] == ndim(x) - 1 else True
    adj_y = True if axes[1] == ndim(y) - 1 else None
    out = math_ops.matmul(x, y, adjoint_a=adj_x, adjoint_b=adj_y)

Как видно, он проверяет только adj_x и adj_y и не передает параметр axes в метод math_ops.matmul. Вот почему вы получаете тот же результат, когда axes равен [1,1] и [2,2].

Мы можем использовать следующий код для проверки:

q = K.ones(shape=range(1, 10))
k = K.ones(shape=range(1, 10))
for i in range(10): print(i, K.batch_dot(q,k,axes=[i,i]))

будет напечатано

0 Tensor("MatMul_7:0", shape=(1, 2, 3, 4, 5, 6, 7, 9, 9), dtype=float32)
1 Tensor("MatMul_8:0", shape=(1, 2, 3, 4, 5, 6, 7, 9, 9), dtype=float32)
2 Tensor("MatMul_9:0", shape=(1, 2, 3, 4, 5, 6, 7, 9, 9), dtype=float32)
3 Tensor("MatMul_10:0", shape=(1, 2, 3, 4, 5, 6, 7, 9, 9), dtype=float32)
4 Tensor("MatMul_11:0", shape=(1, 2, 3, 4, 5, 6, 7, 9, 9), dtype=float32)
5 Tensor("MatMul_12:0", shape=(1, 2, 3, 4, 5, 6, 7, 9, 9), dtype=float32)
6 Tensor("MatMul_13:0", shape=(1, 2, 3, 4, 5, 6, 7, 9, 9), dtype=float32)
7 Tensor("MatMul_14:0", shape=(1, 2, 3, 4, 5, 6, 7, 9, 9), dtype=float32)
8 Tensor("MatMul_15:0", shape=(1, 2, 3, 4, 5, 6, 7, 8, 8), dtype=float32)
9 Tensor("MatMul_16:0", shape=(1, 2, 3, 4, 5, 6, 7, 9, 9), dtype=float32)

За исключением случаев, когда i равно 8, все остальные возвращают тот же результат.

0 голосов
/ 08 января 2019

Учитывается ли первое измерение (со значением?) Для аргументов осей?

Да, это считается.

Дело в том, что первое измерение в слое Input в приведенном выше примере - это размер партии, а в K.ones() - нет. В результате оси [3, 3] для Input равны осям [2, 2] в K.ones(). В вашем коде следующие два batch_dot равны:

q = tf.keras.layers.Input(shape=(36, 8, 24)) # shape =(?, 36,8,24)
k = tf.keras.layers.Input(shape=(36, 8, 24)) # shape =(?, 36,8,24)
print(tf.keras.backend.batch_dot(q, k, axes=[3, 3]))

q = tf.keras.backend.ones(shape=(36, 8, 24))
k = tf.keras.backend.ones(shape=(36, 8, 24))
print(tf.keras.backend.batch_dot(q, k, axes=[2, 2]))

Обратите внимание, что в K.ones(), если форма была символической, мы не можем вернуть переменную, а вместо этого вернем динамический тензор. Что это значит? См. Следующий пример для лучшего понимания:

a = tf.keras.layers.Input(shape=(30,))
c = tf.keras.backend.ones(shape=tf.shape(a))
print(c) # shape=(?, 30)
d = tf.keras.backend.ones(shape=(30, 40))
print(d) # shape=(30,40)

Размеры вывода из операции batch_dot () неожиданны

K.batch_dot(q,k,axes=[1,1])
<tf.Tensor 'MatMul_11:0' shape=(?, 36, 24, 24) dtype=float32>
K.batch_dot(q,k,axes=[2,2])
<tf.Tensor 'MatMul_12:0' shape=(?, 36, 24, 24) dtype=float32>

С какой стати это происходит, когда оси разные?

Чтобы ответить на этот вопрос, мы должны знать о базовой реализации batch_dot. Если ранг входных тензоров не равен 2, то наша batch_dot ведет себя как операция tf.matmul, когда один из входных тензоров сопряженно транспонирован. В результате, когда наши входные тензоры имеют ранг 3 и мы устанавливаем ось, равную 0 или 1, они вычисляют те же самые вещи, но когда устанавливают оси на 2, это вычисляет что-то другое:

a = np.array([[[1, 2, 3],
               [3, 2, 1]]])  # rank 3

b = np.array([[[1, 3, 3],
               [2, 2, 0]]])  # rank 3

a = tf.constant(a, dtype=tf.float32)
b = tf.constant(b, dtype=tf.float32)

c = tf.matmul(a, b, adjoint_a=True, adjoint_b=False)  # when axes is [0,0] or [1,1]
d = tf.matmul(a, b, adjoint_a=False, adjoint_b=True)  # when axes is [2,2]
print(c.shape)  # shape=(1,3,3)
print(d.shape)  # shape=(1,2,2)

То же самое произошло в вашем примере:

a = np.array([[[1, 2, 3],
               [3, 2, 1]]])

b = np.array([[[1, 3, 3],
               [2, 2, 0]]])

q = tf.placeholder(dtype=tf.float32, shape=(None, 2, 3))  
k = tf.placeholder(dtype=tf.float32, shape=(None, 2, 3))  
res1 = tf.keras.backend.batch_dot(q, k, axes=0)
res2 = tf.keras.backend.batch_dot(q, k, axes=1)
res3 = tf.keras.backend.batch_dot(q, k, axes=2)
with tf.Session() as sess:
    feed_dic = {q: a, k: b}
    print(sess.run(res1, feed_dict=feed_dic))
    print(20 * '-')
    print(sess.run(res2, feed_dict=feed_dic))
    print(20 * '-')
    print(sess.run(res3, feed_dict=feed_dic))
...