Как обрезать веса меньше порога в PyTorch? - PullRequest
4 голосов
/ 06 мая 2020

Как отсечь веса модели CNN (сверточной нейронной сети), которая меньше порогового значения (давайте рассмотрим отсечение всех весов, которые <= 1). </em>

Как мы можем добиться этого для файла веса, сохраненного в формате .pth в pytorch?

Ответы [ 2 ]

5 голосов
/ 06 мая 2020

PyTorch, поскольку 1.4.0 обеспечивает обрезку модели из коробки, см. Официальное руководство .

Поскольку в настоящее время в PyTorch нет метода threshold для обрезки, вам необходимо реализовать это сделать самостоятельно, хотя это довольно просто, если вы получите общее представление.

Метод обрезки порога

Ниже приведен код, выполняющий обрезку:

from torch.nn.utils import prune


class ThresholdPruning(prune.BasePruningMethod):
    PRUNING_TYPE = "unstructured"

    def __init__(self, threshold):
        self.threshold = threshold

    def compute_mask(self, tensor, default_mask):
        return torch.abs(tensor) > self.threshold

Пояснение:

  • PRUNING_TYPE может быть одним из global, structured, unstructured. global действует для всего модуля (например, удаляет 20% веса с наименьшим значением), structured действует для целых каналов / модулей. Нам нужно unstructured, так как мы хотели бы изменить каждое соединение в конкретном тензоре параметров c (скажем, weight или bias)
  • __init__ - передайте здесь все, что вы хотите или нужно сделать работа, нормальные вещи
  • compute_mask - маска, которая будет использоваться для удаления указанного c тензора. В нашем случае все параметры ниже порога должны быть равны нулю. Я сделал это с абсолютной ценностью, поскольку это имеет больше смысла. default_mask здесь не требуется, но остается как именованный параметр, так как API требует атм.

Более того, наследование от prune.BasePruningMethod определяет методы для применения маски к каждому параметру, что делает сокращение постоянным et c. См. документацию базового класса для получения дополнительной информации.

Пример модуля

Ничего особенного, вы можете разместить здесь все, что захотите:

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.first = torch.nn.Linear(50, 30)
        self.second = torch.nn.Linear(30, 10)

    def forward(self, inputs):
        return self.second(torch.relu(self.first(inputs)))


module = MyModule()

Вы также можете загрузить ваш модуль через module = torch.load('checkpoint.pth'), если вам нужно, это не имеет значения.

Параметры модуля Prune

Мы должны определить, какой параметр нашего модуля (и будет ли он weight или bias) следует обрезать, например:

parameters_to_prune = ((module.first, "weight"), (module.second, "weight"))

Теперь мы можем применить global ly нашу unstructured обрезку ко всем определенным parameters (threshold передается как kwarg к __init__ из ThresholdPruning):

prune.global_unstructured(
    parameters_to_prune, pruning_method=ThresholdPruning, threshold=0.1
)

Результаты

weight атрибут

Чтобы увидеть эффект, проверьте веса подмодуля first просто с помощью :

print(module.first.weight)

Это вес с примененной нашей техникой обрезки, но обратите внимание, что это больше не torch.nn.Parameter! Теперь это просто атрибут нашей модели, следовательно, он не будет участвовать в обучении или оценке в настоящее время.

weight_mask

Мы можем проверить созданную маску через module.first.weight_mask, чтобы увидеть, что все сделано правильно (в этом случае он будет двоичным).

weight_orig

Применение обрезки создает новый torch.nn.Parameter с исходными весами с именем name + _orig, в данном случае weight_orig, посмотрим:

print(module.first.weight_orig)

Этот параметр будет использоваться во время обучения и оценки в настоящее время! . После применения pruning с помощью методов, описанных выше, добавляется forward_pre_hooks, который "переключает" исходный weight на weight_orig.

Благодаря такому подходу вы можете определить и применить обрезку к любой части training или inference без «разрушения» исходных весов.

Постоянное применение обрезки

Если вы используете sh, чтобы применить сокращение на постоянной основе, просто введите:

prune.remove(module.first, "weight")

И теперь наш module.first.weight снова является параметром с соответствующими сокращенными записями, module.first.weight_mask удаляется, как и module.first.weight_orig. Это то, чем вы, вероятно, станете после .

Вы можете перебрать children, чтобы сделать его постоянным:

for child in module.children():
    prune.remove(child, "weight")

Вы можете определить parameters_to_prune, используя тот же logi c:

parameters_to_prune = [(child, "weight") for child in module.children()]

Или, если вы хотите обрезать только convolution слоев (или что-то еще):

parameters_to_prune = [
    (child, "weight")
    for child in module.children()
    if isinstance(child, torch.nn.Conv2d)
]

Преимущества

  • использует "способ обрезки PyTorch", чтобы было легче общаться ваше намерение по отношению к другим программистам
  • определение сокращения для каждого тензора, единственная ответственность вместо того, чтобы проходить через все
  • ограничиваться предопределенными способами
  • сокращение не является постоянным , следовательно, вы можете оправиться от него при необходимости. Модуль может быть сохранен с масками обрезки и исходными весами, поэтому он оставляет вам место для исправления возможной ошибки (например, threshold было слишком большим, и теперь все ваши веса равны нулю, что делает результаты бессмысленными)
  • работает с исходным весом во время вызовов forward, если вы не хотите окончательно перейти на сокращенную версию (простой вызов remove)

Недостатки

  • Удаление IMO API мог бы быть более понятным
  • Вы можете сделать это короче (как предусмотрено Shai)
  • может сбивать с толку тех, кто не знает, что такая вещь "определена" от PyTorch (все еще есть учебники и документы, поэтому я не думаю, что это серьезная проблема)
1 голос
/ 06 мая 2020

Вы можете работать непосредственно со значениями, сохраненными в state_dict:

sd = torch.load('saved_weights.pth')  # load the state dicd
for k in sd.keys():
  if not 'weight' in k:
    continue  # skip biases and other saved parameters
  w = sd[k]
  sd[k] = w * (w > thr)  # set to zero weights smaller than thr 
torch.save(sd, 'pruned_weights.pth')
...