`tf.data.Dataset.map` с переменными зависимостями - PullRequest
1 голос
/ 26 апреля 2019

Я хочу обучить модель с предварительной обработкой, зависящей (недифференцируемой) от параметров модели. Мое текущее решение состоит в том, чтобы использовать tf.compat.v1.data.make_initializable_iterator и повторно инициализировать каждую эпоху, но это имеет следующие проблемы:

  • обновления применяются только один раз в каждую эпоху. Я в порядке с использованием слегка устаревших значений (несколько проходов по сети), но я бы предпочел более быструю частоту обновления, чем эта, которая не сбрасывает процесс пакетирования;
  • Я бы предпочел сделать что-то в стиле 2.0, и, насколько я могу судить, 2.0 не имеет make_initializable_iterator; и
  • Мне бы хотелось, чтобы это было совместимо с нетерпеливым режимом.

Следующая проблема демонстрирует проблему в активном режиме.

import tensorflow as tf
tf.compat.v1.enable_eager_execution()

n = 10
dataset = tf.data.Dataset.from_tensor_slices((tf.range(n),))

# create a simple keras model to implement the map function
vi = tf.keras.layers.Input(shape=(), dtype=tf.float32)
xi = tf.keras.layers.Input(shape=(), dtype=tf.float32)
out = tf.keras.layers.Add()([xi, vi])
model = tf.keras.models.Model(inputs=[xi, vi], outputs=out)

# create a variable-dependant tensor for input
v = tf.Variable(0., dtype=tf.float32)*2


def map_fn(x):
    return model([tf.cast(x, tf.float32), v2])


dataset = dataset.map(map_fn)
for d in dataset:
    print(d.numpy())  # 0, 1, 2, 3, ... as expected
v.assign(100.)
for d in dataset:
    print(d.numpy())  # 0, 1, 2, 3, ..., expected 200, 201, 202, ...
...