использование while_loop над тензором для создания маски в тензорном потоке - PullRequest
0 голосов
/ 18 июня 2019

Я хочу создать маску с итерацией по тензору.У меня есть этот код:

import tensorflow as tf

out = tf.Variable(tf.zeros_like(alp, dtype=tf.int32))

rows_tf = tf.constant (
[[1, 2, 5],
 [1, 2, 5],
 [1, 2, 5],
 [1, 4, 6],
 [1, 4, 6],
 [2, 3, 6],
 [2, 3, 6],
 [2, 4, 7]])

columns_tf = tf.constant(
[[1],
 [2],
 [3],
 [2],
 [3],
 [2],
 [3],
 [2]])

Я хочу перебрать rows_tf и соответственно columns_tf, чтобы создать маску над out.

. Например, она замаскируетИндекс на [1,1] [2,1] and [5,1] в тензоре out равен 1.

для второй строки в rows_tf Индексы на [1,2] [2,2] [5,2] в тензоре выхода будут установлены на 1 и так далее длявсего 8 строк

Пока я это сделал, хотя он не работает успешно:

body = lambda k, i: (tf.add(out[rows_tf[i][k]][columns_tf[i][i]], 1)) # find the corresponding element in out tensor and add 1 to it (0+1=1)
k = 0
n2, m2 = rows_tf.shape
for i in tf.range(0,n2): # loop through rows in rows_tf    
    cond = lambda k, _: tf.less(k, m2) #this check to go over the columns in rows_tf
    tf.while_loop(cond, body, (k, i))

выдает эту ошибку:

TypeError: Cannot iterate over a scalar tensor. 
in this while cond(*loop_vars):

Я ушелчерез несколько ссылок, а именно здесь , чтобы убедиться, что я следую инструкции, но не смог исправить эту.

Спасибо за помощь

1 Ответ

2 голосов
/ 18 июня 2019

Вы можете сделать это без цикла, используя tf.scatter_nd следующим образом:

import tensorflow as tf

with tf.Graph().as_default(), tf.Session() as sess:
    out = tf.zeros([10, 4], dtype=tf.int32)
    rows_tf = tf.constant(
        [[1, 2, 5],
         [1, 2, 5],
         [1, 2, 5],
         [1, 4, 6],
         [1, 4, 6],
         [2, 3, 6],
         [2, 3, 6],
         [2, 4, 7]], dtype=tf.int32)
    columns_tf = tf.constant(
        [[1],
         [2],
         [3],
         [2],
         [3],
         [2],
         [3],
         [2]], dtype=tf.int32)
    # Broadcast columns
    columns_bc = tf.broadcast_to(columns_tf, tf.shape(rows_tf))
    # Scatter values to indices
    scatter_idx = tf.stack([rows_tf, columns_bc], axis=-1)
    mask = tf.scatter_nd(scatter_idx, tf.ones_like(rows_tf, dtype=tf.bool), tf.shape(out))
    print(sess.run(mask))

Выход:

[[False False False False]
 [False  True  True  True]
 [False  True  True  True]
 [False False  True  True]
 [False False  True  True]
 [False  True  True  True]
 [False False  True  True]
 [False False  True False]
 [False False False False]
 [False False False False]]

В качестве альтернативы, вы также можете сделать этоиспользуя только логические операции:

import tensorflow as tf

with tf.Graph().as_default(), tf.Session() as sess:
    out = tf.zeros([10, 4], dtype=tf.int32)
    rows_tf = tf.constant(
        [[1, 2, 5],
         [1, 2, 5],
         [1, 2, 5],
         [1, 4, 6],
         [1, 4, 6],
         [2, 3, 6],
         [2, 3, 6],
         [2, 4, 7]], dtype=tf.int32)
    columns_tf = tf.constant(
        [[1],
         [2],
         [3],
         [2],
         [3],
         [2],
         [3],
         [2]], dtype=tf.int32)
    # Compare indices
    row_eq = tf.equal(tf.range(out.shape[0])[:, tf.newaxis],
                      rows_tf[..., np.newaxis, np.newaxis])
    col_eq = tf.equal(tf.range(out.shape[1])[tf.newaxis, :],
                      columns_tf[..., np.newaxis, np.newaxis])
    # Aggregate
    mask = tf.reduce_any(row_eq & col_eq, axis=[0, 1])
    print(sess.run(mask))
    # Same as before

Однако в принципе это потребует больше памяти.

...