Пакетный размер в форме ввода цепей CNN - PullRequest
0 голосов
/ 07 января 2019

У меня есть тренировочный набор из 9957 изображений. Тренировочный набор имеет форму (9957, 3, 60, 80). Требуется ли размер партии при установке тренировочного набора на модель? При необходимости можно ли считать исходную форму правильной для подгонки к слою conv2D или мне нужно добавить размер пакета в input_shape?

X_train.shape

(9957, 60,80,3) из chainer.datasets импорт split_dataset_random из chainer.dataset import DatasetMixin

import numpy as np


class MyDataset(DatasetMixin):
   def __init__(self, X, labels):
       super(MyDataset, self).__init__()
       self.X_ = X
       self.labels_ = labels
       self.size_ = X.shape[0]

   def __len__(self):
       return self.size_

   def get_example(self, i):
       return np.transpose(self.X_[i, ...], (2, 0, 1)), self.labels_[i] 


batch_size = 3

label_train = y_trainHot1
dataset = MyDataset(X_train1, label_train)
dataset_train, valid = split_dataset_random(dataset, 8000, seed=0)
train_iter = iterators.SerialIterator(dataset_train, batch_size)
valid_iter = iterators.SerialIterator(valid, batch_size, repeat=False, 
shuffle=False)

1 Ответ

0 голосов
/ 07 января 2019

Код ниже говорит вам, что вам не нужно заботиться о размере партии самостоятельно. Вы просто используете DatsetMixin и SerialIterator, как указано в руководстве по цепочке.

from chainer.dataset import DatasetMixin
from chainer.iterators import SerialIterator
import numpy as np

NUM_IMAGES = 9957
NUM_CHANNELS = 3  # RGB
IMAGE_WIDTH = 60
IMAGE_HEIGHT = 80

NUM_CLASSES = 10

BATCH_SIZE = 32

TRAIN_SIZE = min(8000, int(NUM_IMAGES * 0.9))

images = np.random.rand(NUM_IMAGES, NUM_CHANNELS, IMAGE_WIDTH, IMAGE_HEIGHT)
labels = np.random.randint(0, NUM_CLASSES, (NUM_IMAGES,))


class MyDataset(DatasetMixin):
    def __init__(self, images_, labels_):
        # note: input arg.'s tailing underscore is just to avoid shadowing
        super(MyDataset, self).__init__()
        self.images_ = images_
        self.labels_ = labels_
        self.size_ = len(labels_)

    def __len__(self):
        return self.size_

    def get_example(self, i):
        return self.images_[i, ...], self.labels_[i]


dataset_train = MyDataset(images[:TRAIN_SIZE, ...], labels[:TRAIN_SIZE])
dataset_valid = MyDataset(images[TRAIN_SIZE:, ...], labels[TRAIN_SIZE:])
train_iter = SerialIterator(dataset_train, BATCH_SIZE)
valid_iter = SerialIterator(dataset_valid, BATCH_SIZE, repeat=False, shuffle=False)

###############################################################################
"""This block is just for the confirmation.

.. note: NOT recommended to call :func:`concat_examples` in your code.
    Use :class:`chainer.updaters.StandardUpdater` instead. 
"""
from chainer.dataset import concat_examples

batch_image, batch_label = concat_examples(next(train_iter))
print("batch_image.shape\n{}".format(batch_image.shape))
print("batch_label.shape\n{}".format(batch_label.shape))

выход

batch_image.shape
(32, 3, 60, 80)
batch_label.shape
(32,)

Следует отметить, что chainer.dataset.concat_example немного сложная часть. Обычно пользователи не обращают внимания на эту функцию, если вы используете StandardUpdater, которая скрывает нативную функцию chainer.dataset.concat_example.

Поскольку цепочник спроектирован по схеме Trainer, (Standard)Updater, некоторых Optimizer, (Serial)Iterator и Dataset(Mixin), если вы не будете следовать этой схеме, вам придется погрузиться в море chainer исходный код.

...