Обратное распространение с PyTorch и автодифференциация - PullRequest
1 голос
/ 14 января 2020

У меня есть набор данных - X -, состоящий из 15 переменных и 64 наблюдений, и вектор столбца - Y - из 64 значений, представляющих цель (метку). Я пытаюсь согласовать параметры с квадратичной функцией c, чтобы вернуть наблюдаемые значения (Y) с помощью PyTorch, но я получаю ошибку. Я предоставляю набор данных в конце поста в формате json для воспроизводимости.

Если бы у меня был один пример, мой код мог бы быть:

X = torch.from_numpy(X)
X.requires_grad = True
W = np.random.randn(15,15)
W = np.triu(W, k=0)
W = torch.from_numpy(W)
W.requires_grad = True

# define parameters for gradient descent
max_iter=100
lr_rate = 1e-3

# we will do gradient descent for max_iter iteration 
for i in range(max_iter):

        # compute the loss
        loss = Y - (X@torch.transpose(X, 1,0) * W).sum()
        # use torch.autograd.grad to compute the gradient
        W = W - lr_rate*torch.autograd.grad(out, W)[0]
        print(f"{i}: {out}")

Не могли бы вы привести пример правильной реализации с использованием данных, которые я предоставляю ниже, которые позволят достичь заявленной цели (подгонка параметров к данным) в векторизованном виде?

Данные: X:

