Как получить текущий global_step в конвейере данных - PullRequest
1 голос
/ 27 марта 2020

Я пытаюсь создать фильтр, который зависит от текущей global_step тренировки, но я не могу сделать это правильно.

Во-первых, я не могу использовать tf.train.get_or_create_global_step() в приведенном ниже коде, потому что он выбросит

ValueError: Variable global_step already exists, disallowed. Did you mean to set reuse=True or reuse=tf.AUTO_REUSE in VarScope? Originally defined at:

Вот почему я попытался извлечь область с помощью tf.get_default_graph().get_name_scope(), и в этом контексте я смог " получить " глобальный шаг:

def filter_examples(example):
    scope = tf.get_default_graph().get_name_scope()

    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
        current_step = tf.train.get_or_create_global_step()

    subtokens_by_step = tf.floor(current_step / curriculum_step_update)
    max_subtokens = min_subtokens + curriculum_step_size * tf.cast(subtokens_by_step, dtype=tf.int32)

    return tf.size(example['targets']) <= max_subtokens


dataset = dataset.filter(filter_examples)

Проблема в том, что он не работает так, как я ожидал. Из того, что я наблюдаю, current_step в приведенном выше коде, кажется, все время равно 0 (я не знаю, просто исходя из моих наблюдений, я предполагаю, что).

Единственное, что кажется чтобы изменить ситуацию, и это звучит странно, возобновить обучение. Я думаю, что также на основе наблюдений, в этом случае current_step будет фактическим текущим шагом обучения в этой точке. Но само значение не будет обновляться по мере продолжения обучения.

Если есть способ получить фактическое значение текущего шага и использовать его в моем фильтре, как указано выше?


Среда

Tensorflow 1.12.1

Ответы [ 2 ]

0 голосов
/ 01 апреля 2020

Как мы обсуждали в комментариях, наличие и обновление собственного счетчика может быть альтернативой использованию переменной global_step. Переменная counter может быть обновлена ​​следующим образом:

op = tf.assign_add(counter, 1)
with tf.control_dependencies(op): 
    # Some operation here before which the counter should be updated

Использование tf.control_dependencies позволяет «прикрепить» обновление counter к пути в вычислительном графе. Затем вы можете использовать переменную counter везде, где вам это нужно.

0 голосов
/ 31 марта 2020

Если вы используете переменные внутри наборов данных, вам нужно переинициализировать итераторы в tf 1.x.

iterator = tf.compat.v1.make_initializable_iterator(dataset)
init = iterator.initializer
tensors = iterator.get_next()

with tf.compat.v1.Session() as sess:
    for epoch in range(num_epochs):
        sess.run(init)
        for example in range(num_examples):
            tensor_vals = sess.run(tensors)
...