Код ниже говорит вам, что вам не нужно заботиться о размере партии самостоятельно. Вы просто используете DatsetMixin
и SerialIterator
, как указано в руководстве по цепочке.
from chainer.dataset import DatasetMixin
from chainer.iterators import SerialIterator
import numpy as np
NUM_IMAGES = 9957
NUM_CHANNELS = 3 # RGB
IMAGE_WIDTH = 60
IMAGE_HEIGHT = 80
NUM_CLASSES = 10
BATCH_SIZE = 32
TRAIN_SIZE = min(8000, int(NUM_IMAGES * 0.9))
images = np.random.rand(NUM_IMAGES, NUM_CHANNELS, IMAGE_WIDTH, IMAGE_HEIGHT)
labels = np.random.randint(0, NUM_CLASSES, (NUM_IMAGES,))
class MyDataset(DatasetMixin):
def __init__(self, images_, labels_):
# note: input arg.'s tailing underscore is just to avoid shadowing
super(MyDataset, self).__init__()
self.images_ = images_
self.labels_ = labels_
self.size_ = len(labels_)
def __len__(self):
return self.size_
def get_example(self, i):
return self.images_[i, ...], self.labels_[i]
dataset_train = MyDataset(images[:TRAIN_SIZE, ...], labels[:TRAIN_SIZE])
dataset_valid = MyDataset(images[TRAIN_SIZE:, ...], labels[TRAIN_SIZE:])
train_iter = SerialIterator(dataset_train, BATCH_SIZE)
valid_iter = SerialIterator(dataset_valid, BATCH_SIZE, repeat=False, shuffle=False)
###############################################################################
"""This block is just for the confirmation.
.. note: NOT recommended to call :func:`concat_examples` in your code.
Use :class:`chainer.updaters.StandardUpdater` instead.
"""
from chainer.dataset import concat_examples
batch_image, batch_label = concat_examples(next(train_iter))
print("batch_image.shape\n{}".format(batch_image.shape))
print("batch_label.shape\n{}".format(batch_label.shape))
выход
batch_image.shape
(32, 3, 60, 80)
batch_label.shape
(32,)
Следует отметить, что chainer.dataset.concat_example
немного сложная часть. Обычно пользователи не обращают внимания на эту функцию, если вы используете StandardUpdater
, которая скрывает нативную функцию chainer.dataset.concat_example
.
Поскольку цепочник спроектирован по схеме Trainer
, (Standard)Updater
, некоторых Optimizer
, (Serial)Iterator
и Dataset(Mixin)
, если вы не будете следовать этой схеме, вам придется погрузиться в море chainer
исходный код.