Пример TF 2.0 @ tf.function - PullRequest
       5

Пример TF 2.0 @ tf.function

0 голосов
/ 22 марта 2019

В документации по тензорному потоку в секции autograph имеется следующий фрагмент кода

@tf.function
def train(model, optimizer):
  train_ds = mnist_dataset()
  step = 0
  loss = 0.0
  accuracy = 0.0
  for x, y in train_ds:
    step += 1
    loss = train_one_step(model, optimizer, x, y)
    if tf.equal(step % 10, 0):
      tf.print('Step', step, ': loss', loss, '; accuracy', compute_accuracy.result())
  return step, loss, accuracy

step, loss, accuracy = train(model, optimizer)
print('Final step', step, ': loss', loss, '; accuracy', compute_accuracy.result())

У меня небольшой вопрос по поводу переменной step, это целое число, а не тензор, автограф поддерживает встроенный тип Python, такой как целое число. Поэтому tf.equal(step%10,0) можно изменить просто на step%10 == 0, верно?

Ответы [ 2 ]

2 голосов
/ 22 марта 2019

Да, вы правы.Целочисленная переменная step остается переменной Python, даже если она преобразована в представление графа.Вы можете увидеть результат преобразования, вызвав tf.autograph.to_code(train.python_function).

Не сообщая весь код, а только часть, связанную с переменной step, вы увидите, что

  def loop_body(loop_vars, loss_1, step_1):
    with ag__.function_scope('loop_body'):
      x, y = loop_vars
      step_1 += 1

все ещеОперация Python (в противном случае это будет step_1.assign_add(1), если шаг 1 был tf.Tensor).

Для получения дополнительной информации об автографе и функции tf.function я предлагаю прочитать статью https://pgaleone.eu/tensorflow/tf.function/2019/03/21/dissecting-tf-function-part-1/, которая легко объясняет, чтопроисходит, когда функция преобразуется.

0 голосов
/ 01 апреля 2019

Хотя это не видно в сгенерированном коде, переменная шага фактически будет автоматически помещена в Tensor с помощью цикла for, который преобразуется в TF while_loop.

Вы можете проверить это, добавив оператор печати:

    loss = train_one_step(model, optimizer, x, y)
    print(step)
    if tf.equal(step % 10, 0):
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...