Решение 1: более общее
Вы можете посмотреть ответ здесь , это в основном та же проблема, что и у вас, с разными размерами.
Описанное решение состоит в том, чтобы создать [?, 28, 28, 4]
-образный тензор indices
, где indices[i, x, y, :] = [i, x, y, self.label_predictions[i]]
, а затем использовать tf.gather_nd
:
self.masks_sigmoids = tf.gather_nd(self.final_conv, indices=indices)
Построение indices
не очень элегантно, как показано в этого ответа (с еще одним измерением для вас), но само по себе просто.
Решение 2: Немного более элегантно и адаптировано к вашей проблеме
Это решение очень похоже на первое, но позволяет избежать создания [x, y]
части indices
. Идея состоит в том, чтобы использовать возможности среза команды collect_nd, чтобы избежать записи [x, y]
в indices
для каждого (i, x, y)
путем транспонирования данных перед их сбором. Я выложу здесь весь код, включая способ создания indices
и способ тестирования:
import numpy as np
import tensorflow as tf
N_CHANNELS = 5
pl=tf.placeholder(dtype=tf.int32, shape=(None, 28, 28, N_CHANNELS))
# Indices we'll use. batch_size = 4 here.
label_predictions = tf.constant([0, 2, 0, 3])
# Indices of shape [?, 2], with indices[i] = [i, self.label_predictions[i]],
# which is easy to do with tf.range() and tf.stack()
indices = tf.stack([tf.range(tf.size(label_predictions)), label_predictions], axis=-1)
# [[0, 0], [1, 2], [2, 0], [3, 3]]
transposed = tf.transpose(pl, perm=[0, 3, 1, 2])
gathered = tf.gather_nd(transposed, indices) # Should be of shape (4, 2, 3)
result = tf.expand_dims(gathered, -1)
initial_value = np.arange(4*28*28*N_CHANNELS).reshape((4, 28, 28, N_CHANNELS))
sess = tf.InteractiveSession()
res = sess.run(result, feed_dict={pl: initial_value})
# print(res)
print("checking validity")
for i in range(4):
for x in range(28):
print(x)
for y in range(28):
assert res[i, x, y, 0] == initial_value[i, x, y, indices[i, 1].eval()]
print("All assertions passed")