Как суммировать значения по показателям в другом векторе, используя керас / тензор потока? - PullRequest
0 голосов
/ 11 апреля 2019

Я новичок здесь, и у меня есть вопрос к индексированию тензоров в Keras / Tensorflow:

У меня есть вектор длины N, который содержит индексы слов в словаре (индексы могут повторяться).Этот вектор представляет предложение, например (40, 25, 99, 26, 34, 99, 100, 100...) У меня также есть другой вектор или фактически матрица (так как это группа примеров) той же длины N, где каждому слову в исходном векторе назначен вес W_i.Я хочу суммировать веса для конкретного слова по всему предложению, чтобы я мог получить карту от индекса слова до суммы весов для этого слова в предложении, и я хочу сделать это векторизованным способом.Например, предполагая, что предложение равно (1, 2, 3, 4, 5, 3), а веса равны (0, 1, 0.5, 0.1, 0.6, 0.5), я хочу, чтобы результатом было некоторое отображение:

1->0
2->1
3->1
4->0.1
5->0.6

Как мне добиться чего-то подобного без необходимости перебиратькаждый элемент?Я думал о направлении разреженного тензора (так как возможный словарный запас очень большой), но я не знаю, как это эффективно реализовать.Кто-нибудь может помочь?Я в основном хочу реализовать сеть генератора указателей, и эта часть требуется при расчете вероятности копирования входного слова, а не его генерации.

1 Ответ

1 голос
/ 11 апреля 2019

Вам необходимо tf.bincount(), которое подсчитывает количество вхождений каждого значения в массиве целых чисел. Пример:

import tensorflow as tf
import numpy as np

indices_tf = tf.placeholder(shape=(None,None),dtype=tf.int32)
weights_tf = tf.placeholder(shape=(None,None),dtype=tf.float32)

# The returned index counts from 0
result = tf.bincount(indices_tf,weights_tf)

indices_data = np.array([1, 2, 3, 4, 5, 3])
weights_data = np.array([0, 1, 0.5, 0.1, 0.6, 0.5])

with tf.Session() as sess:
    print(sess.run(result, feed_dict={indices_tf:[indices_data],weights_tf:[weights_data]}))
    print(sess.run(result, feed_dict={indices_tf: [indices_data]*2, weights_tf: [weights_data]*2}))

# print
[0.  0.  1.  1.  0.1 0.6]
[0.  0.  2.  2.  0.2 1.2]
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...