lasagne.layers.set_all_param_values ​​возвращает значение None - PullRequest
0 голосов
/ 24 октября 2019

Я работаю в глубокой нейронной сети, в которой я тренируюсь и хочу сохранить модель на моей машине в виде рассола, и она была сохранена с использованием приведенного ниже кода

import lasagne
import theano
import theano.tensor as T
import pickle

if __name__ == "__main__":
    x_train,y_train,x_test,y_test = load_dataset()

    input_var = T.tensor4('inputs')
    target_var = T.ivector('targets')

    network = build_nn(input_var)

    prediction = lasagne.layers.get_output(network)

    loss = lasagne.objectives.categorical_crossentropy(prediction,target_var)

    loss = loss.mean()

    params = lasagne.layers.get_all_params(network, trainable=True)

    updates = lasagne.updates.nesterov_momentum(loss,params,learning_rate=0.01 , momentum=0.9)

    train_fn = theano.function([input_var,target_var],loss , updates=updates)

    num_training_steps = 2

    for steps in range(num_training_steps):
        train_err = train_fn(x_train,y_train)
        print("current step is " + str(steps))
    pickle_out = open('test_pickle','wb')
    netInfo = {'network': network, 'params': lasagne.layers.get_all_param_values(network)}
    pickle.dump(netInfo,pickle_out)

. Проблема заключается в загрузке сохраненного файла. следующим образом и установите параметры сети

pickle_in = open('test_pickle','rb')
    net = pickle.load(pickle_in)
    all_params = net['params']
    print(net['network'])
    print(all_params)
    network = lasagne.layers.set_all_param_values(net['network'], all_params)
    print(network)

значение сети является значением None, но я проверил как сеть, так и параметры, которые они содержат значения

...