Удаление определенных строк из тензора в тензорном потоке без использования tf.RaggedTensor - PullRequest
1 голос
/ 09 июля 2020

Данные тензорные данные

   [[[ 0.,  0.],
    [ 1.,  1.],
    [-1., -1.]],

   [[-1., -1.],
    [ 4.,  4.],
    [ 5.,  5.]]]

Я хочу удалить [-1, -1] и получить

   [[[ 0.,  0.],
    [ 1.,  1.]],

   [[ 4.,  4.],
    [ 5.,  5.]]]

Как получить указанное выше, не используя рваную функцию в тензорном потоке?

Ответы [ 2 ]

0 голосов
/ 09 июля 2020

Вы можете попробовать это:

x = tf.constant(
      [[[ 0.,  0.],
      [ 1.,  1.],
      [-1., -2.]],

     [[-1., -2.],
      [ 4.,  4.],
      [ 5.,  5.]]])

mask = tf.math.not_equal(x, np.array([-1, -1]))

result = tf.boolean_mask(x, mask)
shape = tf.shape(x)
result = tf.reshape(result, (shape[0], -1, shape[2]))
0 голосов
/ 09 июля 2020

Вы можете сделать это так:

import tensorflow as tf
import numpy as np

data = [[[ 0.,  0.],
         [ 1.,  1.],
         [-1., -1.]],
        [[-1., -1.],
         [ 4.,  4.],
         [ 5.,  5.]]]
data = tf.constant(data)
indices = tf.math.not_equal(data, tf.constant([-1., -1.]))
res = data[indices]

shape = tf.shape(data)
total = tf.reduce_sum(
    tf.cast(tf.math.logical_and(indices[:, :, 0], indices[:, :, 1])[0], tf.int32))

res = tf.reshape(res, (shape[0], total, shape[-1]))

with tf.Session() as sess:
    print(sess.run(res))
# [[[0. 0.]
#   [1. 1.]]

#  [[4. 4.]
#   [5. 5.]]]
...