Это должно быть безопасно.С одной стороны, документация tf.GradientTape.watch
гласит:
Гарантирует, что tensor
отслеживается этой лентой.
"«Гарантирует», по-видимому, подразумевает, что он обязательно проследит, если это не так.Фактически, документация не дает никаких указаний на то, что использование его дважды над одним и тем же объектом должно быть проблемой (хотя это не повредит, если они сделают это явным).
Но в любом случае мы можем копатьв исходный код, чтобы проверить.В конце концов, вызов watch
для переменной (ответ заканчивается тем же, если это не переменная, но путь немного расходится) сводится к методу WatchVariable
класса GradientTape
вC ++:
void WatchVariable(PyObject* v) {
tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(v, "handle"));
if (handle == nullptr) {
return;
}
tensorflow::int64 id = FastTensorId(handle.get());
if (!PyErr_Occurred()) {
this->Watch(id);
}
tensorflow::mutex_lock l(watched_variables_mu_);
auto insert_result = watched_variables_.emplace(id, v);
if (insert_result.second) {
// Only increment the reference count if we aren't already watching this
// variable.
Py_INCREF(v);
}
}
Вторая половина метода показывает, что отслеживаемая переменная добавляется к watched_variables_
, то есть std::set
, поэтому повторное добавление чего-либо ничего не даст.Это на самом деле проверяется позже, чтобы убедиться в правильности подсчета ссылок Python.Первая половина в основном называет Watch
:
template <typename Gradient, typename BackwardFunction, typename TapeTensor>
void GradientTape<Gradient, BackwardFunction, TapeTensor>::Watch(
int64 tensor_id) {
tensor_tape_.emplace(tensor_id, -1);
}
tensor_tape_
- это карта (в частности, tensorflow::gtl:FlatMap
, почти такая же, как стандартКарта C ++), поэтому, если tensor_id
уже существует, это не будет иметь никакого эффекта.
Так что, хотя это явно не указано, все предполагают, что с ним не должно быть проблем.