Переполнение памяти в графическом процессоре Keras при использовании keras.utils.sequence и генератора - PullRequest
0 голосов
/ 05 февраля 2020

Dataset.py

import os
import random
from skimage import io
import cv2
from skimage.transform import resize
import numpy as np
import tensorflow as tf

import keras
import Augmentor

def iter_sequence_infinite(seq):
    """Iterate indefinitely over a Sequence.
    # Arguments
        seq: Sequence object
    # Returns
        Generator yielding batches.
    """
    while True:
        for item in seq:
            yield item

# data generator class
class DataGenerator(keras.utils.Sequence):
    def __init__(self, ids, imgs_dir, masks_dir, batch_size=10, img_size=128, n_classes=1, n_channels=3, shuffle=True):
        self.id_names = ids
        self.indexes = np.arange(len(self.id_names))
        self.imgs_dir = imgs_dir
        self.masks_dir = masks_dir
        self.batch_size = batch_size
        self.img_size = img_size
        self.n_classes = n_classes
        self.n_channels = n_channels
        self.shuffle = shuffle
        self.on_epoch_end()

    # for printing the statistics of the function
    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.id_names))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

    def __data_generation__(self, id_name):
        'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
        # Initialization
        img_path = os.path.join(self.imgs_dir, id_name)  # polyp segmentation/images/id_name.jpg
        mask_path = os.path.join(self.masks_dir, id_name) # polyp segmenatation/masks/id_name.jpg

        img = io.imread(img_path)
        mask = cv2.imread(mask_path)

        p = Augmentor.DataPipeline([[img, mask]])
        p.resize(probability=1.0, width=self.img_size, height=self.img_size)
        p.rotate_without_crop(probability=0.3, max_left_rotation=10, max_right_rotation=10)
        #p.random_distortion(probability=0.3, grid_height=10, grid_width=10, magnitude=1)
        p.shear(probability=0.3, max_shear_left=1, max_shear_right=1)
        #p.skew_tilt(probability=0.3, magnitude=0.1)
        p.flip_random(probability=0.3)

        sample_p = p.sample(1)
        sample_p = np.array(sample_p).squeeze()

        p_img = sample_p[0]
        p_mask = sample_p[1]
        augmented_mask = (p_mask // 255) * 255  # denoising

        q = Augmentor.DataPipeline([[p_img]])
        q.random_contrast(probability=0.3, min_factor=0.2, max_factor=1.0)  # low to High
        q.random_brightness(probability=0.3, min_factor=0.2, max_factor=1.0)  # dark to bright

        sample_q = q.sample(1)
        sample_q = np.array(sample_q).squeeze()

        image = sample_q
        mask = augmented_mask[::, ::, 0]

        """
        # reading the image from dataset
        ## Reading Image
        image = io.imread(img_path)  # reading image to image vaiable
        image = resize(image, (self.img_size, self.img_size), anti_aliasing=True)  # resizing input image to 128 * 128

        mask = io.imread(mask_path, as_gray=True)  # mask image of same size with all zeros
        mask = resize(mask, (self.img_size, self.img_size), anti_aliasing=True)  # resizing mask to fit the 128 * 128 image
        mask = np.expand_dims(mask, axis=-1)
        """

        # image normalization
        image = image / 255.0
        mask = mask / 255.0

        return image, mask

    def __len__(self):
        "Denotes the number of batches per epoch"
        return int(np.floor(len(self.id_names) / self.batch_size))

    def __getitem__(self, index):  # index : batch no.
        # Generate indexes of the batch
        # Generate indexes of the batch
        indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]
        batch_ids = [self.id_names[k] for k in indexes]

        imgs = list()
        masks = list()

        for id_name in batch_ids:
            img, mask = self.__data_generation__(id_name)
            imgs.append(img)
            masks.append(np.expand_dims(mask,-1))

        imgs = np.array(imgs)
        masks = np.array(masks)

        return imgs, masks  # return batch

train.py

