получить следующий тензор из набора данных в tf. while_loop - PullRequest
0 голосов
/ 04 мая 2018

Я хочу перебирать набор данных, пока не будет выполнено определенное условие, но я не знаю, как "перебрать". Ниже мой код.

import tensorflow as tf

c = tf.constant([1,2,6])
d = tf.data.Dataset.from_tensor_slices((c,))
t = d.make_one_shot_iterator().get_next()

def condition(t):
  return t < 5

def body(t):
  # I don't know what to do here to return the next t
  return [t]

t = tf.while_loop(condition, body, [t])

with tf.Session() as sess:
    print(sess.run([t]))

В ответ на ответ Алекса ниже, ниже приведен более реалистичный пример того, чего я хочу достичь.

import tensorflow as tf

# I want to "merge" the dataset da to dataset db by "backfilling" da.
# So session.run will return [[1,'a'], [1,'x']], then [[5, 'c'],[3, 'y']]
# note that one element from dataset da is skipped, which is what I want to achieve with the while loop.
ta = tf.constant([1,2,5]) 
va = tf.constant(['a','b','c'])
da = tf.data.Dataset.from_tensor_slices((ta, va))

tb = tf.constant([1,3,6])
vb = tf.constant(['x','y','z'])
db = tf.data.Dataset.from_tensor_slices((tb, vb))

ea = da.make_one_shot_iterator().get_next()
eb = db.make_one_shot_iterator().get_next()

def condition(ea, eb):
  return ea[0] < eb[0]

def body(ea, eb):
  # I don't know what to do here to get the next ea.
  return ea, eb

result = tf.while_loop(condition, body, (ea, eb))

with tf.Session() as sess:
  sess.run([result])

Я мог бы переместить логику цикла while в python, как предположил Алекс, но я полагаю, что если оставить ее в графе потока данных, производительность будет выше.

Ответы [ 2 ]

0 голосов
/ 05 мая 2018

Вы можете использовать метод Dataset.filter () для фильтрации элементов набора данных в соответствии с пользовательским предикатом. Вы должны передать функцию фильтра, которая возвращает тензор tf.bool, который оценивается как true, если вы хотите сохранить запись, или false в противном случае.

0 голосов
/ 04 мая 2018

Полагаю, вы еще не понимаете, как работает Tensorflow. Tensorflow tf. while_loop создает цикл while внутри вычислительного графа, добавляя операторы управления для повторного применения частей графа несколько раз, пока не будет выполнено определенное условие. Я бы посоветовал начать читать здесь , чтобы узнать, что такое графики и сессии.

Вы, конечно, не хотите перебирать свой набор данных в вычислительном графе, итерация, которую вы хотите, должна происходить в Python, а не в графе Tensorflow.

Вот как бы вы это сделали:

iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

for i in range(100):
  value = sess.run(next_element)

Это объясняется более подробно здесь .

...