Можно ли вызывать `tape.watch (x)`, когда `x` уже является` tf.Variable` в TensorFlow? - PullRequest
0 голосов
/ 01 февраля 2019

Рассмотрим следующую функцию

def foo(x):
  with tf.GradientTape() as tape:
    tape.watch(x)

    y = x**2 + x + 4

  return tape.gradient(y, x)

Вызов tape.watch(x) необходим, если функция вызывается скажем как foo(tf.constant(3.14)), но не когда она передается в переменную напрямую, например foo(tf.Variable(3.14)).

Теперь мой вопрос: безопасен ли вызов tape.watch(x) даже в случае, когда tf.Variable передается напрямую?Или произойдет какая-то странность из-за того, что переменная уже отслеживается автоматически, а затем снова просматривается вручную?Как правильно написать такие общие функции, которые могут принимать как tf.Tensor, так и tf.Variable?

1 Ответ

0 голосов
/ 01 февраля 2019

Это должно быть безопасно.С одной стороны, документация tf.GradientTape.watch гласит:

Гарантирует, что tensor отслеживается этой лентой.

"«Гарантирует», по-видимому, подразумевает, что он обязательно проследит, если это не так.Фактически, документация не дает никаких указаний на то, что использование его дважды над одним и тем же объектом должно быть проблемой (хотя это не повредит, если они сделают это явным).

Но в любом случае мы можем копатьв исходный код, чтобы проверить.В конце концов, вызов watch для переменной (ответ заканчивается тем же, если это не переменная, но путь немного расходится) сводится к методу WatchVariable класса GradientTape вC ++:

void WatchVariable(PyObject* v) {
  tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(v, "handle"));
  if (handle == nullptr) {
    return;
  }
  tensorflow::int64 id = FastTensorId(handle.get());

  if (!PyErr_Occurred()) {
    this->Watch(id);
  }

  tensorflow::mutex_lock l(watched_variables_mu_);
  auto insert_result = watched_variables_.emplace(id, v);

  if (insert_result.second) {
    // Only increment the reference count if we aren't already watching this
    // variable.
    Py_INCREF(v);
  }
}

Вторая половина метода показывает, что отслеживаемая переменная добавляется к watched_variables_, то есть std::set, поэтому повторное добавление чего-либо ничего не даст.Это на самом деле проверяется позже, чтобы убедиться в правильности подсчета ссылок Python.Первая половина в основном называет Watch:

template <typename Gradient, typename BackwardFunction, typename TapeTensor>
void GradientTape<Gradient, BackwardFunction, TapeTensor>::Watch(
    int64 tensor_id) {
  tensor_tape_.emplace(tensor_id, -1);
}

tensor_tape_ - это карта (в частности, tensorflow::gtl:FlatMap, почти такая же, как стандартКарта C ++), поэтому, если tensor_id уже существует, это не будет иметь никакого эффекта.

Так что, хотя это явно не указано, все предполагают, что с ним не должно быть проблем.

...