У меня есть тензор параметров с формой (?,368,5)
, а также тензор запроса с формой (?,368)
.Тензор запросов хранит индексы для сортировки первого тензора.
Требуемый вывод имеет форму: (?,368,5)
.Поскольку это необходимо для функции потерь в нейронной сети, используемые операции должны оставаться дифференцируемыми.Кроме того, во время выполнения размер первой оси ?
соответствует размеру пакета.
До сих пор я экспериментировал с tf.gather
и tf.gather_nd
, однако tf.gather(params,query)
дает тензор с формой (?,368,368,5)
,
Тензор запроса достигается путем выполнения:
query = tf.nn.top_k(params[:, :, 0], k=params.shape[1], sorted=True).indices
В целом, я пытаюсь отсортировать тензор параметров по первому элементу на третьей оси (для вида расстояния фаски).Наконец, стоит упомянуть, что я работаю с фреймворком Keras
.