Я попытался запустить программу с открытым github
https://github.com/jyh2986/Active-Shift-TF
Так как среда отличается следующим образом, я изменил некоторые параметры.
[среда] CentOS 17.4 tenorflow-gpu 2.0.0 CUDA 10.0 CUDNN7.4.2 gcc8.3.2
Используя tf_upgrade_v2, можно запустить test_forward_ASL.py, и я могу выполнить test_forward_ASL.py Однако, когда я запускаю test_gradient_ASL.py
Я впервые встретил следующую ошибку. Предупреждение] ПРЕДУПРЕЖДЕНИЕ: tenorflow: From test_gradient_ASL-TF20.py:45: compute_gradient_error (fromensorflow. python .ops.gradient_checker) устарело и будет удалено в следующей версии. Инструкции по обновлению: используйте tf.test.compute_gradient в 2.0, в котором улучшена поддержка функций. Обратите внимание, что эти две версии используются по-разному, поэтому требуется изменение кода.
[ошибка компиляции] для строки кода 45, 73, ошибка атрибута err = Gradge_checker.compute_gradient_error (a, arr.shape, result, result. get_shape (). as_list (), x_init_value = arr)
err = grad_checker.compute_gradient_error (c, shift.shape, результат, result.get_shape (). as_list (), x_init_value = shift)
AttributeError: Tensor.graph не имеет смысла, когда включено активное выполнение. Таким образом, я изменил функцию градиент_checker.compute_gradient_error, как показано ниже.
strides = [1, 1, stride_h, stride_w]
paddings = [0, 0, pad_h, pad_w]
strides_tf = tf.constant(strides, dtype=tf.int64)
paddings_tf = tf.constant(paddings, dtype=tf.int64)
theoretical, numerical = tf.test.compute_gradient(active_shift2d_op.active_shift2d_op, [a, c, strides_tf, paddings_tf])
Однако я встретил еще одну ошибку
Traceback (последний вызов был последним): Файл "", строка 54, в active_shift2d_op tenorflow. python .eager.core._FallbackException: ожидание значения int64_t для шагов attr, получило тензор потока. python .framework.ops.EagerTensor
TypeError: Ожидаемый список для аргумента "шагов" к "active_shift2d_op" Op, Op, not.
файл библиотеки сделан c ++, поэтому он имеет тип данных int64_t, однако тензор потока не имеет тип int64_t. Как я могу устранить ошибку, удалив разницу в типах?
Заранее спасибо ~