Сброс графика по умолчанию при выходе из tf.Session () в модульных тестах - PullRequest
1 голос
/ 22 июня 2019
  • В конце каждого модульного теста я вызываю tf.reset_default_graph(), чтобы очистить график по умолчанию.
  • Однако, если модульный тест не пройден, график не очищается. Это также приводит к сбою следующего модульного теста.

Как очистить график при выходе из контекста tf.Session()?

Пример (pytest):

import tensorflow as tf


def test_1():
    x = tf.get_variable('x', initializer=1)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        print(4 / 0)
        print(sess.run(x))


def test_2():
    x = tf.get_variable('x', initializer=1)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        print(sess.run(x))

Ответы [ 3 ]

2 голосов
/ 23 июня 2019

Я предлагаю использовать инструменты pytest предлагает:

@pytest.fixture(autouse=True)
def reset():
    yield
    tf.reset_default_graph()

Прибор будет автоматически вызываться до и после каждого теста (флаг autouse), код до / после выполнения yieldдо / после теста.Таким образом, тесты из вашего вопроса будут работать без каких-либо изменений, и вы будете следовать принципу СУХОЙ, отказываясь писать дублированный код в каждом тесте.Другой пример:

@pytest.fixture(autouse=True)
def init_graph():
    with tf.Graph().as_default():
        yield

создаст новый график для каждого теста перед его выполнением.

Приспособления в pytest очень мощные и могут полностью исключить повторения кода при правильном использовании.Например, тесты из вашего вопроса эквивалентны:

@pytest.fixture
def x():
    return tf.get_variable('x', initializer=1)


@pytest.fixture
def session(x):
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        yield sess


@pytest.fixture(autouse=True)
def init_graph():
    with tf.Graph().as_default():
        yield


def test_1(session, x):
    print(4 / 0)
    print(session.run(x))


def test_2(session, x):
    print(session.run(x))

Если вы хотите узнать больше, начните с приспособлений pytest: явных, модульных, масштабируемых .

1 голос
/ 23 июня 2019

Прямое решение состоит в том, чтобы использовать предложение try ... finally (на самом деле может быть лучше поместить предложение в код, который запускает модульные тесты, а не непосредственно в модульные тесты):

def test_1():
    x = tf.get_variable('x', initializer=1)
    try:
       with tf.Session() as sess:
           sess.run(tf.global_variables_initializer())
           print(4 / 0)
           print(sess.run(x))
    finally:
       tf.reset_default_graph()


def test_2():
    x = tf.get_variable('x', initializer=1)
    try:
        with tf.Session() as sess:
           sess.run(tf.global_variables_initializer())
            print(sess.run(x))
    finally:
        tf.reset_default_graph()

Другое чистое решение - использовать один график для каждого модульного теста, как показано в предыдущем ответе. Вот альтернативное решение, основанное на этой идее с немного упрощенным синтаксисом:

def test_1():
    with tf.Graph().as_default(), tf.Session() as sess:
        x = tf.get_variable('x', initializer=1)

        sess.run(tf.global_variables_initializer())
        print(4 / 0)
        print(sess.run(x))


def test_2():
    with tf.Graph().as_default(), tf.Session() as sess:
        x = tf.get_variable('x', initializer=1)

        sess.run(tf.global_variables_initializer())
        print(sess.run(x))

Подобно первому решению, оператор with может также помещаться вокруг кода, который запускает модульные тесты, а не повторяться в каждом модульном тесте.

1 голос
/ 22 июня 2019

Хотелось бы что-нибудь подобное?

import tensorflow as tf


def test_1():
    G = tf.Graph()
    with G.as_default():
        x = tf.get_variable('x', initializer=1)
        with tf.Session() as sess:
            sess.run(tf.initializers.global_variables())
            print(sess.run(x))
            print(4 / 0)


def test_2():
    G = tf.Graph()
    with G.as_default():
        x = tf.get_variable('x', initializer=1)
        with tf.Session() as sess:
            sess.run(tf.initializers.global_variables())
            print(sess.run(x))
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...