Из-за преимуществ использования многопроцессорной обработки во время model.fit_generator
я переключился с простого метода генератора (с использованием yield в истинном цикле while) на keras.utils.Sequence
.В приведенном ниже коде я загружаю некоторые изображения и файлы масок (приложение представляет собой семантическую сегментацию), и мой метод __len__
возвращает длину набора данных, разделенную на размер пакета, как описано во многих руководствах (например, this * 1005).*)
class ImageSequence(Sequence):
def __init__(self, data_location, batch_size=32, is_training=False, class_to_detect='face'):
if is_training:
self.dataset = open(os.path.join(data_location, 'train.txt')).readlines()
else:
self.dataset = open(os.path.join(data_location, 'test.txt')).readlines()
self.class_to_detect = class_to_detect
self.data_location = data_location
self.batch_size = batch_size
def __len__(self):
return len(self.dataset) // self.batch_size
def __getitem__(self, i):
files = self.dataset[(i * self.batch_size):((i + 1) * self.batch_size)]
data = np.zeros((self.batch_size, 64, 64, 3))
labels = np.zeros((self.batch_size, 64, 64, 1))
for i, sample in enumerate(files):
image_file = sample.split(',')[0]
truth_file = sample.split(',')[1][:-1]
image = np.float32(cv2.imread(os.path.join(self.data_location, image_file)) / 255.0)
truth_mask = cv2.imread(os.path.join(self.data_location, truth_file), cv2.IMREAD_GRAYSCALE)
label = np.zeros_like(truth_mask)
label[truth_mask == object_label[self.class_to_detect]] = 1.0
data[i], labels[i] = crop_random(image, label)
return data, labels
Но, как вы можете видеть, я использую метод crop_random
из загруженного изображения, чтобы сделать мою модель немного более устойчивой.Другой метод, который я хочу использовать, - это случайное переключение, но во время реализации я спрашиваю себя
«Когда я взрываю свой набор данных и использую весь этот материал предварительной обработки, увеличивается ли длина последовательности?»
Но когда я увеличу длину последовательности, индекс для __getitem__
, очевидно, должен выйти за пределы диапазона?Но когда я не увеличиваю len, метод fit_generator всегда будет просто видеть «несколько» сэмплов для эпохи.
Использую ли я keras.utils.Sequence
неверным образом, или у меня может быть ложный ментальныймодель как это работает?