Набор данных Tensorflow, созданный с помощью from_generator, не повторяется с помощью batch_size - PullRequest
0 голосов
/ 26 марта 2020

Я создал набор данных тензорного потока из генератора, но не могу понять, как его перебрать, используя batch_size

def ds_gen():
    x = np.random.random((10, 10, 3))
    y = np.random.random((2))
    yield x, y

def create_tf_dataset():
    dataset = tf.data.Dataset.from_generator(ds_gen, output_types=(tf.float32, tf.float32), output_shapes=((10, 10, 3), (2,)))
    return dataset

ds = create_tf_dataset()
ds = ds.batch(10)
for x_batch, y_batch in ds:
  print(x_batch.shape, y_batch.shape)

Этот код повторяет циклы по пакетам размером 1, а не 10

1 Ответ

2 голосов
/ 27 марта 2020

Пожалуйста, обратитесь к приведенному ниже коду для итерации по размеру партии

import numpy as np
import tensorflow as tf 

def ds():
    for i in range(1000):
      x = np.random.rand(10,10,3)
      y = np.random.rand(2)
      yield x,y

ds = tf.data.Dataset.from_generator(ds, output_types=(tf.float32, tf.float32), output_shapes=((10, 10, 3), (2,)))

ds = ds.batch(10)

for batch, (x,y) in enumerate(ds):
  pass
print("Data shape: ", x.shape, y.shape)

Вывод:

Data shape:  (10, 10, 10, 3) (10, 2)

Если вы измените ds = ds.batch(1), тогда вывод будет Data shape: (1, 10, 10, 3) (1, 2)

...