Как обрабатывать переменные в tf.estimator? - PullRequest
0 голосов
/ 18 апреля 2019

Я пытаюсь реализовать модель из здесь с помощью API TensorFlow Estimator.Однако у меня возникли некоторые проблемы с обработкой переменных, которые обновляются вручную.В принципе, проблема в том, что, насколько я могу судить, в любое время, когда model_fn вызывается, переменные повторно инициализируются.

def setup_variables(batch_size, params):
    # Mask describing ended sessions, true if session ended
    ended_sessions_mask = tf.get_variable(
        'ended_sessions_mask',
        shape=(batch_size,),
        initializer=tf.zeros_initializer(),
        trainable=False,
        dtype=tf.bool)

    # Mask describing ended users, true if not more user events
    ended_users_mask = tf.get_variable(
        'ended_users_mask',
        shape=(batch_size,),
        initializer=tf.zeros_initializer(),
        trainable=False,
        dtype=tf.bool)

def model_fn(features, labels, mode, params):

    (ended_sessions_mask,
        ending_sessions_mask) = setup_variables(batch_size, params)


    # Ended sessions where the user did not change
    ended_sessions_same_user_mask = tf.logical_and(
        ended_sessions_mask,
        tf.logical_not(ended_users_mask)
    )

    # Get user_hidden_states to update
    # The hidden states to update are the ones where a session ended
    # but the user has stayed the same
    # The other hidden states are 0
    user_hidden_states = tf.map_fn(
        lambda x: tf.cond(
            x[1],
            true_fn=lambda: tf.nn.embedding_lookup(user_embeddings, x[0]),
            false_fn=lambda: tf.zeros(params['user_rnn_units'])
        ),
        [
            features['UserEmbeddingId'],
            ended_sessions_same_user_mask
        ],
        dtype=tf.float32,
        name='get_user_hidden_states_to_update')

...


    # Compute new mask for ended sessions
    ended_sessions_mask = tf.cast(
        tf.where(
            tf.equal(features['ProductId'], -1),
            tf.ones(tf.shape(ended_sessions_mask)),
            tf.zeros(tf.shape(ended_sessions_mask)),
            name='compute_ended_sessions'),
        tf.bool)

    # Compute new mask for ended users
    ended_users_mask = tf.cast(
        tf.where(
            tf.equal(features['UserId'], -1),
            tf.ones(tf.shape(ended_users_mask)),
            tf.zeros(tf.shape(ended_users_mask)),
            name='compute_ended_users'),
        tf.bool)

В принципе последовательность функций модели должна быть следующей:

  • Обновление пользовательских вложений на основе масок, маски которых вычисляются на предыдущем шаге
  • Применение модели, вычисление потерь и т. Д.
  • Вычисление новых масок, которые будут использоваться вследующий шаг.

Т.е. маски описывают завершенные сеансы и пользователей с предыдущего шага.

Насколько я понимаю, это должно быть возможно при использовании get_variable, с тех порПеременная создается только в том случае, если ее не было раньше.Но каждый раз, когда я называю model_fn, маски повторно инициализируются нулями.Я ожидаю, что маски будут иметь значения, которые были рассчитаны последними, но это не так.

...