Я использую пример Keras OCR: https://github.com/keras-team/keras/blob/master/examples/image_ocr.py для распознавания рукописного текста онлайн, но столкнулся с проблемой выделения памяти после обучения модели при использовании функции theano для получения вывода softmax.Форма x_train: (1200,1586,4).Я передаю 1200 последовательностей штрихов в пакетах по 12. Вот фрагмент кода:
inputs = Input(name='the_input', shape=x_train.shape[1:], dtype='float32')
rnn_encoded = Bidirectional(LSTM(64, return_sequences=True,kernel_initializer=init,bias_initializer=bias),name='bidirectional_1',merge_mode='concat',trainable=trainable)(inputs)
birnn_encoded = Bidirectional(LSTM(32, return_sequences=True,kernel_initializer=init,bias_initializer=bias),name='bidirectional_2',merge_mode='concat',trainable=trainable)(rnn_encoded)
trirnn_encoded=Bidirectional(LSTM(16,return_sequences=True,kernel_initializer=init,bias_initializer=bias),name='bidirectional_3',merge_mode='concat',trainable=trainable)(birnn_encoded)
output = TimeDistributed(Dense(28, name='dense',kernel_initializer=init,bias_initializer=bias))(trirnn_encoded)
y_pred = Activation('softmax', name='softmax')(output)
model=Model(inputs=inputs,outputs=y_pred)
labels = Input(name='the_labels', shape=[max_len], dtype='int32')
input_length = Input(name='input_length', shape=[1], dtype='int64')
label_length = Input(name='label_length', shape=[1], dtype='int64')
loss_out = Lambda(ctc_lambda_func, output_shape=(1,), name='ctc')([y_pred, labels, input_length, label_length])
model = Model(inputs=[inputs, labels, input_length, label_length], outputs=loss_out)
opt=RMSprop(lr=0.001,clipnorm=1.)
model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer=opt)
gc.collect()
my_generator = generator(x_train,y_train,batch_size)
hist= model.fit_generator(my_generator,epochs=80,steps_per_epoch=100,shuffle=True,use_multiprocessing=False,workers=1)
model.save(mfile)
test_func = K.function([inputs], [y_pred])
Ошибка выделения памяти происходит в последней строке.Я использую 32 ГБ ОЗУ с 8vCPU на AWS.Ошибка не возникает, когда я запускаю код для меньшего количества эпох (около 30-40), но в основном, когда я бегу для большого количества эпох, например, 80-100.Я также приложил скриншот ошибки. 1 Пожалуйста, предложите мне, если есть какое-либо решение проблемы, кроме уменьшения размера набора данных или количества эпох.