Я воспроизвел что-то вроде того, что вы сказали, используя строковые данные:
import tensorflow as tf
def string_data(s):
return tf.sparse.to_dense(tf.strings.split([s]), default_value='')[0]
data = [' '.join(['neg'] * 30), ' '.join(['pos1'] * 10), ' '.join(['pos2'] * 10)]
step_sizes = tf.constant([3, 1, 1], dtype=tf.int64)
ds = (tf.data.Dataset.from_tensor_slices((data, step_sizes))
.interleave(lambda d, s: (tf.data.Dataset.from_tensor_slices(string_data(d))
.batch(s)),
cycle_length=len(data))
.flat_map(tf.data.Dataset.from_tensor_slices))
iter = ds.make_one_shot_iterator().get_next()
with tf.Session() as sess:
while True:
try:
print(sess.run(iter).decode(), end=', ')
except tf.errors.OutOfRangeError: break
print()
Вывод:
neg, neg, neg, pos1, pos2, neg, neg, neg, pos1, pos2, neg, neg, neg, pos1, pos2, neg, neg, neg, pos1, pos2, neg, neg, neg, pos1, pos2, neg, neg, neg, pos1, pos2, neg, neg, neg, pos1, pos2, neg, neg, neg, pos1, pos2, neg, neg, neg, pos1, pos2, neg, neg, neg, pos1, pos2,
В реальном случае вы бы заменили data
списком файловимена и tf.data.Dataset.from_tensor_slices(string_data(d))
с tf.data.TFRecordDataset(d)
, но в остальном все должно работать аналогично.
РЕДАКТИРОВАТЬ: Я только что понял, что вы на самом деле хотели, чтобы пакет всех элементов был упорядочен таким образом, а не только один элемент за раз,так что я полагаю, вам нужно добавить еще один batch
вызов в конце.