'{"embed_item_dim":{"0":-1.0,"1":1.0,"2":-1.0,"3":1.0,"4":-1.0,"5":1.0,"6":-1.0,"7":1.0,"8":-1.0,"9":1.0,"10":-1.0,"11":1.0,"12":-1.0,"13":1.0,"14":-1.0,"15":1.0,"16":-1.0,"17":1.0,"18":-1.0,"19":1.0,"20":-1.0,"21":1.0,"22":-1.0,"23":1.0,"24":-1.0,"25":1.0,"26":-1.0,"27":1.0,"28":-1.0,"29":1.0,"30":-1.0,"31":1.0,"32":-1.0,"33":1.0,"34":-1.0,"35":1.0,"36":-1.0,"37":1.0,"38":-1.0,"39":1.0,"40":-1.0,"41":1.0,"42":-1.0,"43":1.0,"44":-1.0,"45":1.0,"46":-1.0,"47":1.0,"48":-1.0,"49":1.0,"50":-1.0,"51":1.0,"52":-1.0,"53":1.0,"54":-1.0,"55":1.0,"56":-1.0,"57":1.0,"58":-1.0,"59":1.0,"60":-1.0,"61":1.0,"62":-1.0,"63":1.0},"embed_category_dim":{"0":-1.0,"1":-1.0,"2":1.0,"3":1.0,"4":-1.0,"5":-1.0,"6":1.0,"7":1.0,"8":-1.0,"9":-1.0,"10":1.0,"11":1.0,"12":-1.0,"13":-1.0,"14":1.0,"15":1.0,"16":-1.0,"17":-1.0,"18":1.0,"19":1.0,"20":-1.0,"21":-1.0,"22":1.0,"23":1.0,"24":-1.0,"25":-1.0,"26":1.0,"27":1.0,"28":-1.0,"29":-1.0,"30":1.0,"31":1.0,"32":-1.0,"33":-1.0,"34":1.0,"35":1.0,"36":-1.0,"37":-1.0,"38":1.0,"39":1.0,"40":-1.0,"41":-1.0,"42":1.0,"43":1.0,"44":-1.0,"45":-1.0,"46":1.0,"47":1.0,"48":-1.0,"49":-1.0,"50":1.0,"51":1.0,"52":-1.0,"53":-1.0,"54":1.0,"55":1.0,"56":-1.0,"57":-1.0,"58":1.0,"59":1.0,"60":-1.0,"61":-1.0,"62":1.0,"63":1.0},"embed_shop_dim":{"0":-1.0,"1":-1.0,"2":-1.0,"3":-1.0,"4":1.0,"5":1.0,"6":1.0,"7":1.0,"8":-1.0,"9":-1.0,"10":-1.0,"11":-1.0,"12":1.0,"13":1.0,"14":1.0,"15":1.0,"16":-1.0,"17":-1.0,"18":-1.0,"19":-1.0,"20":1.0,"21":1.0,"22":1.0,"23":1.0,"24":-1.0,"25":-1.0,"26":-1.0,"27":-1.0,"28":1.0,"29":1.0,"30":1.0,"31":1.0,"32":-1.0,"33":-1.0,"34":-1.0,"35":-1.0,"36":1.0,"37":1.0,"38":1.0,"39":1.0,"40":-1.0,"41":-1.0,"42":-1.0,"43":-1.0,"44":1.0,"45":1.0,"46":1.0,"47":1.0,"48":-1.0,"49":-1.0,"50":-1.0,"51":-1.0,"52":1.0,"53":1.0,"54":1.0,"55":1.0,"56":-1.0,"57":-1.0,"58":-1.0,"59":-1.0,"60":1.0,"61":1.0,"62":1.0,"63":1.0},"categorical_dim":{"0":-1.0,"1":-1.0,"2":-1.0,"3":-1.0,"4":-1.0,"5":-1.0,"6":-1.0,"7":-1.0,"8":1.0,"9":1.0,"10":1.0,"11":1.0,"12":1.0,"13":1.0,"14":1.0,"15":1.0,"16":-1.0,"17":-1.0,"18":-1.0,"19":-1.0,"20":-1.0,"21":-1.0,"22":-1.0,"23":-1.0,"24":1.0,"25":1.0,"26":1.0,"27":1.0,"28":1.0,"29":1.0,"30":1.0,"31":1.0,"32":-1.0,"33":-1.0,"34":-1.0,"35":-1.0,"36":-1.0,"37":-1.0,"38":-1.0,"39":-1.0,"40":1.0,"41":1.0,"42":1.0,"43":1.0,"44":1.0,"45":1.0,"46":1.0,"47":1.0,"48":-1.0,"49":-1.0,"50":-1.0,"51":-1.0,"52":-1.0,"53":-1.0,"54":-1.0,"55":-1.0,"56":1.0,"57":1.0,"58":1.0,"59":1.0,"60":1.0,"61":1.0,"62":1.0,"63":1.0},"categorical_dropout":{"0":-1.0,"1":-1.0,"2":-1.0,"3":-1.0,"4":-1.0,"5":-1.0,"6":-1.0,"7":-1.0,"8":-1.0,"9":-1.0,"10":-1.0,"11":-1.0,"12":-1.0,"13":-1.0,"14":-1.0,"15":-1.0,"16":1.0,"17":1.0,"18":1.0,"19":1.0,"20":1.0,"21":1.0,"22":1.0,"23":1.0,"24":1.0,"25":1.0,"26":1.0,"27":1.0,"28":1.0,"29":1.0,"30":1.0,"31":1.0,"32":-1.0,"33":-1.0,"34":-1.0,"35":-1.0,"36":-1.0,"37":-1.0,"38":-1.0,"39":-1.0,"40":-1.0,"41":-1.0,"42":-1.0,"43":-1.0,"44":-1.0,"45":-1.0,"46":-1.0,"47":-1.0,"48":1.0,"49":1.0,"50":1.0,"51":1.0,"52":1.0,"53":1.0,"54":1.0,"55":1.0,"56":1.0,"57":1.0,"58":1.0,"59":1.0,"60":1.0,"61":1.0,"62":1.0,"63":1.0},"numerical_dim":{"0":-1.0,"1":-1.0,"2":-1.0,"3":-1.0,"4":-1.0,"5":-1.0,"6":-1.0,"7":-1.0,"8":-1.0,"9":-1.0,"10":-1.0,"11":-1.0,"12":-1.0,"13":-1.0,"14":-1.0,"15":-1.0,"16":-1.0,"17":-1.0,"18":-1.0,"19":-1.0,"20":-1.0,"21":-1.0,"22":-1.0,"23":-1.0,"24":-1.0,"25":-1.0,"26":-1.0,"27":-1.0,"28":-1.0,"29":-1.0,"30":-1.0,"31":-1.0,"32":1.0,"33":1.0,"34":1.0,"35":1.0,"36":1.0,"37":1.0,"38":1.0,"39":1.0,"40":1.0,"41":1.0,"42":1.0,"43":1.0,"44":1.0,"45":1.0,"46":1.0,"47":1.0,"48":1.0,"49":1.0,"50":1.0,"51":1.0,"52":1.0,"53":1.0,"54":1.0,"55":1.0,"56":1.0,"57":1.0,"58":1.0,"59":1.0,"60":1.0,"61":1.0,"62":1.0,"63":1.0},"numerical_dropout":{"0":-1.0,"1":1.0,"2":1.0,"3":-1.0,"4":1.0,"5":-1.0,"6":-1.0,"7":1.0,"8":-1.0,"9":1.0,"10":1.0,"11":-1.0,"12":1.0,"13":-1.0,"14":-1.0,"15":1.0,"16":-1.0,"17":1.0,"18":1.0,"19":-1.0,"20":1.0,"21":-1.0,"22":-1.0,"23":1.0,"24":-1.0,"25":1.0,"26":1.0,"27":-1.0,"28":1.0,"29":-1.0,"30":-1.0,"31":1.0,"32":-1.0,"33":1.0,"34":1.0,"35":-1.0,"36":1.0,"37":-1.0,"38":-1.0,"39":1.0,"40":-1.0,"41":1.0,"42":1.0,"43":-1.0,"44":1.0,"45":-1.0,"46":-1.0,"47":1.0,"48":-1.0,"49":1.0,"50":1.0,"51":-1.0,"52":1.0,"53":-1.0,"54":-1.0,"55":1.0,"56":-1.0,"57":1.0,"58":1.0,"59":-1.0,"60":1.0,"61":-1.0,"62":-1.0,"63":1.0},"mixed_dim1":{"0":-1.0,"1":1.0,"2":1.0,"3":-1.0,"4":-1.0,"5":1.0,"6":1.0,"7":-1.0,"8":1.0,"9":-1.0,"10":-1.0,"11":1.0,"12":1.0,"13":-1.0,"14":-1.0,"15":1.0,"16":-1.0,"17":1.0,"18":1.0,"19":-1.0,"20":-1.0,"21":1.0,"22":1.0,"23":-1.0,"24":1.0,"25":-1.0,"26":-1.0,"27":1.0,"28":1.0,"29":-1.0,"30":-1.0,"31":1.0,"32":-1.0,"33":1.0,"34":1.0,"35":-1.0,"36":-1.0,"37":1.0,"38":1.0,"39":-1.0,"40":1.0,"41":-1.0,"42":-1.0,"43":1.0,"44":1.0,"45":-1.0,"46":-1.0,"47":1.0,"48":-1.0,"49":1.0,"50":1.0,"51":-1.0,"52":-1.0,"53":1.0,"54":1.0,"55":-1.0,"56":1.0,"57":-1.0,"58":-1.0,"59":1.0,"60":1.0,"61":-1.0,"62":-1.0,"63":1.0},"mixed_dropout1":{"0":-1.0,"1":1.0,"2":1.0,"3":-1.0,"4":-1.0,"5":1.0,"6":1.0,"7":-1.0,"8":-1.0,"9":1.0,"10":1.0,"11":-1.0,"12":-1.0,"13":1.0,"14":1.0,"15":-1.0,"16":1.0,"17":-1.0,"18":-1.0,"19":1.0,"20":1.0,"21":-1.0,"22":-1.0,"23":1.0,"24":1.0,"25":-1.0,"26":-1.0,"27":1.0,"28":1.0,"29":-1.0,"30":-1.0,"31":1.0,"32":-1.0,"33":1.0,"34":1.0,"35":-1.0,"36":-1.0,"37":1.0,"38":1.0,"39":-1.0,"40":-1.0,"41":1.0,"42":1.0,"43":-1.0,"44":-1.0,"45":1.0,"46":1.0,"47":-1.0,"48":1.0,"49":-1.0,"50":-1.0,"51":1.0,"52":1.0,"53":-1.0,"54":-1.0,"55":1.0,"56":1.0,"57":-1.0,"58":-1.0,"59":1.0,"60":1.0,"61":-1.0,"62":-1.0,"63":1.0},"mixed_dim2":{"0":-1.0,"1":1.0,"2":1.0,"3":-1.0,"4":-1.0,"5":1.0,"6":1.0,"7":-1.0,"8":-1.0,"9":1.0,"10":1.0,"11":-1.0,"12":-1.0,"13":1.0,"14":1.0,"15":-1.0,"16":-1.0,"17":1.0,"18":1.0,"19":-1.0,"20":-1.0,"21":1.0,"22":1.0,"23":-1.0,"24":-1.0,"25":1.0,"26":1.0,"27":-1.0,"28":-1.0,"29":1.0,"30":1.0,"31":-1.0,"32":1.0,"33":-1.0,"34":-1.0,"35":1.0,"36":1.0,"37":-1.0,"38":-1.0,"39":1.0,"40":1.0,"41":-1.0,"42":-1.0,"43":1.0,"44":1.0,"45":-1.0,"46":-1.0,"47":1.0,"48":1.0,"49":-1.0,"50":-1.0,"51":1.0,"52":1.0,"53":-1.0,"54":-1.0,"55":1.0,"56":1.0,"57":-1.0,"58":-1.0,"59":1.0,"60":1.0,"61":-1.0,"62":-1.0,"63":1.0},"mixed_dropout2":{"0":-1.0,"1":1.0,"2":-1.0,"3":1.0,"4":1.0,"5":-1.0,"6":1.0,"7":-1.0,"8":1.0,"9":-1.0,"10":1.0,"11":-1.0,"12":-1.0,"13":1.0,"14":-1.0,"15":1.0,"16":-1.0,"17":1.0,"18":-1.0,"19":1.0,"20":1.0,"21":-1.0,"22":1.0,"23":-1.0,"24":1.0,"25":-1.0,"26":1.0,"27":-1.0,"28":-1.0,"29":1.0,"30":-1.0,"31":1.0,"32":-1.0,"33":1.0,"34":-1.0,"35":1.0,"36":1.0,"37":-1.0,"38":1.0,"39":-1.0,"40":1.0,"41":-1.0,"42":1.0,"43":-1.0,"44":-1.0,"45":1.0,"46":-1.0,"47":1.0,"48":-1.0,"49":1.0,"50":-1.0,"51":1.0,"52":1.0,"53":-1.0,"54":1.0,"55":-1.0,"56":1.0,"57":-1.0,"58":1.0,"59":-1.0,"60":-1.0,"61":1.0,"62":-1.0,"63":1.0},"mixed_dim3":{"0":-1.0,"1":1.0,"2":-1.0,"3":1.0,"4":1.0,"5":-1.0,"6":1.0,"7":-1.0,"8":-1.0,"9":1.0,"10":-1.0,"11":1.0,"12":1.0,"13":-1.0,"14":1.0,"15":-1.0,"16":1.0,"17":-1.0,"18":1.0,"19":-1.0,"20":-1.0,"21":1.0,"22":-1.0,"23":1.0,"24":1.0,"25":-1.0,"26":1.0,"27":-1.0,"28":-1.0,"29":1.0,"30":-1.0,"31":1.0,"32":-1.0,"33":1.0,"34":-1.0,"35":1.0,"36":1.0,"37":-1.0,"38":1.0,"39":-1.0,"40":-1.0,"41":1.0,"42":-1.0,"43":1.0,"44":1.0,"45":-1.0,"46":1.0,"47":-1.0,"48":1.0,"49":-1.0,"50":1.0,"51":-1.0,"52":-1.0,"53":1.0,"54":-1.0,"55":1.0,"56":1.0,"57":-1.0,"58":1.0,"59":-1.0,"60":-1.0,"61":1.0,"62":-1.0,"63":1.0},"mixed_dropout3":{"0":-1.0,"1":1.0,"2":-1.0,"3":1.0,"4":1.0,"5":-1.0,"6":1.0,"7":-1.0,"8":-1.0,"9":1.0,"10":-1.0,"11":1.0,"12":1.0,"13":-1.0,"14":1.0,"15":-1.0,"16":-1.0,"17":1.0,"18":-1.0,"19":1.0,"20":1.0,"21":-1.0,"22":1.0,"23":-1.0,"24":-1.0,"25":1.0,"26":-1.0,"27":1.0,"28":1.0,"29":-1.0,"30":1.0,"31":-1.0,"32":1.0,"33":-1.0,"34":1.0,"35":-1.0,"36":-1.0,"37":1.0,"38":-1.0,"39":1.0,"40":1.0,"41":-1.0,"42":1.0,"43":-1.0,"44":-1.0,"45":1.0,"46":-1.0,"47":1.0,"48":1.0,"49":-1.0,"50":1.0,"51":-1.0,"52":-1.0,"53":1.0,"54":-1.0,"55":1.0,"56":1.0,"57":-1.0,"58":1.0,"59":-1.0,"60":-1.0,"61":1.0,"62":-1.0,"63":1.0},"last_layer_dim":{"0":-1.0,"1":1.0,"2":-1.0,"3":1.0,"4":-1.0,"5":1.0,"6":-1.0,"7":1.0,"8":1.0,"9":-1.0,"10":1.0,"11":-1.0,"12":1.0,"13":-1.0,"14":1.0,"15":-1.0,"16":1.0,"17":-1.0,"18":1.0,"19":-1.0,"20":1.0,"21":-1.0,"22":1.0,"23":-1.0,"24":-1.0,"25":1.0,"26":-1.0,"27":1.0,"28":-1.0,"29":1.0,"30":-1.0,"31":1.0,"32":-1.0,"33":1.0,"34":-1.0,"35":1.0,"36":-1.0,"37":1.0,"38":-1.0,"39":1.0,"40":1.0,"41":-1.0,"42":1.0,"43":-1.0,"44":1.0,"45":-1.0,"46":1.0,"47":-1.0,"48":1.0,"49":-1.0,"50":1.0,"51":-1.0,"52":1.0,"53":-1.0,"54":1.0,"55":-1.0,"56":-1.0,"57":1.0,"58":-1.0,"59":1.0,"60":-1.0,"61":1.0,"62":-1.0,"63":1.0},"last_layer_dropout":{"0":-1.0,"1":1.0,"2":-1.0,"3":1.0,"4":-1.0,"5":1.0,"6":-1.0,"7":1.0,"8":1.0,"9":-1.0,"10":1.0,"11":-1.0,"12":1.0,"13":-1.0,"14":1.0,"15":-1.0,"16":-1.0,"17":1.0,"18":-1.0,"19":1.0,"20":-1.0,"21":1.0,"22":-1.0,"23":1.0,"24":1.0,"25":-1.0,"26":1.0,"27":-1.0,"28":1.0,"29":-1.0,"30":1.0,"31":-1.0,"32":1.0,"33":-1.0,"34":1.0,"35":-1.0,"36":1.0,"37":-1.0,"38":1.0,"39":-1.0,"40":-1.0,"41":1.0,"42":-1.0,"43":1.0,"44":-1.0,"45":1.0,"46":-1.0,"47":1.0,"48":1.0,"49":-1.0,"50":1.0,"51":-1.0,"52":1.0,"53":-1.0,"54":1.0,"55":-1.0,"56":-1.0,"57":1.0,"58":-1.0,"59":1.0,"60":-1.0,"61":1.0,"62":-1.0,"63":1.0}}'

