я работаю над изображениями.Я сделал модель gan с keras для работы с этими изображениями, но у меня возникла ошибка OOM в генеративной части.
Я пытался уменьшить пакет (до 1), я также уменьшил размер изображений, я также уменьшил количество слоев в моей модели, но я всегда получал эту ошибку.Я думал, что функция предиката игнорировала размер пакета, поэтому я попытался поместить только 10 изображений в мои файлы (размер одного изображения = (400 500,3))
Когда я смотрю на обобщение генеративной модели, ясмотрите параметры 31k, это ничего.
Я также запустил свой код на экземпляре AWS с 8 GPU и 480 гигабайт оперативной памяти (смеется). Я тоже получил ошибку ...
Выесть идеи почему?
Вот часть кода генератора и часть поезда.
def res_block_gen(model, kernal_size, filters, strides):
gen = model
model = Conv2D(filters = filters, kernel_size = kernal_size, strides = strides, padding = "same")(model)
model = BatchNormalization(momentum = 0.5)(model)
# Relu avec parametres
model = PReLU(alpha_initializer='zeros', alpha_regularizer=None, alpha_constraint=None, shared_axes=[1,2])(model)
model = Conv2D(filters = filters, kernel_size = kernal_size, strides = strides, padding = "same")(model)
model = BatchNormalization(momentum = 0.5)(model)
model = add([gen, model])
return model
def up_sampling_block(model, kernal_size, filters, strides):
model = Conv2DTranspose(filters = filters, kernel_size = kernal_size, strides = strides, padding = "same")(model) #a la place de Upsampling
#model = Conv2D(filters = filters, kernel_size = kernal_size, strides = strides, padding = "same")(model)
model = UpSampling2D(size = 2)(model)
model = LeakyReLU(alpha = 0.2)(model)
return model
class Generator(object):
def __init__(self, noise_shape):
self.noise_shape = noise_shape
def generator(self):
gen_input = Input(shape = self.noise_shape)
model = Conv2D(filters = 64, kernel_size = 9, strides = 1, padding = "same")(gen_input)
model = PReLU(alpha_initializer='zeros', alpha_regularizer=None, alpha_constraint=None, shared_axes=[1,2])(model)
gen_model = model
# 16 blocks residuals avec skip connection
#for index in range(2):
#model = res_block_gen(model, 3, 64, 1)
#model = Conv2D(filters = 64, kernel_size = 3, strides = 1, padding = "same")(model)
#model = BatchNormalization(momentum = 0.5)(model)
#model = add([gen_model, model])
# 2 blocks upsampling
#for index in range(2):
# model = up_sampling_block(model, 3, 256, 1)
model = Conv2D(filters = 3, kernel_size = 9, strides = 1, padding = "same")(model)
model = Activation('relu')(model) ## tanh
generator_model = Model(inputs = gen_input, outputs = model)
return generator_model
На этой линии в поезде есть ошибка:
generated_images_sr = generator.predict(image_batch_lr)
Спасибо за советы!