Как объединить два измерения тензора, не принимая его форму - PullRequest
2 голосов
/ 07 октября 2019

Давайте предположим следующую функцию:

from tensorflow.python.keras import backend as K

def broadcast_sum(a, b):
    a = K.expand_dims(a, 1)
    b = K.expand_dims(b, 2)
    c = a + b
    cs = K.shape(c)
    return K.reshape(c, (cs[0], -1, cs[-1]))

Учитывая два тензора форм (1, 3, 2) и (1, 4, 2), он корректно возвращает:

>>> broadcast_sum(K.placeholder((1, 3, 2)), K.placeholder((1, 4, 2)))
>>> <tf.Tensor 'Reshape_2:0' shape=(1, 12, 2) dtype=float32>

Прямо сейчас эта функцияработает только с 3D-вводом (из-за строки reshape). Мой вопрос: как я могу заставить эту работу работать с любой формой (используя ту же функцию), не зная форму? Конечно, я предполагаю, что входы имеют одинаковую форму и, по крайней мере, 3D. Но как я могу иметь одну функцию, которая работает с 3D, 4D и т. Д.?

И я предполагаю, что это всегда второе измерение (слева), которое будет транслироваться функцией, а остальные измеренияидентичны между двумя входами. Вот фигуры, с которыми я хочу заставить одну и ту же функцию работать:

>>> broadcast_sum(K.placeholder((1, 3, 5, 2)), K.placeholder((1, 4, 5, 2)))
>>> <tf.Tensor 'Reshape_3:0' shape=(1, 60, 2) dtype=float32>

Конечно, возвращенный тензор сейчас неверен. Он должен иметь форму (1, 12, 5, 2).

[ОБНОВЛЕНИЕ]

Также учтите, что первое измерение (размер партии) может составлять None. Фактически, любое из измерений, кроме самого правого, может быть None.

1 Ответ

1 голос
/ 07 октября 2019

И я предполагаю, что функция всегда будет транслироваться во втором измерении (слева), а остальные измерения идентичны между двумя входами.

На основеэто, я повторно использую информацию формы от одного из входных данных.

from tensorflow.python.keras import backend as K
def broadcast_sum(a, b):
    final_shape = (a.shape[0], -1, *a.shape[2:])
    a = K.expand_dims(a, 1)
    b = K.expand_dims(b, 2)
    c = a + b
    return K.reshape(c, final_shape)


print(broadcast_sum(K.placeholder((1, 3, 2)), K.placeholder((1, 4, 2))))
print(broadcast_sum(K.placeholder((1, 3, 5, 2)), K.placeholder((1, 4, 5, 2))))

Tensor("Reshape:0", shape=(1, 4, 3, 2), dtype=float32)
Tensor("Reshape_1:0", shape=(1, 12, 5, 2), dtype=float32)
...