Я пытаюсь изучить вложение с помощью пользовательской функции градиента. В основном у меня есть вложение
my_embedding = torch.nn.Embedding(
num_embeddings=n,
embedding_dim=d)
и оптимизатор, например
my_optimizer = torch.optim.SGD(
my_embedding.parameters(),
lr=0.01)
Я вычисляю некоторый градиент вручную (обозначается my_gradient
ниже), и я делаю my_embedding.weight.grad = my_gradient
.
Однако, когда я делаю my_optimizer.step()
, это ничего не обновляет.
Я что-то пропустил ? (в этом случае я мог бы напрямую пропустить оптимизатор, но мне хотелось бы понять, почему он не работает)