Получите предварительно обученную модель ImageNet (resnet152
имеет лучшую точность):
from torchvision import models
# https://pytorch.org/docs/stable/torchvision/models.html
model = models.resnet152(pretrained=True)
Распечатайте ее структуру, чтобы мы могли сравнить ее с конечным состоянием:
print(model)
Удалите последний модуль (как правило, один полностью связанный слой) из модели:
classifier_name, old_classifier = model._modules.popitem()
Заморозьте параметры детектора элементов модели, чтобы они не были отрегулированы при обратном распространении:
for param in model.parameters():
param.requires_grad = False
Создайте новый классификатор:
classifier_input_size = old_classifier.in_features
classifier = nn.Sequential(OrderedDict([
('fc1', nn.Linear(classifier_input_size, hidden_layer_size)),
('activation', nn.SELU()),
('dropout', nn.Dropout(p=0.5)),
('fc2', nn.Linear(hidden_layer_size, output_layer_size)),
('output', nn.LogSoftmax(dim=1))
]))
Имя модуля для нашего классификатора должно совпадать с именем, которое было удалено.Добавьте наш новый классификатор в конец детектора функций:
model.add_module(classifier_name, classifier)
Наконец, распечатайте структуру новой сети:
print(model)