Выберите вес действия из модели тензорного потока - PullRequest
0 голосов
/ 08 июня 2018

У меня есть маленькая модель, используемая в контексте обучения с подкреплением.

Я могу ввести 2-й тензор состояний и получить 2-й тензор весов действий.

Допустим, я ввелдва состояния, и я получаю следующие веса действий:

[[0.1, 0.2],
 [0.3, 0.4]]

Теперь у меня есть еще один 2-мерный тензор, у которого есть номер действия, из которого я хочу получить веса:

[[1],
 [0]]

Какмогу ли я использовать этот тензор, чтобы получить вес действий?

В этом примере я хотел бы получить:

[[0.2],
 [0.3]]

Ответы [ 2 ]

0 голосов
/ 09 июня 2018

Простой способ сделать это - сжать размеры индексов, поэлементно умножить на соответствующий вектор с одной горячей точкой, а затем расширить размеры позже.

import tensorflow as tf

weights = tf.constant([[0.1, 0.2], [0.3, 0.4]])
indices = tf.constant([[1], [0]])
# Reduce from 2d (2, 1) to 1d (2,)
indices1d = tf.squeeze(indices)
# One-hot vector corresponding to the indices. shape (2, 2)
action_one_hot = tf.one_hot(indices=indices1d, depth=weights.shape[1])
# Element-wise multiplication and sum across axis 1 to pick the weight. Shape (2,)
action_taken_weight = tf.reduce_sum(action_one_hot * weights, axis=1)
# Expand the dimension back to have a 2d. Shape (2, 1)
action_taken_weight2d = tf.expand_dims(action_taken_weight, axis=1)

sess = tf.InteractiveSession()
print("weights\n", sess.run(weights))
print("indices\n", sess.run(indices))
print("indices1d\n", sess.run(indices1d))
print("action_one_hot\n", sess.run(action_one_hot))
print("action_taken_weight\n", sess.run(action_taken_weight))
print("action_taken_weight2d\n", sess.run(action_taken_weight2d))

Должен дать вам следующий вывод:

weights
 [[0.1 0.2]
 [0.3 0.4]]
indices
 [[1]
 [0]]
indices1d
 [1 0]
action_one_hot
 [[0. 1.]
 [1. 0.]]
action_taken_weight
 [0.2 0.3]
action_taken_weight2d
 [[0.2]
 [0.3]]

Примечание: вы также можете сделать action_taken_weight = tf.reshape(action_taken_weight, tf.shape(indices)) вместо expand_dims.

0 голосов
/ 08 июня 2018

Аналогично Tensorflow tf.gather с параметром оси , индексы здесь обрабатываются немного иначе:

a = tf.constant( [[0.1, 0.2], [0.3, 0.4]])
indices = tf.constant([[1],[0]])

# convert to full indices
full_indices = tf.stack([tf.range(indices.shape[0])[...,tf.newaxis], indices], axis=2)

# gather
result = tf.gather_nd(a,full_indices)

with tf.Session() as sess:
   print(sess.run(result))
#[[0.2]
#[0.3]]
...