Переполнение памяти в графическом процессоре Keras при использовании keras.utils.sequence и генератора - PullRequest
/ 05 февраля 2020


import os
import random
from skimage import io
import cv2
from skimage.transform import resize
import numpy as np
import tensorflow as tf

import keras
import Augmentor

def iter_sequence_infinite(seq):
    """Iterate indefinitely over a Sequence.
    # Arguments
        seq: Sequence object
    # Returns
        Generator yielding batches.
    while True:
        for item in seq:
            yield item

# data generator class
class DataGenerator(keras.utils.Sequence):
    def __init__(self, ids, imgs_dir, masks_dir, batch_size=10, img_size=128, n_classes=1, n_channels=3, shuffle=True):
        self.id_names = ids
        self.indexes = np.arange(len(self.id_names))
        self.imgs_dir = imgs_dir
        self.masks_dir = masks_dir
        self.batch_size = batch_size
        self.img_size = img_size
        self.n_classes = n_classes
        self.n_channels = n_channels
        self.shuffle = shuffle

    # for printing the statistics of the function
    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.id_names))
        if self.shuffle == True:

    def __data_generation__(self, id_name):
        'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
        # Initialization
        img_path = os.path.join(self.imgs_dir, id_name)  # polyp segmentation/images/id_name.jpg
        mask_path = os.path.join(self.masks_dir, id_name) # polyp segmenatation/masks/id_name.jpg

        img = io.imread(img_path)
        mask = cv2.imread(mask_path)

        p = Augmentor.DataPipeline([[img, mask]])
        p.resize(probability=1.0, width=self.img_size, height=self.img_size)
        p.rotate_without_crop(probability=0.3, max_left_rotation=10, max_right_rotation=10)
        #p.random_distortion(probability=0.3, grid_height=10, grid_width=10, magnitude=1)
        p.shear(probability=0.3, max_shear_left=1, max_shear_right=1)
        #p.skew_tilt(probability=0.3, magnitude=0.1)

        sample_p = p.sample(1)
        sample_p = np.array(sample_p).squeeze()

        p_img = sample_p[0]
        p_mask = sample_p[1]
        augmented_mask = (p_mask // 255) * 255  # denoising

        q = Augmentor.DataPipeline([[p_img]])
        q.random_contrast(probability=0.3, min_factor=0.2, max_factor=1.0)  # low to High
        q.random_brightness(probability=0.3, min_factor=0.2, max_factor=1.0)  # dark to bright

        sample_q = q.sample(1)
        sample_q = np.array(sample_q).squeeze()

        image = sample_q
        mask = augmented_mask[::, ::, 0]

        # reading the image from dataset
        ## Reading Image
        image = io.imread(img_path)  # reading image to image vaiable
        image = resize(image, (self.img_size, self.img_size), anti_aliasing=True)  # resizing input image to 128 * 128

        mask = io.imread(mask_path, as_gray=True)  # mask image of same size with all zeros
        mask = resize(mask, (self.img_size, self.img_size), anti_aliasing=True)  # resizing mask to fit the 128 * 128 image
        mask = np.expand_dims(mask, axis=-1)

        # image normalization
        image = image / 255.0
        mask = mask / 255.0

        return image, mask

    def __len__(self):
        "Denotes the number of batches per epoch"
        return int(np.floor(len(self.id_names) / self.batch_size))

    def __getitem__(self, index):  # index : batch no.
        # Generate indexes of the batch
        indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]
        batch_ids = [self.id_names[k] for k in indexes]

        imgs = list()
        masks = list()

        for id_name in batch_ids:
            img, mask = self.__data_generation__(id_name)

        imgs = np.array(imgs)
        masks = np.array(masks)

        return imgs, masks  # return batch


import argparse
import logging
import os
import sys
from tqdm import tqdm # progress bar
import numpy as np
import matplotlib.pyplot as plt

from keras import optimizers
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
import segmentation_models as sm
from segmentation_models.utils import set_trainable
from dataset import DataGenerator, iter_sequence_infinite

def train_model(model, train_gen, valid_gen, epochs, save_cp=True):
    total_batch_count = 0
    train_img_num = len(train_gen.id_names)
    train_batch_num = len(train_gen)
    train_gen_out = iter_sequence_infinite(train_gen)

    valid_batch_num = len(valid_gen)
    valid_img_num = len(valid_gen.id_names)
    valid_gen_out = iter_sequence_infinite(valid_gen)

    for epoch in range(epochs): # interation as many epochs

        epoch_loss = 0 # loss in this epoch
        epoch_iou = 0
        count = 0

        with tqdm(total=train_img_num, desc=f'Epoch {epoch + 1}/{epochs}',  position=0, leave=True, unit='img') as pbar:  # make progress bar
            for _ in range(train_batch_num):
                batch = next(train_gen_out)
                imgs = batch[0]
                true_masks = batch[1]
                loss, iou = model.train_on_batch(imgs, true_masks)  # value of loss of this batch
                epoch_loss += loss
                epoch_iou += iou

                pbar.set_postfix(**{'Batch loss': loss, 'Batch IoU': iou})  # floating the loss at the post in the pbar

                pbar.update(imgs.shape[0])  # update progress
                count += 1
                total_batch_count += 1

        print( "Epoch : loss: {}, IoU : {}".format(epoch_loss/count, epoch_iou/count))

        # Do validation
        validation_model(model, valid_gen_out, valid_batch_num, valid_img_num)

        if save_cp:
            model.save_weights(os.path.join(checkpoint_dir , f'CP_epoch{epoch + 1}.h5'))
            logging.info(f'Checkpoint {epoch + 1} saved !')

def validation_model(model, valid_gen_out, valid_batch_num, valid_img_num):
    epoch_loss = 0  # loss in this epoch
    epoch_iou = 0
    count = 0

    with tqdm(total=valid_img_num, desc='Validation round',  position=0, leave=True, unit='img') as pbar:  # make progress bar
        for _ in range(valid_batch_num):
            batch = next(valid_gen_out)
            imgs = batch[0]
            true_masks = batch[1]
            loss, iou = model.test_on_batch(imgs, true_masks)  # value of loss of this batch
            epoch_loss += loss
            epoch_iou += iou

            pbar.set_postfix(**{'Batch, loss': loss, 'Batch IoU': iou})  # floating the loss at the post in the pbar

            pbar.update(imgs.shape[0])  # update progress
            count += 1

    print("Validation loss: {}, IoU: {}".format(epoch_loss / count, epoch_iou / count))
    pred_mask = model.predict(np.expand_dims(imgs[0],0))
    plt.imshow(true_masks[0].squeeze(), cmap="gray")
    plt.imshow(pred_mask.squeeze(), cmap="gray")

def get_args():
    parser = argparse.ArgumentParser(description='Train the UNet on images and target masks',
    parser.add_argument('-e', '--epochs', metavar='E', type=int, default=50,
                        help='Number of epochs', dest='epochs')
    parser.add_argument('-b', '--batch_size', metavar='B', type=int, nargs='?', default=2,
                        help='Batch size', dest='batch_size')
    parser.add_argument('-l', '--learning-rate', metavar='LR', type=float, nargs='?', default=1e-5,
                        help='Learning rate', dest='lr')
    parser.add_argument('-bb', '--backbone', default='resnet50', metavar='FILE',
                        help="backcone name")
    parser.add_argument('-w', '--weight', dest='load', type=str, default=False,
                        help='Load model from a .h5 file')
    parser.add_argument('-s', '--resizing', dest='resizing', type=int, default=384,
                        help='Downscaling factor of the images')
    parser.add_argument('-v', '--validation', dest='val', type=float, default=20.0,
                        help='Percent of the data that is used as validation (0-100)')

    return parser.parse_args()

if __name__ == '__main__':
    img_dir = './data/train/imgs/'  # ./data/train/imgs/CVC_Original/'
    mask_dir = './data/train/masks/'  # ./data/train/masks/CVC_Ground Truth/'
    checkpoint_dir = './checkpoints'
    args = get_args()

    # train path
    train_ids = os.listdir(img_dir)
    # Validation Data Size
    n_val = int(len(train_ids) * args.val/100)  # size of validation set

    valid_ids = train_ids[:n_val]  # list of image ids used for validation of result 0 to 9
    train_ids = train_ids[n_val:]  # list of image ids used for training dataset
    # print(valid_ids, "\n\n")
    print("training_size: ", len(train_ids), "validation_size: ", len(valid_ids))

    train_gen = DataGenerator(train_ids, img_dir, mask_dir, img_size=args.resizing, batch_size=args.batch_size)
    valid_gen = DataGenerator(valid_ids, img_dir, mask_dir, img_size=args.resizing, batch_size=args.batch_size)

    print("total training batches: ", len(train_gen))
    print("total validaton batches: ", len(valid_gen))
    train_steps = len(train_ids) // args.batch_size
    valid_steps = len(valid_ids) // args.batch_size

    # define model
    model = sm.Unet(args.backbone, encoder_weights='imagenet')

    optimizer = optimizers.Adam(lr=args.lr, decay=1e-4)
        #        "Adam",
        loss=sm.losses.bce_dice_loss,  # sm.losses.bce_jaccard_loss, # sm.losses.binary_crossentropy,

    callbacks = [
        EarlyStopping(patience=6, verbose=1),
        ReduceLROnPlateau(factor=0.1, patience=3, min_lr=1e-7, verbose=1),
        ModelCheckpoint('./weights.Epoch{epoch:02d}-Loss{loss:.3f}-VIou{val_iou_score:.3f}.h5', verbose=1,
                        monitor='val_accuracy', save_best_only=True, save_weights_only=True)

    train_model(model=model, train_gen=train_gen, valid_gen=valid_gen, epochs=args.epochs)

Когда я пытаюсь запустить этот код, некоторые эпохи хорошо прогрессировал, но в 20 эпохах происходит ошибка переполнения памяти gpu, как показано ниже

(0) Resource exhausted: OOM when allocating tensor with shape[2,64,96,96] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
     [[{{node decoder_stage2b_bn/FusedBatchNorm}}]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info.

, поэтому я думаю, что это происходит из-за генерации данных.

Этот код генерирует пакет в этом порядке.

  1. в train.py, инициализировать Класс Datageneratr, представляющий собой модель последовательности , реализованную в Dataset.py

    train_gen = DataGenerator (train_ids, img_dir, mask_dir, img_size = args.resizing, batch_size = args.batch_size)

    valid_gen = DataGenerator (valid_ids, img_dir, mask_dir = args_ize) resizing, batch_size = args.batch_size)

  2. Сначала в function 'train_model' преобразуйте генератор данных (модель последовательности) в генератор с использованием функции ' iter_sequence_infinite '

    train_gen_out = iter_sequence_infinite (train_gen)

    valid_gen_out = iter_sequence_infinite (valid_gen)

  3. с использованием функции magi c, 'next', получить партию

    batch = next (train_gen_out)

Я думаю, что проблем с памятью не будет, но она возникла. В чем проблема и как ее решить? Спасибо.
