Dataset.py
import os
import random
from skimage import io
import cv2
from skimage.transform import resize
import numpy as np
import tensorflow as tf
import keras
import Augmentor
def iter_sequence_infinite(seq):
"""Iterate indefinitely over a Sequence.
# Arguments
seq: Sequence object
# Returns
Generator yielding batches.
"""
while True:
for item in seq:
yield item
# data generator class
class DataGenerator(keras.utils.Sequence):
def __init__(self, ids, imgs_dir, masks_dir, batch_size=10, img_size=128, n_classes=1, n_channels=3, shuffle=True):
self.id_names = ids
self.indexes = np.arange(len(self.id_names))
self.imgs_dir = imgs_dir
self.masks_dir = masks_dir
self.batch_size = batch_size
self.img_size = img_size
self.n_classes = n_classes
self.n_channels = n_channels
self.shuffle = shuffle
self.on_epoch_end()
# for printing the statistics of the function
def on_epoch_end(self):
'Updates indexes after each epoch'
self.indexes = np.arange(len(self.id_names))
if self.shuffle == True:
np.random.shuffle(self.indexes)
def __data_generation__(self, id_name):
'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
# Initialization
img_path = os.path.join(self.imgs_dir, id_name) # polyp segmentation/images/id_name.jpg
mask_path = os.path.join(self.masks_dir, id_name) # polyp segmenatation/masks/id_name.jpg
img = io.imread(img_path)
mask = cv2.imread(mask_path)
p = Augmentor.DataPipeline([[img, mask]])
p.resize(probability=1.0, width=self.img_size, height=self.img_size)
p.rotate_without_crop(probability=0.3, max_left_rotation=10, max_right_rotation=10)
#p.random_distortion(probability=0.3, grid_height=10, grid_width=10, magnitude=1)
p.shear(probability=0.3, max_shear_left=1, max_shear_right=1)
#p.skew_tilt(probability=0.3, magnitude=0.1)
p.flip_random(probability=0.3)
sample_p = p.sample(1)
sample_p = np.array(sample_p).squeeze()
p_img = sample_p[0]
p_mask = sample_p[1]
augmented_mask = (p_mask // 255) * 255 # denoising
q = Augmentor.DataPipeline([[p_img]])
q.random_contrast(probability=0.3, min_factor=0.2, max_factor=1.0) # low to High
q.random_brightness(probability=0.3, min_factor=0.2, max_factor=1.0) # dark to bright
sample_q = q.sample(1)
sample_q = np.array(sample_q).squeeze()
image = sample_q
mask = augmented_mask[::, ::, 0]
"""
# reading the image from dataset
## Reading Image
image = io.imread(img_path) # reading image to image vaiable
image = resize(image, (self.img_size, self.img_size), anti_aliasing=True) # resizing input image to 128 * 128
mask = io.imread(mask_path, as_gray=True) # mask image of same size with all zeros
mask = resize(mask, (self.img_size, self.img_size), anti_aliasing=True) # resizing mask to fit the 128 * 128 image
mask = np.expand_dims(mask, axis=-1)
"""
# image normalization
image = image / 255.0
mask = mask / 255.0
return image, mask
def __len__(self):
"Denotes the number of batches per epoch"
return int(np.floor(len(self.id_names) / self.batch_size))
def __getitem__(self, index): # index : batch no.
# Generate indexes of the batch
# Generate indexes of the batch
indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]
batch_ids = [self.id_names[k] for k in indexes]
imgs = list()
masks = list()
for id_name in batch_ids:
img, mask = self.__data_generation__(id_name)
imgs.append(img)
masks.append(np.expand_dims(mask,-1))
imgs = np.array(imgs)
masks = np.array(masks)
return imgs, masks # return batch
train.py
import argparse
import logging
import os
import sys
from tqdm import tqdm # progress bar
import numpy as np
import matplotlib.pyplot as plt
from keras import optimizers
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
import segmentation_models as sm
from segmentation_models.utils import set_trainable
from dataset import DataGenerator, iter_sequence_infinite
def train_model(model, train_gen, valid_gen, epochs, save_cp=True):
total_batch_count = 0
train_img_num = len(train_gen.id_names)
train_batch_num = len(train_gen)
train_gen_out = iter_sequence_infinite(train_gen)
valid_batch_num = len(valid_gen)
valid_img_num = len(valid_gen.id_names)
valid_gen_out = iter_sequence_infinite(valid_gen)
for epoch in range(epochs): # interation as many epochs
set_trainable(model)
epoch_loss = 0 # loss in this epoch
epoch_iou = 0
count = 0
with tqdm(total=train_img_num, desc=f'Epoch {epoch + 1}/{epochs}', position=0, leave=True, unit='img') as pbar: # make progress bar
for _ in range(train_batch_num):
batch = next(train_gen_out)
imgs = batch[0]
true_masks = batch[1]
loss, iou = model.train_on_batch(imgs, true_masks) # value of loss of this batch
epoch_loss += loss
epoch_iou += iou
pbar.set_postfix(**{'Batch loss': loss, 'Batch IoU': iou}) # floating the loss at the post in the pbar
pbar.update(imgs.shape[0]) # update progress
count += 1
total_batch_count += 1
train_gen.on_epoch_end()
print( "Epoch : loss: {}, IoU : {}".format(epoch_loss/count, epoch_iou/count))
# Do validation
validation_model(model, valid_gen_out, valid_batch_num, valid_img_num)
valid_gen.on_epoch_end()
if save_cp:
try:
if not os.path.isdir(checkpoint_dir):
os.mkdir(checkpoint_dir)
logging.info('Created checkpoint directory')
else:
pass
except OSError:
pass
model.save_weights(os.path.join(checkpoint_dir , f'CP_epoch{epoch + 1}.h5'))
logging.info(f'Checkpoint {epoch + 1} saved !')
def validation_model(model, valid_gen_out, valid_batch_num, valid_img_num):
epoch_loss = 0 # loss in this epoch
epoch_iou = 0
count = 0
with tqdm(total=valid_img_num, desc='Validation round', position=0, leave=True, unit='img') as pbar: # make progress bar
for _ in range(valid_batch_num):
batch = next(valid_gen_out)
imgs = batch[0]
true_masks = batch[1]
loss, iou = model.test_on_batch(imgs, true_masks) # value of loss of this batch
epoch_loss += loss
epoch_iou += iou
pbar.set_postfix(**{'Batch, loss': loss, 'Batch IoU': iou}) # floating the loss at the post in the pbar
pbar.update(imgs.shape[0]) # update progress
count += 1
print("Validation loss: {}, IoU: {}".format(epoch_loss / count, epoch_iou / count))
pred_mask = model.predict(np.expand_dims(imgs[0],0))
plt.subplot(131)
plt.imshow(imgs[0])
plt.subplot(132)
plt.imshow(true_masks[0].squeeze(), cmap="gray")
plt.subplot(133)
plt.imshow(pred_mask.squeeze(), cmap="gray")
plt.show()
print()
def get_args():
parser = argparse.ArgumentParser(description='Train the UNet on images and target masks',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-e', '--epochs', metavar='E', type=int, default=50,
help='Number of epochs', dest='epochs')
parser.add_argument('-b', '--batch_size', metavar='B', type=int, nargs='?', default=2,
help='Batch size', dest='batch_size')
parser.add_argument('-l', '--learning-rate', metavar='LR', type=float, nargs='?', default=1e-5,
help='Learning rate', dest='lr')
parser.add_argument('-bb', '--backbone', default='resnet50', metavar='FILE',
help="backcone name")
parser.add_argument('-w', '--weight', dest='load', type=str, default=False,
help='Load model from a .h5 file')
parser.add_argument('-s', '--resizing', dest='resizing', type=int, default=384,
help='Downscaling factor of the images')
parser.add_argument('-v', '--validation', dest='val', type=float, default=20.0,
help='Percent of the data that is used as validation (0-100)')
return parser.parse_args()
if __name__ == '__main__':
img_dir = './data/train/imgs/' # ./data/train/imgs/CVC_Original/'
mask_dir = './data/train/masks/' # ./data/train/masks/CVC_Ground Truth/'
checkpoint_dir = './checkpoints'
args = get_args()
# train path
train_ids = os.listdir(img_dir)
# Validation Data Size
n_val = int(len(train_ids) * args.val/100) # size of validation set
valid_ids = train_ids[:n_val] # list of image ids used for validation of result 0 to 9
train_ids = train_ids[n_val:] # list of image ids used for training dataset
# print(valid_ids, "\n\n")
print("training_size: ", len(train_ids), "validation_size: ", len(valid_ids))
train_gen = DataGenerator(train_ids, img_dir, mask_dir, img_size=args.resizing, batch_size=args.batch_size)
valid_gen = DataGenerator(valid_ids, img_dir, mask_dir, img_size=args.resizing, batch_size=args.batch_size)
print("total training batches: ", len(train_gen))
print("total validaton batches: ", len(valid_gen))
train_steps = len(train_ids) // args.batch_size
valid_steps = len(valid_ids) // args.batch_size
# define model
model = sm.Unet(args.backbone, encoder_weights='imagenet')
optimizer = optimizers.Adam(lr=args.lr, decay=1e-4)
model.compile(
optimizer=optimizer,
# "Adam",
loss=sm.losses.bce_dice_loss, # sm.losses.bce_jaccard_loss, # sm.losses.binary_crossentropy,
metrics=[sm.metrics.iou_score],
)
#model.summary()
callbacks = [
EarlyStopping(patience=6, verbose=1),
ReduceLROnPlateau(factor=0.1, patience=3, min_lr=1e-7, verbose=1),
ModelCheckpoint('./weights.Epoch{epoch:02d}-Loss{loss:.3f}-VIou{val_iou_score:.3f}.h5', verbose=1,
monitor='val_accuracy', save_best_only=True, save_weights_only=True)
]
train_model(model=model, train_gen=train_gen, valid_gen=valid_gen, epochs=args.epochs)
Когда я пытаюсь запустить этот код, некоторые эпохи хорошо прогрессировал, но в 20 эпохах происходит ошибка переполнения памяти gpu, как показано ниже
(0) Resource exhausted: OOM when allocating tensor with shape[2,64,96,96] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
[[{{node decoder_stage2b_bn/FusedBatchNorm}}]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info.
, поэтому я думаю, что это происходит из-за генерации данных.
Этот код генерирует пакет в этом порядке.
в train.py, инициализировать Класс Datageneratr, представляющий собой модель последовательности , реализованную в Dataset.py
train_gen = DataGenerator (train_ids, img_dir, mask_dir, img_size = args.resizing, batch_size = args.batch_size)
valid_gen = DataGenerator (valid_ids, img_dir, mask_dir = args_ize) resizing, batch_size = args.batch_size)
Сначала в function 'train_model' преобразуйте генератор данных (модель последовательности) в генератор с использованием функции ' iter_sequence_infinite '
train_gen_out = iter_sequence_infinite (train_gen)
valid_gen_out = iter_sequence_infinite (valid_gen)
с использованием функции magi c, 'next', получить партию
batch = next (train_gen_out)
Я думаю, что проблем с памятью не будет, но она возникла. В чем проблема и как ее решить? Спасибо.