Я пытаюсь реализовать модель из здесь с помощью API TensorFlow Estimator.Однако у меня возникли некоторые проблемы с обработкой переменных, которые обновляются вручную.В принципе, проблема в том, что, насколько я могу судить, в любое время, когда model_fn вызывается, переменные повторно инициализируются.
def setup_variables(batch_size, params):
# Mask describing ended sessions, true if session ended
ended_sessions_mask = tf.get_variable(
'ended_sessions_mask',
shape=(batch_size,),
initializer=tf.zeros_initializer(),
trainable=False,
dtype=tf.bool)
# Mask describing ended users, true if not more user events
ended_users_mask = tf.get_variable(
'ended_users_mask',
shape=(batch_size,),
initializer=tf.zeros_initializer(),
trainable=False,
dtype=tf.bool)
def model_fn(features, labels, mode, params):
(ended_sessions_mask,
ending_sessions_mask) = setup_variables(batch_size, params)
# Ended sessions where the user did not change
ended_sessions_same_user_mask = tf.logical_and(
ended_sessions_mask,
tf.logical_not(ended_users_mask)
)
# Get user_hidden_states to update
# The hidden states to update are the ones where a session ended
# but the user has stayed the same
# The other hidden states are 0
user_hidden_states = tf.map_fn(
lambda x: tf.cond(
x[1],
true_fn=lambda: tf.nn.embedding_lookup(user_embeddings, x[0]),
false_fn=lambda: tf.zeros(params['user_rnn_units'])
),
[
features['UserEmbeddingId'],
ended_sessions_same_user_mask
],
dtype=tf.float32,
name='get_user_hidden_states_to_update')
...
# Compute new mask for ended sessions
ended_sessions_mask = tf.cast(
tf.where(
tf.equal(features['ProductId'], -1),
tf.ones(tf.shape(ended_sessions_mask)),
tf.zeros(tf.shape(ended_sessions_mask)),
name='compute_ended_sessions'),
tf.bool)
# Compute new mask for ended users
ended_users_mask = tf.cast(
tf.where(
tf.equal(features['UserId'], -1),
tf.ones(tf.shape(ended_users_mask)),
tf.zeros(tf.shape(ended_users_mask)),
name='compute_ended_users'),
tf.bool)
В принципе последовательность функций модели должна быть следующей:
- Обновление пользовательских вложений на основе масок, маски которых вычисляются на предыдущем шаге
- Применение модели, вычисление потерь и т. Д.
- Вычисление новых масок, которые будут использоваться вследующий шаг.
Т.е. маски описывают завершенные сеансы и пользователей с предыдущего шага.
Насколько я понимаю, это должно быть возможно при использовании get_variable
, с тех порПеременная создается только в том случае, если ее не было раньше.Но каждый раз, когда я называю model_fn, маски повторно инициализируются нулями.Я ожидаю, что маски будут иметь значения, которые были рассчитаны последними, но это не так.