sess.run () не запускается? - PullRequest
0 голосов
/ 03 мая 2018

Я здесь новичок, изучаю тензорный поток и сталкиваюсь с проблемой.

import model_method
fittt(model_method.build(self,...),...parameters...)

Выше приведено в main.py импортировании model_method.py. Функция fittt в main.py:

def fittt(model,...):
    model.fit(...)

build () в model_method.py:

def build(self,...):
    self.op_C,self.op_A = self.function_A(...)
    self.op_B = self.function_B(self.op_C,...)

fit () in model_method.py:

def fit(self,...):
    sess = tf.Session(graph=self.graph,config=config)
    BB,AA = sess.run([self.op_B,self.op_A],feed_dict)

Чтобы проверить запущенный процесс, я добавил pdb.set_trace () в начале function_A () и function_B () в model_method.py следующим образом:

def function_A(self,...):
    pdb.set_trace()
    ......

def function_B(self,...):
    pdb.set_trace()
    ......

Два pdb.set_trace () остановились только при вызове build () и не работали при вызове sess.run ([self.op_B, self.op_A], feed_dict). Таким образом, это означает, что sess.run () на самом деле не запускал function_A () и function_B (). Интересно, почему и хочешь знать, как заставить работать две функции?

1 Ответ

0 голосов
/ 03 мая 2018

Вызывая функцию model_method.build(), вы создаете график вычислений. В этом вызове выполняется каждая строка кода (следовательно, почему pdb остановлено).

Однако tf.Session.run(...) выполняет только те части вычислительного графа, которые необходимы для вычисления извлеченных значений (self.op_A, self.op_B в вашем примере). Функция не выполняет всю функцию build() снова.

Поэтому причина, по которой pdb.set_trace() не выполнялась при запуске sess.run(...), заключается в том, что они не являются действительными Tensor объектами и, следовательно, не являются частью вычислительного графа.

UPDATE

Примите во внимание следующее:

class My_Model:

  def __init__(self):
      self.np_input = np.random.normal(size=(10,2)) # 10x2

  def build(self):
      self._in = tf.placeholder(dtype=tf.float32, shape=[10, None]) # matrix 10xN
      W_exception = tf.random_normal(dtype=tf.float32, shape=[3,3]) # matrix 3x3
      W_success = tf.random_normal(dtype=tf.float32, shape=[2,3]) # matrix 2x3
      self.op_exception = tf.matmul(self._in, W_exception) # [10x2] x [3x3] = ERROR
      self.op_success = tf.matmul(self._in, W_success) # [10x2] x [2x3] = [10x3]
      print('Computational Graph Built')

  def fit_success(self):
      with tf.Session() as sess:
          res = sess.run(self.op_success, feed_dict={self._in : self.np_input})
          print('Result shape: {}'.format(res.shape))

  def fit_exception(self):
      with tf.Session() as sess:
          res = sess.run(self.op_exception, feed_dict={self._in : self.np_input})
          print('Result shape: {}'.format(res.shape))

и затем звоните:

m = My_Model()
m.build()
#> Computational Graph Built

m.fit_success()
#> Result shape: (10, 3)

m.fit_exception()
#> InvalidArgumentError: Matrix size-incompatible: In[0]: [10,2], In[1]: [3,3]

Итак, чтобы объяснить, что вы там видите. Сначала мы определим вычислительный граф в функции build(). _in - наш входной тензор; None означает, что размерность 1 определяется динамически, то есть, когда мы задаем тензор с указанными значениями.

Затем мы определили две матрицы W_exception и W_success, в которых указаны все размеры и их значения будут сгенерированы случайным образом.

Затем мы определяем две операции, умножение матрицы, каждая из которых возвращает тензор.

Мы вызвали функцию build() и создали вычислительный граф, функция print() также выполняется, но НЕ добавляется в граф. Здесь ничего не вычисляется. На самом деле, это даже не может быть, потому что значения _in не указаны.

Теперь, чтобы показать, что оцениваются только необходимые детали, необходимые для вычисления, мы вызываем функцию fit_success(), которая просто умножает входной тензор _in на тензор W_success (с правильными размерами). Мы получаем тензор правильной формы: [10x3]. Обратите внимание, что мы не получаем ошибки, что op_exception не может быть вычислено из-за несовпадающих размеров. Это потому, что нам не нужно оценивать op_success.

Наконец, я просто показываю, что исключение действительно выдается, когда мы пытаемся оценить op_exception с тем же входным тензором.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...