TensorFlow, периодическое индексирование (первое измерение) и сортировка - PullRequest
0 голосов
/ 30 мая 2018

У меня есть тензор параметров с формой (?,368,5), а также тензор запроса с формой (?,368).Тензор запросов хранит индексы для сортировки первого тензора.

Требуемый вывод имеет форму: (?,368,5).Поскольку это необходимо для функции потерь в нейронной сети, используемые операции должны оставаться дифференцируемыми.Кроме того, во время выполнения размер первой оси ? соответствует размеру пакета.

До сих пор я экспериментировал с tf.gather и tf.gather_nd, однако tf.gather(params,query) дает тензор с формой (?,368,368,5),

Тензор запроса достигается путем выполнения:

query = tf.nn.top_k(params[:, :, 0], k=params.shape[1], sorted=True).indices

В целом, я пытаюсь отсортировать тензор параметров по первому элементу на третьей оси (для вида расстояния фаски).Наконец, стоит упомянуть, что я работаю с фреймворком Keras.

1 Ответ

0 голосов
/ 30 мая 2018

Вам нужно добавить индексы первого измерения к query, чтобы использовать его с tf.gather_nd.Вот способ сделать это:

import tensorflow as tf
import numpy as np

np.random.seed(100)

with tf.Graph().as_default(), tf.Session() as sess:
    params = tf.placeholder(tf.float32, [None, 368, 5])
    query = tf.nn.top_k(params[:, :, 0], k=params.shape[1], sorted=True).indices
    n = tf.shape(params)[0]
    # Make tensor of indices for the first dimension
    ii = tf.tile(tf.range(n)[:, tf.newaxis], (1, params.shape[1]))
    # Stack indices
    idx = tf.stack([ii, query], axis=-1)
    # Gather reordered tensor
    result = tf.gather_nd(params, idx)
    # Test
    out = sess.run(result, feed_dict={params: np.random.rand(10, 368, 5)})
    # Check the order is correct
    print(np.all(np.diff(out[:, :, 0], axis=1) <= 0))
    # True
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...