Как назначить новое значение переменной Pytorch без прерывания обратного распространения? - PullRequest
0 голосов
/ 17 декабря 2018

У меня есть переменная pytorch, которая используется в качестве обучаемого входа для модели.В какой-то момент мне нужно вручную переназначить все значения в этой переменной.

Как я могу сделать это, не разрывая соединения с функцией потерь?

Предположим, что текущие значения [1.2, 3.2, 43.2], и я просто хочу, чтобы они стали [1,2,3].

1 Ответ

0 голосов
/ 18 декабря 2018

Вы можете использовать атрибут data тензоров для изменения значений, поскольку изменения в data не влияют на график.
Таким образом, график все еще не поврежден, и модификации самого атрибута data не оказывают влияния на график no .(Операции и изменения на data не отслеживаются автоградами и поэтому не представлены на графике)

Поскольку вы не привели пример, этот пример основан на вашем заявлении о комментариях:
'Предположим, я хочу изменить вес слоя.'
Я использовал здесь обычные тензоры, но это работает так же для атрибутов weight.data и bias.data слоев.

Вот краткий пример:

import torch
import torch.nn.functional as F



# Test 1, random vector with CE
w1 = torch.rand(1, 3, requires_grad=True)
loss = F.cross_entropy(w1, torch.tensor([1]))
loss.backward()
print('w1.data', w1)
print('w1.grad', w1.grad)
print()

# Test 2, replacing values of w2 with w1, before CE
# to make sure that everything is exactly like in Test 1 after replacing the values
w2 = torch.zeros(1, 3, requires_grad=True)
w2.data = w1.data
loss = F.cross_entropy(w2, torch.tensor([1]))
loss.backward()
print('w2.data', w2)
print('w2.grad', w2.grad)
print()

# Test 3, replace data after computation
w3 = torch.rand(1, 3, requires_grad=True)
loss = F.cross_entropy(w3, torch.tensor([1]))
# setting values
# the graph of the previous computation is still intact as you can in the below print-outs
w3.data = w1.data
loss.backward()

# data were replaced with values from w1
print('w3.data', w3)
# gradient still shows results from computation with w3
print('w3.grad', w3.grad)

Вывод:

w1.data tensor([[ 0.9367,  0.6669,  0.3106]])
w1.grad tensor([[ 0.4351, -0.6678,  0.2326]])

w2.data tensor([[ 0.9367,  0.6669,  0.3106]])
w2.grad tensor([[ 0.4351, -0.6678,  0.2326]])

w3.data tensor([[ 0.9367,  0.6669,  0.3106]])
w3.grad tensor([[ 0.3179, -0.7114,  0.3935]])

Самая интересная часть здесь - w3.В момент вызова backward значения заменяются значениями w1.
Но градиенты рассчитываются на основе функции CE со значениями оригинала w3.Замененные значения не влияют на график.Таким образом, связь с графом не разорвана, замена имела нет влияния на граф.Я надеюсь, что это то, что вы искали!

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...