Как проверить вычисление градиента с помощью модульного теста - PullRequest
0 голосов
/ 06 октября 2019

Я пытаюсь протестировать пользовательский слой. Написание теста прямой связи было довольно простым, но я понятия не имею, как реализовать тест для градиентов.

Я обнаружил, что в пакете теста тензорного потока есть функция, называемая compute_gradient, но я не могу 'Не могу найти какой-либо ресурс о том, как его использовать. Документация в основном гласит, что она вычисляет градиенты (якобианскую матрицу), что мне и нужно, но когда я пытаюсь ее использовать, я получаю EagerTensor is not callable

Это код, который дает сбой:

class LayerGradientTest(tf.test.TestCase):
    def test_gradient(self):
        with self.test_session():
            input_tensor = [...]
            expected_output = [...]
            expected_gradients = [...]
            test_layer = MyLayer()
            output_tensor = test_layer(tf.Variable(input_tensor))
            grad_computed = tf.test.compute_gradient(output_tensor, expected_output)
            self.assertAllEqual(grad_computed, expected_gradients)

Я бы ожидал, что тест либо пройдет, либо провалится в утверждении, но я получу TypeError: 'tensorflow.python.framework.ops.EagerTensor' object is not callable от compute_gradient

Редактировать: Конечно, градиентам нужна функция потерь, яЯ идиот ... но все же вывод имеет глупую форму. Теперь я использую следующий код:

function = tf.losses.mean_squared_error
grad_computed = tf.test.compute_gradient(function, [output_tensor, expected_output])

Формами ввода для моего слоя являются (1, 2, 2, 3) и (1, 2, 2, 2), но градиенты являются zip-объектом4 матрицы 12x4, но поскольку у меня нет параметров в слое, я ожидал получить значения ошибок на входе. Пожалуйста, поправьте меня, если я снова что-то напутал. Просто чтобы прояснить, мой слой просто преобразовывает данные и поэтому сам по себе не имеет градиентов, но должен правильно распространять их назад.

1 Ответ

0 голосов
/ 11 октября 2019

Проверьте, включено ли активное выполнение, если нет, попробуйте следующий код при импорте

import tensorflow as tf
tf.enable_eager_execution()
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...