копирование весов с помощью pytorch работало только с двоеточием - PullRequest
0 голосов
/ 09 марта 2020
    resnet18 = models.resnet18(pretrained=True)
    reg_resnet = resnet_model()
    for each_param in resnet18.state_dict().keys():
        reg_resnet.state_dict()[each_param][:] = resnet18.state_dict()[each_param]

Приведенный выше код работал для меня при обновлении весов моей модели reg_re snet (которая инициализируется с нуля) с предварительно подготовленными весами re snet 18. Но он не работал, когда у меня был ниже:

resnet18 = models.resnet18(pretrained=True)
reg_resnet = resnet_model()
for each_param in resnet18.state_dict().keys():
    reg_resnet.state_dict()[each_param] = resnet18.state_dict()[each_param]

Почему при добавлении '[:]' это работает? Что это делает в pytorch?

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...