PyTorch: как определить новую нейронную сеть, которая использует трансферное обучение - PullRequest
1 голос
/ 22 января 2020

Я мигрирую из фреймворков Keras / TF и ​​у меня возникли небольшие проблемы с пониманием процесса обучения переносу в PyTorch.

Я хочу использовать каркас Pytorch-Lightning и хочу переключаться между разными нейронными сетями в одном скрипте .

По этому примеру мы можем переключаться между различными нейронными сетями в их реализации:

class BERT(pl.LightningModule):
def __init__(self, model_name, task):
    self.task = task

    if model_name == 'transformer':
        self.net = Transformer()
    elif model_name == 'my_cool_version':
        self.net = MyCoolVersion()

Вопрос: как создать новую нейронную сеть что расширяет nn.Module и использует процесс обучения передачи?

Моя собственная реализация выглядит следующим образом: я использую сеть vgg16 и заменил слой классификатора только одним f c двумя выходными нейронами.

class VGGNetwork(nn.Module):
    def __init__(self):
        super(VGGNetwork, self).__init__()
        # vgg16 is the default model here, we can use bn etc...
        self.model = vgg16(pretrained=True)

        # removing the last three layers of classifier only 2 ...
        self.model.classifier = nn.Sequential(nn.Linear(512 * 7 * 7, 2))

def forward(self, x):
    return self.model.forward(x)

Это правильный способ, как это сделать?

Ответы [ 2 ]

2 голосов
/ 22 января 2020

вы можете заморозить веса и байты для слоя нейронной сети, за исключением последнего слоя.

вы можете использовать require_grad = False

for param in model_conv.parameters():
    param.requires_grad = False

Подробнее об этом можно узнать по следующей ссылке https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html

0 голосов
/ 11 марта 2020

https://pytorch-lightning.readthedocs.io/en/0.7.1/transfer_learning.html

    ...

class AutoEncoder(pl.LightningModule):
    def __init__(self):
        self.encoder = Encoder()
        self.decoder = Decoder()

class CIFAR10Classifier(pl.LightingModule):
    def __init__(self):
        # init the pretrained LightningModule
        self.feature_extractor = AutoEncoder.load_from_checkpoint(PATH)
        self.feature_extractor.freeze()

        # the autoencoder outputs a 100-dim representation and CIFAR-10 has 10 classes
        self.classifier = nn.Linear(100, 10)

    def forward(self, x):
        representations = self.feature_extractor(x)
        x = self.classifier(representations)
        ...
...