Pytorch: AttributeError: у объекта 'function' нет атрибута 'cuda' - PullRequest
0 голосов
/ 10 января 2020
import torch
import models
model_names = sorted(name for name in models.__dict__
                     if name.islower() and not name.startswith("__")
                     and callable(models.__dict__[name]))
model = models.__dict__['resnet18']
model = torch.nn.DataParallel(model,device_ids = [0])  #PROBLEM CAUSING LINE
model.to('cuda:0')

Для запуска этого кода вам необходимо клонировать этот репозиторий: https://github.com/SoftwareGift/FeatherNets_Face-Anti-spoofing-Attack-Detection-Challenge-CVPR2019.git

Пожалуйста, запустите этот фрагмент кода в папке root клонированного каталога.

Я получаю следующую ошибку AttributeError: 'function' object has no attribute 'cuda' Я попытался использовать объект torch.device для той же функции, и это приводит к той же ошибке. Пожалуйста, попросите любые другие детали, которые требуются. PyTorch newb ie здесь python: 3.7 pytorch: 1.3.1

1 Ответ

1 голос
/ 11 января 2020

Заменить

model = torch.nn.DataParallel(model,device_ids = [0])

на

model = torch.nn.DataParallel(model(), device_ids=[0])

(обратите внимание на () после модели внутри DataParallel). Разница проста: модуль models содержит классы / функции, которые создают модели, а не экземпляры моделей. Если вы проследите за импортом, вы обнаружите, что models.__dict__['resnet18'] разрешает эту функцию. Поскольку DataParallel переносит экземпляр, а не сам класс, он несовместим. () вызывает эту функцию построения модели / конструктор класса для создания экземпляра этой модели.

Гораздо более простым примером этого будет следующий

class MyNet(nn.Model):
    def __init__(self):
        self.linear = nn.Linear(4, 4)
    def forward(self, x):
        return self.linear(x)

model = nn.DataParallel(MyNet) # this is what you're doing
model = nn.DataParallel(MyNet()) # this is what you should be doing

Ваше сообщение об ошибке жалуется на то, что function (поскольку model без () имеет тип function) не имеет атрибута cuda, который является методом nn.Model экземпляров .

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...