Вам необходимо применить балансировку размеров к градиентам. Взятый из курса Стэнфорда cs231n , он сводится к двум простым модификациям:
Учитывая и , у нас будет:
,
Вот код, который я использовал для проверки правильности расчета градиента. Вы должны быть в состоянии обновить свой код соответственно.
import torch
torch.random.manual_seed(0)
x_1, x_2 = torch.zeros(size=(1, 8)).normal_(0, 0.01), torch.zeros(size=(1, 8)).normal_(0, 0.01)
y = torch.zeros(size=(1, 8)).normal_(0, 0.01)
h_0 = torch.zeros(size=(1, 16)).normal_(0, 0.01)
weight_ih = torch.zeros(size=(8, 16)).normal_(mean=0, std=0.01).requires_grad_(True)
weight_hh = torch.zeros(size=(16, 16)).normal_(mean=0, std=0.01).requires_grad_(True)
weight_ho = torch.zeros(size=(16, 8)).normal_(mean=0, std=0.01).requires_grad_(True)
h_1 = x_1.mm(weight_ih) + h_0.mm(weight_hh)
h_2 = x_2.mm(weight_ih) + h_1.mm(weight_hh)
g_2 = h_2.sigmoid()
j_2 = g_2.mm(weight_ho)
y_predicted = j_2.sigmoid()
loss = 0.5 * (y - y_predicted).pow(2).sum()
loss.backward()
delta_1 = -1 * (y - y_predicted) * y_predicted * (1 - y_predicted)
delta_2 = delta_1.mm(weight_ho.t()) * (g_2 * (1 - g_2))
delta_3 = delta_2.mm(weight_hh.t())
# 16 x 8
weight_ho_grad = g_2.t() * delta_1
# 16 x 16
weight_hh_grad = h_1.t() * delta_2 + (h_0.t() * delta_3)
# 8 x 16
weight_ih_grad = x_2.t() * delta_2 + x_1.t() * delta_3
atol = 1e-10
assert torch.allclose(weight_ho.grad, weight_ho_grad, atol=atol)
assert torch.allclose(weight_hh.grad, weight_hh_grad, atol=atol)
assert torch.allclose(weight_ih.grad, weight_ih_grad, atol=atol)