Y

'{"0":2.0561221309,"1":2.0649733606,"2":2.0733728925,"3":2.0594125771,"4":2.0949032045,"5":2.0294939058,"6":2.0436441327,"7":2.1209041954,"8":2.0496001055,"9":2.148755921,"10":2.0937250525,"11":2.0629058135,"12":2.0641746866,"13":2.0592979107,"14":2.1166172412,"15":2.1125198086,"16":2.0525522671,"17":2.0687485594,"18":2.0649582587,"19":2.0818384718,"20":2.0839422046,"21":2.043783441,"22":2.05290516,"23":2.0565277924,"24":2.0550897444,"25":2.0663609971,"26":2.0895415003,"27":2.0706054531,"28":2.0639581304,"29":2.0889003421,"30":2.0436977626,"31":2.1350170653,"32":2.0395688425,"33":2.079368626,"34":2.0439947954,"35":2.072433023,"36":2.050665861,"37":2.037977855,"38":2.0527567514,"39":2.050903715,"40":2.0381965719,"41":2.0673631206,"42":2.085004701,"43":2.0458497661,"44":2.0540644062,"45":2.050330556,"46":2.0859451303,"47":2.0323004844,"48":2.05113558,"49":2.046360857,"50":2.0572361143,"51":2.0659940765,"52":2.0583657215,"53":2.0520969623,"54":2.0683284923,"55":2.0491708591,"56":2.0932832342,"57":2.0416396082,"58":2.0703974941,"59":2.0464359665,"60":2.0591405783,"61":2.0527808995,"62":2.0670555565,"63":2.0898413706}'

1 Ответ

0 голосов
/ 14 января 2020

Я думаю, что ваша функция потери неверна. В настоящее время вы использовали

loss = Y - (X@torch.transpose(X, 1,0) * W).sum()

В идеале, поскольку мы используем SGD и пытаемся минимизировать потери, мы хотим, чтобы потери были минимальными тогда и только тогда, когда Y равно f(x). Тем не менее, в вашем случае потери сводятся к минимуму, когда f(x) максимально велико, так как тогда Y-f(x) будет максимально малым. Обратите внимание, что потеря будет только 0, когда y=f(x), иначе она будет положительной или отрицательной. Быстрое исправление может заключаться в том, чтобы просто возместить потери, чтобы они всегда были положительными и равными 0 только тогда, когда y=f(x).

loss = (Y - (X@torch.transpose(X, 1,0) * W).sum())**2

Также, если Y,X - вектор и В матрицах (более одного образца) можно суммировать потери или вычислять среднее значение

loss = ((Y - (X@torch.transpose(X, 1,0) * W).sum())**2).mean()

...