Заморозить только некоторые строки объекта torch.nn.Embedding - PullRequest
0 голосов
/ 10 марта 2020

Я довольно новичок в ie для Pytorch, и я пытаюсь реализовать своего рода процедуру "посттренинга" встраивания.

У меня есть словарь с набором предметов, и я выучил один вектор для каждого из них. Я сохраняю изученные векторы в объекте nn.Embedding. Теперь я хотел бы добавить новый элемент в словарь, не обновляя уже выученные векторы. Внедрение для нового элемента будет инициализировано случайным образом, а затем обучено, сохраняя все остальные вложения замороженными.

Я знаю, что для предотвращения обучения nn.Embedding мне нужно установить значение False его requires_grad переменная. Я также нашел этот другой вопрос , который похож на мой. Наилучший ответ предполагает, что

  1. либо хранит замороженные векторы и вектор для обучения в различных nn.Вмещающие объекты, первый с requires_grad = False, а второй с requires_grad = True

  2. или сохраните замороженные векторы и новый в одном и том же объекте nn.Embedding, вычисляя градиент по всем векторам, но убывая его только по измерениям вектора нового элемента. Это, однако, приводит к значительному ухудшению характеристик (чего, конечно, я хочу избежать).

Моя проблема заключается в том, что мне действительно нужно сохранить вектор для нового элемента в тот же объект nn.Embedding, что и замороженные векторы старых элементов. Причина этого ограничения заключается в следующем: при построении моей функции потерь с вложениями элементов (старых и новых) мне нужно искать векторы на основе идентификаторов элементов, а по причинам производительности мне нужно использовать Python нарезка. Другими словами, учитывая список идентификаторов элементов item_ids, мне нужно сделать что-то вроде vecs = embedding[item_ids]. Если бы я использовал два разных элемента nn.Embedding для старых элементов и нового и нового, мне нужно было бы использовать явный for-l oop с условиями if-else, что привело бы к ухудшению производительности.

Есть ли способ, которым я могу это сделать?

1 Ответ

1 голос
/ 10 марта 2020

Если вы посмотрите на реализацию nn.Embedding , она использует функциональную форму встраивания в прямой проход. Поэтому я думаю, что вы могли бы реализовать пользовательский модуль, который делает что-то вроде этого:

import torch
from torch.nn.parameter import Parameter
import torch.nn.functional as F

weights_freeze = torch.rand(10, 5)  # Don't make parameter
weights_train = Parameter(torch.rand(2, 5))
weights = torch.cat((weights_freeze, weights_train), 0)

idx = torch.tensor([[11, 1, 3]])
lookup = F.embedding(idx, weights)

# Desired result
print(lookup)
lookup.sum().backward()
# 11 corresponds to idx 1 in weights_train so this has grad
print(weights_train.grad)
...