import argparse
import logging
import os
import sys
from tqdm import tqdm # progress bar
import numpy as np
import matplotlib.pyplot as plt

from keras import optimizers
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
import segmentation_models as sm
from segmentation_models.utils import set_trainable
from dataset import DataGenerator, iter_sequence_infinite



def train_model(model, train_gen, valid_gen, epochs, save_cp=True):
    total_batch_count = 0
    train_img_num = len(train_gen.id_names)
    train_batch_num = len(train_gen)
    train_gen_out = iter_sequence_infinite(train_gen)

    valid_batch_num = len(valid_gen)
    valid_img_num = len(valid_gen.id_names)
    valid_gen_out = iter_sequence_infinite(valid_gen)

    for epoch in range(epochs): # interation as many epochs
        set_trainable(model)

        epoch_loss = 0 # loss in this epoch
        epoch_iou = 0
        count = 0

        with tqdm(total=train_img_num, desc=f'Epoch {epoch + 1}/{epochs}',  position=0, leave=True, unit='img') as pbar:  # make progress bar
            for _ in range(train_batch_num):
                batch = next(train_gen_out)
                imgs = batch[0]
                true_masks = batch[1]
                loss, iou = model.train_on_batch(imgs, true_masks)  # value of loss of this batch
                epoch_loss += loss
                epoch_iou += iou

                pbar.set_postfix(**{'Batch loss': loss, 'Batch IoU': iou})  # floating the loss at the post in the pbar

                pbar.update(imgs.shape[0])  # update progress
                count += 1
                total_batch_count += 1

        train_gen.on_epoch_end()
        print( "Epoch : loss: {}, IoU : {}".format(epoch_loss/count, epoch_iou/count))

        # Do validation
        validation_model(model, valid_gen_out, valid_batch_num, valid_img_num)
        valid_gen.on_epoch_end()

        if save_cp:
            try:
                if not os.path.isdir(checkpoint_dir):
                    os.mkdir(checkpoint_dir)
                    logging.info('Created checkpoint directory')
                else:
                    pass
            except OSError:
                pass
            model.save_weights(os.path.join(checkpoint_dir , f'CP_epoch{epoch + 1}.h5'))
            logging.info(f'Checkpoint {epoch + 1} saved !')

def validation_model(model, valid_gen_out, valid_batch_num, valid_img_num):
    epoch_loss = 0  # loss in this epoch
    epoch_iou = 0
    count = 0

    with tqdm(total=valid_img_num, desc='Validation round',  position=0, leave=True, unit='img') as pbar:  # make progress bar
        for _ in range(valid_batch_num):
            batch = next(valid_gen_out)
            imgs = batch[0]
            true_masks = batch[1]
            loss, iou = model.test_on_batch(imgs, true_masks)  # value of loss of this batch
            epoch_loss += loss
            epoch_iou += iou

            pbar.set_postfix(**{'Batch, loss': loss, 'Batch IoU': iou})  # floating the loss at the post in the pbar

            pbar.update(imgs.shape[0])  # update progress
            count += 1

    print("Validation loss: {}, IoU: {}".format(epoch_loss / count, epoch_iou / count))
    pred_mask = model.predict(np.expand_dims(imgs[0],0))
    plt.subplot(131)
    plt.imshow(imgs[0])
    plt.subplot(132)
    plt.imshow(true_masks[0].squeeze(), cmap="gray")
    plt.subplot(133)
    plt.imshow(pred_mask.squeeze(), cmap="gray")
    plt.show()
    print()


