скопировать конструкцию из тензора: ПРЕДУПРЕЖДЕНИЕ ПОЛЬЗОВАТЕЛЯ - PullRequest
1 голос
/ 21 октября 2019

Я создаю случайный тензор из нормального распределения, и так как этот тензор служит весом в NN, для добавления атрибутов require_grad я использую torch.tensor (), как показано ниже:

import torch 

input_dim, hidden_dim = 3, 5

norm = torch.distributions.normal.Normal(loc=0, scale=0.01)
W = norm.sample((input_dim, hidden_dim))
W = torch.tensor(W, requires_grad=True)

Iя получаю предупреждение об ошибке, как показано ниже:

    UserWarning: To copy construct from a tensor, 
    it is recommended to use sourceTensor.clone().detach() or 
sourceTensor.clone().detach().requires_grad_(True), 
rather than torch.tensor(sourceTensor).

Есть ли альтернативный способ достижения вышеуказанного? Спасибо

1 Ответ

1 голос
/ 21 октября 2019

Вы можете просто установить W.requires_grad на True

import torch 

input_dim, hidden_dim = 3, 5

norm = torch.distributions.normal.Normal(loc=0, scale=0.01)
W = norm.sample((input_dim, hidden_dim))
W.requires_grad = True
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...