Tensorflow dataset.shuffle, кажется, не перемешивается без повторения () - PullRequest
0 голосов
/ 03 июля 2019

Мой код имеет аналогичный шаблон с tenorflow 2.0 учебник . Я хочу, чтобы мой объект набора данных перетасовывался в каждую эпоху.

dataset = tf.data.Dataset.from_tensor_slices(['a','b','c','d'])
dataset = dataset.shuffle(100)

for epoch in range(10):
    for d in dataset:
        print(d)

Результат:

tf.Tensor(b'c', shape=(), dtype=string)
tf.Tensor(b'a', shape=(), dtype=string)
tf.Tensor(b'b', shape=(), dtype=string)
tf.Tensor(b'd', shape=(), dtype=string)
tf.Tensor(b'c', shape=(), dtype=string)
tf.Tensor(b'a', shape=(), dtype=string)
tf.Tensor(b'b', shape=(), dtype=string)
tf.Tensor(b'd', shape=(), dtype=string)
...

Кажется, набор данных не перетасовывается для каждой эпохи. Должен ли я вызывать .shuffle () для каждой эпохи?

1 Ответ

1 голос
/ 03 июля 2019

Да, вы должны вызывать .shuffle во время внутреннего цикла.Более того, лучше не смешивать код Python и код TensorFlow, когда доступен чистый метод tf. *, Эквивалентный операторам Python.

import tensorflow as tf

dataset = tf.data.Dataset.from_tensor_slices(["a", "b", "c", "d"])
# dataset = dataset.shuffle(2)


@tf.function
def loop():
    for epoch in tf.range(10):
        for d in dataset.shuffle(2):
            tf.print(d)


loop()

При вызове цикла каждый раз создаются разные значения (и tf.print печатает содержимое tf.Tensor, отличное от print, которое печатает объект).

...