Что именно tf.expand_dims делает с вектором и почему результаты могут быть добавлены вместе, даже если формы матрицы различны? - PullRequest
0 голосов
/ 13 января 2019

Я добавляю два вектора, которые, как я думал, были «преобразованы» вместе, и в результате получаю 2d матрицу. Я ожидаю некоторый тип ошибки здесь, но не получил это. Я думаю, что понимаю, что происходит, он рассматривал их так, как будто было еще два набора каждого вектора по горизонтали и вертикали, но я не понимаю, почему результаты a и b не отличаются. И если они не предназначены, почему это вообще работает?

import tensorflow as tf
import numpy as np

start_vec = np.array((83,69,45))
a = tf.expand_dims(start_vec, 0)
b = tf.expand_dims(start_vec, 1)
ab_sum = a + b
init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    a = sess.run(a)
    b = sess.run(b)
    ab_sum = sess.run(ab_sum)

print(a)
print(b)
print(ab_sum)

=============================================== ==

[[83 69 45]]

[[83]
 [69]
 [45]]

[[166 152 128]
 [152 138 114]
 [128 114  90]]

1 Ответ

0 голосов
/ 13 января 2019

На самом деле, этот вопрос более широко использует вещание характеристик тензорного потока, который так же, как NumPy ( Broadcasting ). Broadcasting избавляет от требования, чтобы форма операции между тензорами была одинаковой. Конечно, он также должен соответствовать определенным условиям.

Общие правила вещания:

При работе с двумя массивами, NumPy сравнивает их формы поэлементно. Начинается с трейлинга размеры, и продвигается вперед. Два измерения совместимы когда

1. они равны или

2. один из них 1

Простой пример - одномерные тензоры, умноженные на скаляры.

import tensorflow as tf

start_vec = tf.constant((83,69,45))
b = start_vec * 2

with tf.Session() as sess:
    print(sess.run(b))

[166 138  90]

Возвращаясь к вопросу, функция tf.expand_dims() состоит в том, чтобы вставить измерение в форму тензора в указанной позиции axis. Ваша исходная форма данных (3,). Вы получите форму a=tf.expand_dims(start_vec, 0), равную (1,3), когда ваш сет axis=0. Вы получите форму b=tf.expand_dims(start_vec, 1), равную (3,1), когда ваш сет axis=1.

.

Сравнивая правила broadcasting, вы можете увидеть, что они удовлетворяют второму условию. Таким образом, их фактическая операция составляет

83,83,83     83,69,45
69,69,69  +  83,69,45
45,45,45     83,69,45
...