Я рассмотрел различные способы нарезки в тензорном потоке, а именно tf.gather
и tf.gather_nd
.
В tf.gather он просто нарезает измерение, а также в tf.gather_nd
он просто принимает один indices
для применения к входному тензору.
Мне нужно другое, я хочу нарезать входной тензор, используя два разных тензора: один разрезает по строкам, второй разрезает по столбцу, и они не обязательно имеют одинаковую форму.
Например:
предположим, что это мой входной тензор, в котором я хочу извлечь его часть.
input_tf = tf.Variable([ [9.968594, 8.655439, 0., 0. ],
[0., 8.3356, 0., 8.8974 ],
[0., 0., 6.103182, 7.330564 ],
[6.609862, 0., 3.0614321, 0. ],
[9.497023, 0., 3.8914037, 0. ],
[0., 8.457685, 8.602337, 0. ],
[0., 0., 5.826657, 8.283971 ],
[0., 0., 0., 0. ]])
второй:
rows_tf = tf.constant (
[[1, 2, 5],
[1, 2, 5],
[1, 2, 5],
[1, 4, 6],
[1, 4, 6],
[2, 3, 6],
[2, 3, 6],
[2, 4, 7]])
Третий тензор:
columns_tf = tf.constant(
[[1],
[2],
[3],
[2],
[3],
[2],
[3],
[2]])
Теперь я хочу нарезать input_tf
, используя rows_tf
и columns_tf
. индекс [1 2 5]
в строках и [1]
в columns_tf
. Опять строки [1 2 5]
с [2]
в columns_tf
.
или [1 4 6]
с [2]
.
В целом, каждый индекс в rows_tf
с таким же индексом в columns_tf
будет извлекать часть из input_tf
.
Итак, ожидаемый результат будет:
[[8.3356, 0., 8.457685 ],
[0., 6.103182, 8.602337 ],
[8.8974, 7.330564, 0. ],
[0., 3.8914037, 5.826657 ],
[8.8974, 0., 8.283971 ],
[6.103182, 3.0614321, 5.826657 ],
[7.330564, 0., 8.283971 ],
[6.103182, 3.8914037, 0. ]]
например, здесь первая строка [8.3356, 0., 8.457685 ]
извлекается с помощью
rows in rows_tf [1,2,5] and column in columns_tf [1](row 1 and column 1, row 2 and column 1 and row 5 and column 1 in the input_tf)
Было несколько вопросов относительно нарезки в тензорном потоке, хотя они использовали tf.gather
или tf.gather_nd
и tf.stack
, которые не дали желаемого результата.
Нет необходимости упоминать, что в numpy
мы можем легко сделать это, позвонив по номеру: input_tf[rows_tf, columns_tf]
.
Я также посмотрел на эту расширенную индексацию, которая пытается имитировать расширенную индексацию, доступную в numpy, однако она по-прежнему не похожа на numpy Flexible https://github.com/SpinachR/ubuntuTest/blob/master/beautifulCodes/tensorflow_advanced_index_slicing.ipynb
Это то, что я пробовал, и это не правильно:
tf.gather(tf.transpose(tf.gather(input_tf,rows_tf)),columns_tf)
размерность этого кода равна (8, 1, 3, 8)
, что совершенно неверно.
Заранее спасибо!