Я сделал модель gan для работы с изображениями, но у меня возникла ошибка OOM в генеративной части. Как справиться с такой ошибкой? - PullRequest
0 голосов
/ 04 апреля 2019

я работаю над изображениями.Я сделал модель gan с keras для работы с этими изображениями, но у меня возникла ошибка OOM в генеративной части.

Я пытался уменьшить пакет (до 1), я также уменьшил размер изображений, я также уменьшил количество слоев в моей модели, но я всегда получал эту ошибку.Я думал, что функция предиката игнорировала размер пакета, поэтому я попытался поместить только 10 изображений в мои файлы (размер одного изображения = (400 500,3))

Когда я смотрю на обобщение генеративной модели, ясмотрите параметры 31k, это ничего.

Я также запустил свой код на экземпляре AWS с 8 GPU и 480 гигабайт оперативной памяти (смеется). Я тоже получил ошибку ...

Выесть идеи почему?

Вот часть кода генератора и часть поезда.

def res_block_gen(model, kernal_size, filters, strides):    
gen = model

model = Conv2D(filters = filters, kernel_size = kernal_size, strides = strides, padding = "same")(model)
model = BatchNormalization(momentum = 0.5)(model)
# Relu avec parametres
model = PReLU(alpha_initializer='zeros', alpha_regularizer=None, alpha_constraint=None, shared_axes=[1,2])(model)
model = Conv2D(filters = filters, kernel_size = kernal_size, strides = strides, padding = "same")(model)
model = BatchNormalization(momentum = 0.5)(model)

model = add([gen, model])

return model



def up_sampling_block(model, kernal_size, filters, strides):

model = Conv2DTranspose(filters = filters, kernel_size = kernal_size, strides = strides, padding = "same")(model) #a la place de Upsampling
#model = Conv2D(filters = filters, kernel_size = kernal_size, strides = strides, padding = "same")(model)
model = UpSampling2D(size = 2)(model)
model = LeakyReLU(alpha = 0.2)(model)

return model


class Generator(object):
def __init__(self, noise_shape):

    self.noise_shape = noise_shape

def generator(self):

    gen_input = Input(shape = self.noise_shape)

    model = Conv2D(filters = 64, kernel_size = 9, strides = 1, padding = "same")(gen_input)
    model = PReLU(alpha_initializer='zeros', alpha_regularizer=None, alpha_constraint=None, shared_axes=[1,2])(model)

    gen_model = model

    # 16 blocks residuals avec skip connection
    #for index in range(2):
        #model = res_block_gen(model, 3, 64, 1)

    #model = Conv2D(filters = 64, kernel_size = 3, strides = 1, padding = "same")(model)
    #model = BatchNormalization(momentum = 0.5)(model)
    #model = add([gen_model, model])

    # 2 blocks upsampling
    #for index in range(2):
    #    model = up_sampling_block(model, 3, 256, 1)

    model = Conv2D(filters = 3, kernel_size = 9, strides = 1, padding = "same")(model)
    model = Activation('relu')(model) ## tanh

    generator_model = Model(inputs = gen_input, outputs = model)
    return generator_model

На этой линии в поезде есть ошибка:

generated_images_sr = generator.predict(image_batch_lr)

Спасибо за советы!

...