PyTorch, поскольку 1.4.0
обеспечивает обрезку модели из коробки, см. Официальное руководство .
Поскольку в настоящее время в PyTorch нет метода threshold
для обрезки, вам необходимо реализовать это сделать самостоятельно, хотя это довольно просто, если вы получите общее представление.
Метод обрезки порога
Ниже приведен код, выполняющий обрезку:
from torch.nn.utils import prune
class ThresholdPruning(prune.BasePruningMethod):
PRUNING_TYPE = "unstructured"
def __init__(self, threshold):
self.threshold = threshold
def compute_mask(self, tensor, default_mask):
return torch.abs(tensor) > self.threshold
Пояснение:
PRUNING_TYPE
может быть одним из global
, structured
, unstructured
. global
действует для всего модуля (например, удаляет 20%
веса с наименьшим значением), structured
действует для целых каналов / модулей. Нам нужно unstructured
, так как мы хотели бы изменить каждое соединение в конкретном тензоре параметров c (скажем, weight
или bias
) __init__
- передайте здесь все, что вы хотите или нужно сделать работа, нормальные вещи compute_mask
- маска, которая будет использоваться для удаления указанного c тензора. В нашем случае все параметры ниже порога должны быть равны нулю. Я сделал это с абсолютной ценностью, поскольку это имеет больше смысла. default_mask
здесь не требуется, но остается как именованный параметр, так как API требует атм.
Более того, наследование от prune.BasePruningMethod
определяет методы для применения маски к каждому параметру, что делает сокращение постоянным et c. См. документацию базового класса для получения дополнительной информации.
Пример модуля
Ничего особенного, вы можете разместить здесь все, что захотите:
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.first = torch.nn.Linear(50, 30)
self.second = torch.nn.Linear(30, 10)
def forward(self, inputs):
return self.second(torch.relu(self.first(inputs)))
module = MyModule()
Вы также можете загрузить ваш модуль через module = torch.load('checkpoint.pth')
, если вам нужно, это не имеет значения.
Параметры модуля Prune
Мы должны определить, какой параметр нашего модуля (и будет ли он weight
или bias
) следует обрезать, например:
parameters_to_prune = ((module.first, "weight"), (module.second, "weight"))
Теперь мы можем применить global
ly нашу unstructured
обрезку ко всем определенным parameters
(threshold
передается как kwarg
к __init__
из ThresholdPruning
):
prune.global_unstructured(
parameters_to_prune, pruning_method=ThresholdPruning, threshold=0.1
)
Результаты
weight
атрибут
Чтобы увидеть эффект, проверьте веса подмодуля first
просто с помощью :
print(module.first.weight)
Это вес с примененной нашей техникой обрезки, но обратите внимание, что это больше не torch.nn.Parameter
! Теперь это просто атрибут нашей модели, следовательно, он не будет участвовать в обучении или оценке в настоящее время.
weight_mask
Мы можем проверить созданную маску через module.first.weight_mask
, чтобы увидеть, что все сделано правильно (в этом случае он будет двоичным).
weight_orig
Применение обрезки создает новый torch.nn.Parameter
с исходными весами с именем name + _orig
, в данном случае weight_orig
, посмотрим:
print(module.first.weight_orig)
Этот параметр будет использоваться во время обучения и оценки в настоящее время! . После применения pruning
с помощью методов, описанных выше, добавляется forward_pre_hooks
, который "переключает" исходный weight
на weight_orig
.
Благодаря такому подходу вы можете определить и применить обрезку к любой части training
или inference
без «разрушения» исходных весов.
Постоянное применение обрезки
Если вы используете sh, чтобы применить сокращение на постоянной основе, просто введите:
prune.remove(module.first, "weight")
И теперь наш module.first.weight
снова является параметром с соответствующими сокращенными записями, module.first.weight_mask
удаляется, как и module.first.weight_orig
. Это то, чем вы, вероятно, станете после .
Вы можете перебрать children
, чтобы сделать его постоянным:
for child in module.children():
prune.remove(child, "weight")
Вы можете определить parameters_to_prune
, используя тот же logi c:
parameters_to_prune = [(child, "weight") for child in module.children()]
Или, если вы хотите обрезать только convolution
слоев (или что-то еще):
parameters_to_prune = [
(child, "weight")
for child in module.children()
if isinstance(child, torch.nn.Conv2d)
]
Преимущества
- использует "способ обрезки PyTorch", чтобы было легче общаться ваше намерение по отношению к другим программистам
- определение сокращения для каждого тензора, единственная ответственность вместо того, чтобы проходить через все
- ограничиваться предопределенными способами
- сокращение не является постоянным , следовательно, вы можете оправиться от него при необходимости. Модуль может быть сохранен с масками обрезки и исходными весами, поэтому он оставляет вам место для исправления возможной ошибки (например,
threshold
было слишком большим, и теперь все ваши веса равны нулю, что делает результаты бессмысленными) - работает с исходным весом во время вызовов
forward
, если вы не хотите окончательно перейти на сокращенную версию (простой вызов remove
)
Недостатки
- Удаление IMO API мог бы быть более понятным
- Вы можете сделать это короче (как предусмотрено
Shai
) - может сбивать с толку тех, кто не знает, что такая вещь "определена" от PyTorch (все еще есть учебники и документы, поэтому я не думаю, что это серьезная проблема)