У меня есть следующие настройки:
B = Batchsize
N = Number of Objects
T = Number of Targets
L = Length of feature embedding per target
Для каждого объекта я хочу присутствовать на цели. Модель решает, к какой цели следует обратиться, беря argmax вектора attention_weights
с shape=[B,N,T]
:
pick = tf.math.argmax(attention_weights, axis=2)
Таким образом, pick
имеет форму [B,N]
и каждыйзапись является индексом. Теперь я хотел бы использовать эти индексы для доступа к нужным целевым функциям
target_features.set_shape(target_features, [B, D, L])
features_picked = tf.some_function(target_features, pick)
Мой вопрос: что использовать для tf.some_function
? Это связано с tf.gather? У меня есть проблема с выяснением, как использовать это в этом случае.
Большое спасибо заранее за любую помощь!
PS: я использую tf. version = '1.130,1'