Как выбрать несколько строк с помощью Tensorflow? - PullRequest
0 голосов
/ 09 июля 2020
• 1000 * и это не работает:
x[tf.argmax(x, axis=0), :]

TypeError: только целые числа, срезы (:), многоточие (...), tf.newaxis (None) и скаляр tf .int32 / tf.int64 тензоры - допустимые индексы, получившие

Как правильно это сделать?

Ответы [ 2 ]

1 голос
/ 09 июля 2020

Сегодня я узнал, что tf.gather() делает это.

tf.gather(x, tf.argmax(x, axis=0))
<tf.Tensor: shape=(4, 4), dtype=float32, numpy=
array([[ 1.8891758 ,  0.7073202 , -0.78521085, -2.7632885 ],
       [-1.3851309 ,  1.4023514 , -0.9735394 , -0.81982684],
       [-0.22595228, -0.7155944 ,  0.37807527,  2.2081604 ],
       [-0.22595228, -0.7155944 ,  0.37807527,  2.2081604 ]],
      dtype=float32)>
1 голос
/ 09 июля 2020

Преобразуйте тензор в numpy, прежде чем использовать его в качестве матрицы

x.numpy()[tf.argmax(x, axis=0), :]
...