Можно ли заморозить только определенные веса вложения в слое в pytorch? - PullRequest
0 голосов
/ 28 февраля 2019

При использовании встраивания GloVe в задачи NLP некоторые слова из набора данных могут отсутствовать в GloVe.Поэтому мы создаем случайные веса для этих неизвестных слов.

Можно ли замораживать веса, полученные из GloVe, и обучать только новые экземпляры весов?

Я только знаю, что мы можем установить: model.embedding.weight.requires_grad = False

Но это делает новые слова неисследуемыми ..

Или существуют более эффективные способы извлечения семантики слов ..

1 Ответ

0 голосов
/ 02 марта 2019

1.Разделите вложения на два отдельных объекта

Один из подходов состоит в том, чтобы использовать два отдельных вложения , одно для подготовки , другое для обучения .

* 1008.* GloVe следует заморозить, в то время как тот, для которого нет предтренированного представления, будет взят из обучаемого слоя.

Если вы отформатируете свои данные, что для предтренированных представлений токенов они находятся в меньшем диапазоне, чем токеныбез представления GloVe это может быть сделано.Допустим, ваши предварительно обученные индексы находятся в диапазоне [0, 300], а показатели без представления - [301, 500].Я хотел бы пойти с чем-то вроде этого:

import numpy as np
import torch


class YourNetwork(torch.nn.Module):
    def __init__(self, glove_embeddings: np.array, how_many_tokens_not_present: int):
        self.pretrained_embedding = torch.nn.Embedding.from_pretrained(glove_embeddings)
        self.trainable_embedding = torch.nn.Embedding(
            how_many_tokens_not_present, glove_embeddings.shape[1]
        )
        # Rest of your network setup

    def forward(self, batch):
        # Which tokens in batch do not have representation, should have indices BIGGER
        # than the pretrained ones, adjust your data creating function accordingly
        mask = batch > self.pretrained_embedding.shape[0]

        # You may want to optimize it, you could probably get away without copy, though
        # I'm not currently sure how
        pretrained_batch = batch.copy()
        pretrained_batch[mask] = 0

        embedded_batch = self.pretrained_embedding[pretrained_batch]

        # Every token without representation has to be brought into appropriate range
        batch -= self.pretrained_embedding.shape[0]
        # Zero out the ones which already have pretrained embedding
        batch[~mask] = 0
        non_pretrained_embedded_batch = self.trainable_embedding(batch)

        # And finally change appropriate tokens from placeholder embedding created by
        # pretrained into trainable embeddings.
        embedded_batch[mask] = non_pretrained_embedded_batch[mask]

        # Rest of your code
        ...

Допустим, ваши предварительно обученные индексы находятся в диапазоне [0, 300], а без представления - [301, 500].

2.Нулевые градиенты для указанных токенов.

Этот немного хитрый, но я думаю, что он довольно лаконичен и прост в реализации.Итак, если вы получаете индексы токенов, которые не имеют представления GloVe, вы можете явно обнулить их градиент после backprop, чтобы эти строки не обновлялись.

import torch

embedding = torch.nn.Embedding(10, 3)
X = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]])

values = embedding(X)
loss = values.mean()

# Use whatever loss you want
loss.backward()

# Let's say those indices in your embedding are pretrained (have GloVe representation)
indices = torch.LongTensor([2, 4, 5])

print("Before zeroing out gradient")
print(embedding.weight.grad)

print("After zeroing out gradient")
embedding.weight.grad[indices] = 0
print(embedding.weight.grad)

И вывод второго подхода:

Before zeroing out gradient
tensor([[0.0000, 0.0000, 0.0000],
        [0.0417, 0.0417, 0.0417],
        [0.0833, 0.0833, 0.0833],
        [0.0417, 0.0417, 0.0417],
        [0.0833, 0.0833, 0.0833],
        [0.0417, 0.0417, 0.0417],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000],
        [0.0417, 0.0417, 0.0417]])
After zeroing out gradient
tensor([[0.0000, 0.0000, 0.0000],
        [0.0417, 0.0417, 0.0417],
        [0.0000, 0.0000, 0.0000],
        [0.0417, 0.0417, 0.0417],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000],
        [0.0417, 0.0417, 0.0417]])
...