Как внести изменения на месте в тензор, локальный для метода в методе, который не является локальным, в TensorFlow? - PullRequest
0 голосов
/ 25 декабря 2018

Мы знаем, что в python данные передаются по имени через методы.Скажем, у меня есть список a , который является локальным для метода m1 (), и я хотел передать его другому методу и внести некоторые изменения в него в другом методе и сохранить эти изменения, затемдовольно просто и может быть сделано следующим образом:

def m1(a):
   a.append(5)
def m2():
   a = [1, 2, 3, 4]
   print('Before: ', a) # Output= Before: [1, 2, 3, 4]
   m1(a)
   print('After: ', a) # Output= After: [1, 2, 3, 4, 5]
m2()

Как сделать то же самое, если a был тензор?Я хочу сделать что-то вроде

def m1(t1):
  t2 = tf.constant([[[7, 4], [8, 4]], [[2, 10], [15, 11]]])
  tf.concat([t1, t2], axis = -1)


def m2():
  t1 = tf.constant([[[1, 2], [2, 3]], [[4, 4], [5, 3]]])
  se = tf.Session()
  print('Before: ', se.run(t1)) # Output = Before: [[[1, 2], [2, 3]], [[4, 4], [5, 3]]]
  m1(t1)
  print('After: ', se.run(t1))  #Actual Output = After : [[[1, 2], [2, 3]], [[4, 4], [5, 3]]] | Desired Output = After : [[[1, 2, 7, 4], [2, 3, 8, 4]], [[4, 4, 2, 10], [5, 3, 15, 11]]]

m2()

1 Ответ

0 голосов
/ 25 декабря 2018

tf.concat фактически возвращает каскадный тензор и не делает его на месте, так как тензор потока в основном работает над добавлением новых узлов в граф.Итак, этот новый тензор добавлен к графику.

Этот код работает:

import tensorflow as tf

def m1(t1):
  t2 = tf.constant([[[7, 4], [8, 4]], [[2, 10], [15, 11]]])
  return tf.concat([t1, t2], axis = -1)

def m2():
  t1 = tf.constant([[[1, 2], [2, 3]], [[4, 4], [5, 3]]])
  se = tf.Session()
  print('Before: ', se.run(t1)) # Output = Before: [[[1, 2], [2, 3]], [[4, 4], [5, 3]]]
  t1 = m1(t1)
  print('After: ', se.run(t1))  #Actual Output = After : [[[1, 2], [2, 3]], [[4, 4], [5, 3]]] | Desired Output = After : [[[1, 2, 7, 4], [2, 3, 8, 4]], [[4, 4, 2, 10], [5, 3, 15, 11]]]

m2()

Он дает следующий вывод:

('Before: ', array([[[1, 2],
        [2, 3]],

       [[4, 4],
        [5, 3]]], dtype=int32))
('After: ', array([[[ 1,  2,  7,  4],
        [ 2,  3,  8,  4]],

       [[ 4,  4,  2, 10],
        [ 5,  3, 15, 11]]], dtype=int32))

См. Это tf.concat

...