Я использую репозиторий github, содержащий обученный CNN с весовыми параметрами, указанными в файле .npy
. Модель загружает веса и использует параметры модели следующим образом: -
model = CNN_Model(batch_size)
filename = "weight_file.npy"
dtype = torch.FloatTensor
model.load_state_dict(load_weights(model, weight_file, dtype))
А load_weights
определяется как: -
def load_weights(model, filename, dtype):
model_params = model.state_dict()
data_dict = np.load(filename, encoding='latin1').item()
model_params["conv1.weight"] = torch.from_numpy(data_dict["conv1"] ["weights"]).type(dtype).permute(3,2,0,1)
model_params["conv1.bias"] = torch.from_numpy(data_dict["conv1"]["biases"]).type(dtype)
model_params["bn1.weight"] = torch.from_numpy(data_dict["bn_conv1"]["scale"]).type(dtype)
model_params["bn1.bias"] = torch.from_numpy(data_dict["bn_conv1"]["offset"]).type(dtype)
return model_params
Я добавил к нему учебный модуль и пытаюсь подстроить веса в моем собственном наборе данных. После обучения я хочу сохранить новые веса в файле .npy
с теми же индексами data_dict
, что и в ранее загруженном файле весов, чтобы я мог использовать их снова для модели CNN.
Как мне выполнить индексацию с похожими именами перед сохранением массива data_dict, используя:
np.save("trained_weight_file.npy", data_dict)
РЕДАКТИРОВАТЬ 1: -
Поэтому по рекомендации @ a-d я сделал
data_dict = model.state_dict()
Он сохранил все веса с индексом model_params
. Вывод print data_dict
был: -
OrderedDict([('conv1.weight', tensor([[[[....]]]])), ('conv1.bias', tensor([....])), , ('bn1.weight', tensor([....])), ('bn1.bias', tensor([....]))])
Но мне нужно сохранить индекс data_dict
, чтобы я мог прочитать его с тем же алгоритмом из файла .npy
. Также я попытался вернуть data_dict
вместе с model_params
из load_weights
определения, а затем попытался использовать data_dict = model.state_dict()
, но это дало мне ошибку в строке `model.load_state_dict (load_weights (model, weight_file, dtype)) ', которая является: -
Traceback (последний вызов был последним):
model.load_state_dict (load_weights (model, weight_file, dtype))
state_dict = state_dict.copy ()
AttributeError: объект 'tuple' не имеет атрибута 'copy'