def get_args():
    parser = argparse.ArgumentParser(description='Train the UNet on images and target masks',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('-e', '--epochs', metavar='E', type=int, default=50,
                        help='Number of epochs', dest='epochs')
    parser.add_argument('-b', '--batch_size', metavar='B', type=int, nargs='?', default=2,
                        help='Batch size', dest='batch_size')
    parser.add_argument('-l', '--learning-rate', metavar='LR', type=float, nargs='?', default=1e-5,
                        help='Learning rate', dest='lr')
    parser.add_argument('-bb', '--backbone', default='resnet50', metavar='FILE',
                        help="backcone name")
    parser.add_argument('-w', '--weight', dest='load', type=str, default=False,
                        help='Load model from a .h5 file')
    parser.add_argument('-s', '--resizing', dest='resizing', type=int, default=384,
                        help='Downscaling factor of the images')
    parser.add_argument('-v', '--validation', dest='val', type=float, default=20.0,
                        help='Percent of the data that is used as validation (0-100)')

    return parser.parse_args()


if __name__ == '__main__':
    img_dir = './data/train/imgs/'  # ./data/train/imgs/CVC_Original/'
    mask_dir = './data/train/masks/'  # ./data/train/masks/CVC_Ground Truth/'
    checkpoint_dir = './checkpoints'
    args = get_args()

    # train path
    train_ids = os.listdir(img_dir)
    # Validation Data Size
    n_val = int(len(train_ids) * args.val/100)  # size of validation set


    valid_ids = train_ids[:n_val]  # list of image ids used for validation of result 0 to 9
    train_ids = train_ids[n_val:]  # list of image ids used for training dataset
    # print(valid_ids, "\n\n")
    print("training_size: ", len(train_ids), "validation_size: ", len(valid_ids))

    train_gen = DataGenerator(train_ids, img_dir, mask_dir, img_size=args.resizing, batch_size=args.batch_size)
    valid_gen = DataGenerator(valid_ids, img_dir, mask_dir, img_size=args.resizing, batch_size=args.batch_size)

    print("total training batches: ", len(train_gen))
    print("total validaton batches: ", len(valid_gen))
    train_steps = len(train_ids) // args.batch_size
    valid_steps = len(valid_ids) // args.batch_size

    # define model
    model = sm.Unet(args.backbone, encoder_weights='imagenet')

    optimizer = optimizers.Adam(lr=args.lr, decay=1e-4)
    model.compile(
        optimizer=optimizer,
        #        "Adam",
        loss=sm.losses.bce_dice_loss,  # sm.losses.bce_jaccard_loss, # sm.losses.binary_crossentropy,
        metrics=[sm.metrics.iou_score],
    )
    #model.summary()

    callbacks = [
        EarlyStopping(patience=6, verbose=1),
        ReduceLROnPlateau(factor=0.1, patience=3, min_lr=1e-7, verbose=1),
        ModelCheckpoint('./weights.Epoch{epoch:02d}-Loss{loss:.3f}-VIou{val_iou_score:.3f}.h5', verbose=1,
                        monitor='val_accuracy', save_best_only=True, save_weights_only=True)
                ]


    train_model(model=model, train_gen=train_gen, valid_gen=valid_gen, epochs=args.epochs)

Когда я пытаюсь запустить этот код, некоторые эпохи хорошо прогрессировал, но в 20 эпохах происходит ошибка переполнения памяти gpu, как показано ниже

(0) Resource exhausted: OOM when allocating tensor with shape[2,64,96,96] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
     [[{{node decoder_stage2b_bn/FusedBatchNorm}}]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info.

, поэтому я думаю, что это происходит из-за генерации данных.

Этот код генерирует пакет в этом порядке.

  1. в train.py, инициализировать Класс Datageneratr, представляющий собой модель последовательности , реализованную в Dataset.py

    train_gen = DataGenerator (train_ids, img_dir, mask_dir, img_size = args.resizing, batch_size = args.batch_size)

    valid_gen = DataGenerator (valid_ids, img_dir, mask_dir = args_ize) resizing, batch_size = args.batch_size)

  2. Сначала в function 'train_model' преобразуйте генератор данных (модель последовательности) в генератор с использованием функции ' iter_sequence_infinite '

    train_gen_out = iter_sequence_infinite (train_gen)

    valid_gen_out = iter_sequence_infinite (valid_gen)

  3. с использованием функции magi c, 'next', получить партию

    batch = next (train_gen_out)

Я думаю, что проблем с памятью не будет, но она возникла. В чем проблема и как ее решить? Спасибо.

...