Я новичок из Pytorch ie, и у меня возникли проблемы с torch.cat.
В моем сценарии у меня есть обученная модель, которая в основном представляет собой матрицу векторов, и я хочу сделать следующее:
- добавить новый вектор в обученную модель;
- объединить его с матрицей ранее существующих векторов;
- выполнить несколько итераций, обучая новый только вектор, сохраняя старые, уже обученные векторы «замороженными»
Вот минимальный пример моих вычислений и моей проблемы:
import torch
from torch import optim
from torch.nn.parameter import Parameter
# the old, already-trained vectors (2 vecs with 5 elements each)
old_vecs = torch.rand(2, 5)
# the new vector to train (1 vec with 5 elements).
# This has to be a Parameter, so it can be passed to an Optimizer to have it trained
new_vec = Parameter(torch.rand(1, 5))
# concatenate old_vecs and new_vec
all_vecs = torch.cat((old_vecs, new_vec), 0)
# create the Optimizer with the new_vec to train
optimizer = optim.SGD([new_vec], lr=1e-1)
for epoch in range(5):
loss = all_vecs.sum()
loss.backward()
optimizer.step()
optimizer.zero_grad()
Путем печати значений векторы до и после обучения, я получил следующие показания:
new_vec before training:
Parameter containing: tensor([[0.4151, 0.6478, 0.5142, 0.2373, 0.5643]], requires_grad=True)
new_vec after training:
Parameter containing: tensor([[-0.0849, 0.1478, 0.0142, -0.2627, 0.0643]], requires_grad=True)
all_vecs[2] after training:
tensor([0.4151, 0.6478, 0.5142, 0.2373, 0.5643], grad_fn=<SelectBackward>)
Ясно, что new_ve c обновляется, но all_vecs [2] нет. По-видимому, при запуске torch.cat генерируется новый тензор all_vecs с независимым содержимым . Поэтому vec_new и all_vecs [2] имеют независимые значения, и поскольку мой единственный параметр - new_ve c, all_vecs [2] обрабатывается так же, как старые замороженные векторы.
Это проблема, поскольку моя потеря вычисляется из all_vecs, а не из new_ve c. Поскольку all_vecs нельзя изменить, потеря и ее градиенты никогда не уменьшатся (и это, конечно, повлияет на способ обучения new_ve c).
Я мог бы легко решить проблему, обновив значения вручную для all_vecs [2] на каждой итерации:
for epoch in range(5):
with torch.no_grad():
all_vecs[2] = new_vec
loss = all_vecs.sum()
loss.backward()
optimizer.step()
optimizer.zero_grad()
Но, честно говоря, это похоже на плохой трюк, и мне нужно, чтобы мой код был как можно более "чистым". Какие-либо предложения? Есть ли альтернатива torch.cat, которая не выделяет новую память для значений тензора?