Tensorflow - индексация и трансляция типа Numpy. ValueError на разных формах - PullRequest
0 голосов
/ 05 марта 2020

В numpy (мой код будет точным) я бы индексировал и передавал список x формы (5,2) со мной sh, например:

mesh = np.mgrid[0:x.shape[0],0:x.shape[0]] # shape(2,5,5)
out = x[mesh] # shape (2, 5, 5, 2)

В порядке чтобы создать ту же сетку в Tensorflow, применяется следующий код:

with tf.compatv1.Session() as sess:
    x = tf.random.uniform((5,2),minval=0,maxval=10)
    M, _= tf.meshgrid(tf.range(0,tf.shape(x)[0]),tf.range(0,tf.shape(x)[0]))
    MT = tf.transpose(M)
    M2 = tf.expand_dims(M,0)
    M3 = tf.expand_dims(MT,0)
    mesh = tf.concat((M2,M3),0)
    out=x[mesh] # raises error here
    output = tf.concat((out[1],out[2]),axis=-1)
    a = output.eval()

Но когда я ее запускаю, я получаю следующую ошибку:

ValueError: Shape должен иметь ранг 1 но имеет ранг 4 для 'strided_slice_13' (op: 'StridedSlice') с входными формами: [5,2], [1,2,5,5], [1,2,5,5], [1]

Как можно добиться того же поведения, что и у меня в numpy, но в тензорном потоке? (ps: я уже пробовал плитку и collect_nd безуспешно)

...