Как сохранить веса в файле .npy в аналогичном формате, который используется в CNN? - PullRequest
0 голосов
/ 04 мая 2019

Я использую репозиторий github, содержащий обученный CNN с весовыми параметрами, указанными в файле .npy. Модель загружает веса и использует параметры модели следующим образом: -

model = CNN_Model(batch_size)
filename = "weight_file.npy"
dtype = torch.FloatTensor    
model.load_state_dict(load_weights(model, weight_file, dtype))

А load_weights определяется как: -

def load_weights(model, filename, dtype):
    model_params = model.state_dict()
    data_dict = np.load(filename, encoding='latin1').item()
    model_params["conv1.weight"] = torch.from_numpy(data_dict["conv1"] ["weights"]).type(dtype).permute(3,2,0,1)
    model_params["conv1.bias"] = torch.from_numpy(data_dict["conv1"]["biases"]).type(dtype)
    model_params["bn1.weight"] = torch.from_numpy(data_dict["bn_conv1"]["scale"]).type(dtype)
    model_params["bn1.bias"] = torch.from_numpy(data_dict["bn_conv1"]["offset"]).type(dtype)
    return model_params

Я добавил к нему учебный модуль и пытаюсь подстроить веса в моем собственном наборе данных. После обучения я хочу сохранить новые веса в файле .npy с теми же индексами data_dict, что и в ранее загруженном файле весов, чтобы я мог использовать их снова для модели CNN.

Как мне выполнить индексацию с похожими именами перед сохранением массива data_dict, используя:

np.save("trained_weight_file.npy", data_dict)

РЕДАКТИРОВАТЬ 1: - Поэтому по рекомендации @ a-d я сделал

data_dict = model.state_dict()

Он сохранил все веса с индексом model_params. Вывод print data_dict был: -

OrderedDict([('conv1.weight', tensor([[[[....]]]])), ('conv1.bias', tensor([....])), , ('bn1.weight', tensor([....])), ('bn1.bias', tensor([....]))])

Но мне нужно сохранить индекс data_dict, чтобы я мог прочитать его с тем же алгоритмом из файла .npy. Также я попытался вернуть data_dict вместе с model_params из load_weights определения, а затем попытался использовать data_dict = model.state_dict(), но это дало мне ошибку в строке `model.load_state_dict (load_weights (model, weight_file, dtype)) ', которая является: -

Traceback (последний вызов был последним): model.load_state_dict (load_weights (model, weight_file, dtype)) state_dict = state_dict.copy () AttributeError: объект 'tuple' не имеет атрибута 'copy'

1 Ответ

0 голосов
/ 04 мая 2019

Я бы сделал что-то вроде data_dict = model.state_dict().

Вы можете прочитать официальную документацию с примером вывода state_dict() здесь .Существует репозиторий github , который является базой репозитория github, из которого вы можете получить свой код.Этот репозиторий также использует model.state_dict() для хранения значений.

...