TensorFlow: как ограничить / обновить только некоторые строки переменной? - PullRequest
0 голосов
/ 07 ноября 2018

У меня есть переменная матрицы встраивания

embs = tf.get_variable(name="embs", shape=[total, hidden_size])

Я хочу ограничить норму единицей

embs_unit_norm_op = tf.assign(embs, tf.keras.constraints.unit_norm(axis=1)(embs))

Но матрица embs слишком велика, и только некоторые строки изменились на предыдущем шаге применения градиента. Я хочу сократить вычисления, только ограничивая / обновляя эти «активные» вложения.

У меня есть список номеров строк, например, e = [1, 3, 5], такой же, как идентификатор в горячем режиме, используемый при поиске встраивания. Как я могу ограничить / обновить только эти «активные» вложения?

1 Ответ

0 голосов
/ 07 ноября 2018

Вы можете попробовать tf.scatter_update или tf.scatter_nd_update

...