opt.step () не обновляет веса модели - PullRequest
1 голос
/ 07 мая 2020
    print(self.global_model.state_dict())
    print("total_loss",total_loss)
    total_loss.backward()
    self.opt.step()
    print(self.global_model.state_dict())

вывод:

('dense1.weight', tensor([[ 0.3997, -0.1907,  0.1120,  0.3016],
        [ 0.1156,  0.0646,  0.1802,  0.3558],
        [ 0.0321,  0.2537,  0.0879,  0.2441],
        [-0.2952, -0.0886, -0.3235,  0.3006]])), ('dense1.bias', tensor([ 0.1927,  0.3048, -0.3551, -0.0302])), ('dense2.weig

total_loss.backward() tensor(2.5806, dtype=torch.float64, grad_fn=<MeanBackward0>)

('dense1.weight', tensor([[ 0.3997, -0.1907,  0.1120,  0.3016],
        [ 0.1156,  0.0646,  0.1802,  0.3558],
        [ 0.0321,  0.2537,  0.0879,  0.2441],
        [-0.2952, -0.0886, -0.3235,  0.3006]])), ('dense1.bias', tensor([ 0.192

Мы видим, что total_loss имеет какое-то значение, но не обновляет веса

self.opt = torch.optim.SGD(self.global_model.parameters(),lr = 0.01)

Обновление

Если я do

            print(self.local_model.state_dict())
            print("total_loss.backward()",total_loss)
            total_loss.backward()
            opt_2 = torch.optim.SGD(self.local_model.parameters(),lr = 0.01)
            opt_2.step()
            self.opt.step()
            print(self.local_model.state_dict())

Обновляет веса локальной модели. Но мне нужно применить этот градиент к другой модели. Итак, что мне нужно будет сделать?

...