Вы можете использовать функцию tf.gather_nd
, но сначала вам нужно объявить class_colors
как переменную тензорного потока. Проверьте следующий пример (размер изображения 50x50, 2 класса):
import tensorflow as tf
predictions = tf.argmax(tf.nn.softmax(tf.random_normal([50,50,2])),axis=-1) #(50,50)
class_colors = tf.Variable([[255,0,0],[0,255,0]]) #(2,3)
prediction_image = tf.gather_nd(class_colors, tf.expand_dims(predictions,axis=-1)) #(50,50,3)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
print(sess.run(prediction_image).shape) #(50, 50, 3)
В качестве альтернативы вы можете вычислить тензор predictions
и использовать пустые операции.