tenorflow - реализация опыта воспроизведения памяти с помощью API оценки - PullRequest
0 голосов
/ 21 ноября 2018

Я пытаюсь реализовать память воспроизведения опыта с помощью tf.estimator.Estimator API.Однако я не уверен, каков наилучший способ достижения результата, который работает как минимум для всех режимов (TRAIN, EVALUATE, PREDICT).Я попробовал следующее:

  • Реализация памяти с tf.Variable, что вызывает проблемы с пакетной обработкой и входным конвейером (я не могу ввести пользовательский опыт в фазе тестирования или прогнозирования)

и в настоящее время попробуйте:

  • Реализация памяти вне tf.Graph.Установите значения после каждого запуска с помощью tf.train.SessionRunHook.Загрузите опыт с tf.data.Dataset.from_generator() во время обучения и тестирования.Управляйте состоянием по своему усмотрению.

Я ошибаюсь в нескольких моментах и ​​начинаю полагать, что API tf.estimator.Estimator не предоставляет мне необходимых интерфейсов для простой записи этого.

Некоторый код (первый подход, который завершается с ошибкой batch_size, поскольку он фиксирован для нарезки опыта, я не могу использовать модель для прогнозирования или оценки):

 def model_fn(self, features, labels, mode, params):
    batch_size = features["matrix"].get_shape()[0].value

    # get prev_exp
    if mode == tf.estimator.ModeKeys.TRAIN:
        erm = tf.get_variable("erm", shape=[30000, 10], initializer=tf.constant_initializer(self.erm.initial_train_erm()), trainable=False)
        prev_exp = tf.slice(erm, [features["index"][0], 0], [batch_size, 10])

    # model
    pred = model(features["matrix"], prev_exp, params) 

Однако: этобыло бы лучше, если бы внутри этой функции было написано.Но затем я должен управлять ошибками за пределами графика, а также записать свой новый опыт с SessionRunHook.Есть ли лучший способ или я что-то упустил?

1 Ответ

0 голосов
/ 22 ноября 2018

Я решил свою проблему, внедрив ERM вне графика, вернув его обратно во входной конвейер с помощью tf.data.Dataset.from_generator () и выполнив обратную запись с помощью SessionRunHooks.Да, довольно утомительно, но это работает.

...