Как кэшировать модели Pytorch для использования, когда они не подключены к inte rnet? - PullRequest
1 голос
/ 20 февраля 2020

Я использую vgg19 в проблеме классификации. У меня есть доступ к исследовательскому компьютеру кампуса, но на узлах, где проводятся вычисления, нет доступа к inte rnet. Таким образом, выполнение строки кода, подобной self.net = models.vgg19(pretrained=True), завершается неудачно с ошибкой urllib.error.URLError: <urlopen error [Errno 101] Network is unreachable>

. Есть ли способ, которым я мог бы кэшировать модель на головном узле (где у меня есть inte rnet доступ) и загрузить модель из кеша вместо inte rnet на вычислительном узле?

1 Ответ

1 голос
/ 20 февраля 2020

Если вы просто где-то сохраняете веса предварительно обученных сетей, вы можете загружать их так же, как и любые другие веса сетей.

Сохранение:

import torchvision

#  I am assuming we have internet access here
model = torchvision.models.vgg16(pretrained=True)
torch.save(model.state_dict(), "Somewhere")

Загрузка:

import torchvision

def create_vgg16(dict_path=None):
    model = torchvision.models.vgg16(pretrained=False)
    if (dict_path != None):
        model.load_state_dict(torch.load(dict_path))
    return model

model = create_vgg16("Somewhere")
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...