model.parameters()
возвращает все parameters
вашего model
, включая embeddings
.
Таким образом, все эти parameters
ваших model
передаются optimizer
(строка ниже) и будут обучаться позже при вызове optimizer.step()
- так что да, ваши embeddings
обучен вместе со всеми другими parameters
сети.
(вы также можете заморозить определенные слои, установив, например, embedding.weight.requires_grad = False
, но здесь это не так).
# summing it up:
# this line specifies which parameters are trained with the optimizer
# model.parameters() just returns all parameters
# embedding class weights are also parameters and will thus be trained
optimizer = optim.SGD(model.parameters(), lr=0.001)
Вы можете видеть, что ваши веса встраивания также относятся к типу Parameter
, выполнив следующие действия:
import torch
embedding_maxtrix = torch.nn.Embedding(10, 10)
print(type(embedding_maxtrix.weight))
Это выведет тип весов, который будет Parameter
:
<class 'torch.nn.parameter.Parameter'>
Я не совсем уверен, что означает retrieve . Вы имеете в виду получение одного вектора, или вы хотите, чтобы вся матрица сохранила его, или что-то еще?
embedding_maxtrix = torch.nn.Embedding(5, 5)
# this will get you a single embedding vector
print('Getting a single vector:\n', embedding_maxtrix(torch.LongTensor([0])))
# of course you can do the same for a seqeunce
print('Getting vectors for a sequence:\n', embedding_maxtrix(torch.LongTensor([1, 2, 3])))
# this will give the the whole embedding matrix
print('Getting weights:\n', embedding_maxtrix.weight.data)
Выход:
Getting a single vector:
tensor([[-0.0144, -0.6245, 1.3611, -1.0753, 0.5020]], grad_fn=<EmbeddingBackward>)
Getting vectors for a sequence:
tensor([[ 0.9277, -0.1879, -1.4999, 0.2895, 0.8367],
[-0.1167, -2.2139, 1.6918, -0.3483, 0.3508],
[ 2.3763, -1.3408, -0.9531, 2.2081, -1.5502]],
grad_fn=<EmbeddingBackward>)
Getting weights:
tensor([[-0.0144, -0.6245, 1.3611, -1.0753, 0.5020],
[ 0.9277, -0.1879, -1.4999, 0.2895, 0.8367],
[-0.1167, -2.2139, 1.6918, -0.3483, 0.3508],
[ 2.3763, -1.3408, -0.9531, 2.2081, -1.5502],
[-0.5829, -0.1918, -0.8079, 0.6922, -0.2627]])
Надеюсь, это ответит на ваш вопрос, вы также можете взглянуть на документацию, там же вы найдете несколько полезных примеров.
https://pytorch.org/docs/stable/nn.html#torch.nn.Embedding