Пожалуйста, обратитесь к приведенному ниже коду для итерации по размеру партии
import numpy as np
import tensorflow as tf
def ds():
for i in range(1000):
x = np.random.rand(10,10,3)
y = np.random.rand(2)
yield x,y
ds = tf.data.Dataset.from_generator(ds, output_types=(tf.float32, tf.float32), output_shapes=((10, 10, 3), (2,)))
ds = ds.batch(10)
for batch, (x,y) in enumerate(ds):
pass
print("Data shape: ", x.shape, y.shape)
Вывод:
Data shape: (10, 10, 10, 3) (10, 2)
Если вы измените ds = ds.batch(1)
, тогда вывод будет Data shape: (1, 10, 10, 3) (1, 2)