Я думаю, вы можете использовать tf.gather
:
tf.gather(tf_a1, tf_a2, axis=0)
# <tf.Tensor 'GatherV2_10:0' shape=(5, 3, 4) dtype=float32>
Воспроизводимый пример на TensorFlow 2.0
tf.__version__
# '2.0.0-beta0'
tf.gather(tf_a1, tf_a2, axis=0)
<tf.Tensor: id=9, shape=(5, 3, 4), dtype=float32, numpy=
array([[[0. , 8.3356 , 0. , 8.8974 ],
[0. , 0. , 6.103182 , 7.330564 ],
[0. , 8.457685 , 8.602337 , 0. ]],
[[0. , 8.3356 , 0. , 8.8974 ],
[9.497023 , 0. , 3.8914037, 0. ],
[0. , 0. , 5.826657 , 8.283971 ]],
[[9.968594 , 8.655439 , 0. , 0. ],
[0. , 0. , 0. , 0. ],
[0. , 0. , 0. , 0. ]],
[[0. , 0. , 6.103182 , 7.330564 ],
[6.609862 , 0. , 3.0614321, 0. ],
[0. , 0. , 5.826657 , 8.283971 ]],
[[0. , 0. , 6.103182 , 7.330564 ],
[9.497023 , 0. , 3.8914037, 0. ],
[0. , 0. , 0. , 0. ]]], dtype=float32)>