Использование Pypi предварительно обученных моделей против PyTorch - PullRequest
1 голос
/ 10 июля 2019

У меня есть две разные установки - одна занимает ок. 10 минут, чтобы запустить другой все еще идет через час:

10 м:

import pretrainedmodels 

def resnext50_32x4d(pretrained=False):
    pretrained = 'imagenet' if pretrained else None
    model = pretrainedmodels.se_resnext50_32x4d(pretrained=pretrained)
    return nn.Sequential(*list(model.children()))

learn = cnn_learner(data, resnext50_32x4d, pretrained=True, cut=-2, split_on=lambda m: (m[0][3], m[1]),metrics=[accuracy, error_rate])

Не отделка:

import torchvision.models as models

def get_model(pretrained=True, model_name = 'resnext50_32x4d', **kwargs ):
    arch = models.resnext50_32x4d(pretrained, **kwargs )
    return arch

learn = Learner(data, get_model(), metrics=[accuracy, error_rate])

Это все скопировано и взломано из кода других людей, поэтому есть части, которые я не понимаю. Но самое непонятное, почему один будет намного быстрее другого. Я хотел бы использовать второй вариант, потому что его легче понять, и я могу просто поменять предварительно обученную модель, чтобы протестировать разные.

1 Ответ

1 голос
/ 10 июля 2019

Обе архитектуры разные. Я предполагаю, что вы используете pretrained-models.pytorch .

Обратите внимание, что вы используете SE -ResNeXt в первом примере и ResNeXt во втором (стандартный от torchvision).

Первая версия использует более быструю блочную архитектуру (Squeeze and Excitation), исследовательская работа описывает ее здесь .

Я не уверен в точных различиях между архитектурами и реализациями, за исключением того, что используются разные строительные блоки, но Вы можете print обе модели и проверить различия.

Наконец здесь - хорошая статья, обобщающая, что такое Squeeze And Excitation . По сути, вы делаете GlobalAveragePooling на всех каналах (в дальнейшем это будет torch.nn.AdaptiveAvgPoo2d(1) и flatten впоследствии), проталкиваете его через два линейных слоя (с активацией ReLU), завершенных к sigmoid, чтобы получить веса для каждого канала. Наконец вы умножаете каналы на эти.

Кроме того, вы делаете что-то странное с модулями, превращающими их в torch.nn.Sequential. Может быть какая-то логика в forward вызове предварительно обученной сети, которую вы удаляете путем копирования модулей, это может также сыграть свою роль.

...