Pytorch nn.Parameter обновляет только для первой эпохи - PullRequest
0 голосов
/ 22 апреля 2019

Я пытаюсь реализовать пользовательский модуль (это упрощенная версия), и моя проверка переменной self.param во время каждой итерации значение не меняется после первой итерации, даже если градиент имеет значение. Кто-нибудь, знакомый с pytorch, знает, почему это происходит?

class Custom_RNN(nn.Module):
    def __init__(self):
        super(Custom_RNN, self).__init__()
        self.hl1 = nn.Linear(2, 1024)
        self.hl2 = nn.Linear(1024, 512)
        self.out = nn.Linear(512, 1)
        self.param = nn.Parameter(torch.Tensor([250]), requires_grad=True)

    def forward(self, state_init, input, output):

        combined = torch.cat((state_init, input[0]), 1)
        out = torch.sigmoid(self.hl1(combined))
        out = torch.sigmoid(self.hl2(out))
        out = self.out(out)

        for t in range(input.shape[1]):

            # Predict SoC
            shifted_out = out + input[t]/self.param

            combined = torch.cat((shifted_out, input[t]), 1)
            out = torch.sigmoid(self.hl1(combined))
            out = torch.sigmoid(self.hl2(out))
            out = self.out(out)
            if first:
                loss = torch.pow(out - output[t], 2.0)
            else:
                loss = loss + torch.pow(out - output[t], 2.0)

        return loss

rnn = Custom_RNN()
optimiser = optim.Adam(rnn.parameters())

for epoch in range(epochs):
    optimiser.zero_grad()
    loss = rnn(initial_cond, data_in, data_out)
    rnn_param_before = rnn.param.item()
    loss.backward()
    optimiser.step()
    rnn_param_after = rnn.param.item()
    print(rnn_param_before - rnn_param_after)
    print(rnn.param.item(), rnn.param.grad)

В первую эпоху первый отпечаток получает число, не являющееся нулями, затем в каждую другую эпоху он равен 0,0, а второй оператор печати показывает, что значение остается неизменным, а градус всегда ненулевой.

...