Попытка лучше понять память ООМ в керасе - PullRequest
0 голосов
/ 16 апреля 2020

Я использую очень большую сверточную нейронную сеть с использованием Keras. Когда я тренирую сеть с помощью команды fit, у меня обычно не возникает проблем. Когда я запускаю команду train_on_batch, у меня обычно возникают проблемы с нехваткой памяти. Странная вещь, однако, заключается в том, что при попытке передать данные в сеть не хватает памяти сразу, чего я и ожидал при ошибке OOM. Вместо этого у меня заканчивается память в 4-й или 5-й эпохе обучения, когда я ожидаю, что вся память уже будет выделена.

Одна вещь, которую я делаю, может быть, причина в том, что у меня огромная сеть и большие обучающие наборы данных для сопровождения этой крупной сети, я стараюсь пропустить только небольшое количество данных за раз , сбрасывая новые входы на выход обучения каждый раз. Может ли это быть возможной причиной нехватки памяти через короткое время?

Код добавлен:

def read_case(name,file_,parameter):
    """
    Reads the specific parameter data from the given file for the specific case.
    """
    case = file_[name]
    keylist = list(case.keys())
    #print("Keylist {}".format(keylist))
    if keylist:
        shape = list(case[keylist[0]][parameter].shape)
        shape.insert(0,len(keylist))
        mymatrix = numpy.zeros(shape)
        for i,key in enumerate(keylist):
            mymatrix[i] = case[key][parameter][:]

        return mymatrix

Затем основной l oop, где я запускаю нейронную сеть

    ash_ketchum = NN_Trainer()
    ash_ketchum.training_list = training_list
    ash_ketchum.validation_list = validation_list
    ash_ketchum.library_list.append(lib_one)
    #ash_ketchum.library_list.append(lib_two)

    keylist = list(lib_one.keys())
    testkey = keylist[0]
    input1 =  read_case(testkey,lib_one,'pin_power')
    output1 = read_case(testkey,lib_one,'pin_steamrate')

    model = assemble(input1,output1)

    model.compile(optimizer='adam',loss='mse',metrics=['mae'])

    for i in range(50):
        print(i)
        ash_ketchum.train(model)
        ash_ketchum.validate(model,ash_ketchum.training_list)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...