Работает хорошо, но ничему не учится. Потери никогда не уменьшаются. Кроме того, возникает предупреждение времени выполнения: деление на ноль, встречающееся по мощности.
phi = np.array([0, 0.00724873038246, 0.0126062332691, 0.0179158239032, 0.0237780447692, 0.0301819694947, 0.037412389363, 0.0453824653196, 0.0541722382861, 0.0638498928693, 0.0748241181818, 0.0870018575601, 0.100969984201, 0.116901498992, 0.135225153396, 0.155865741879, 0.179184255267, 0.205004290279, 0.233612351761, 0.265481420095, 0.307980863])
kf = np.array([188.0535454, 200.042702173, 212.031762644, 224.020823148, 236.009883356, 247.998943263, 259.988002683, 271.977061559, 283.96611989, 295.955177493, 307.944233912, 319.933288966, 331.922341978, 343.911392444, 355.900439242, 367.889482121, 379.878519645, 391.867551768, 403.856577188, 415.845593859, 427.834570885])
PHI_train = torch.FloatTensor(phi)
KF_train = torch.FloatTensor(kf)
def ludwig_func(x, a, b, c):
return a+b*(x**c)
class Reg(torch.nn.Module):
def __init__(self,n_feature,n_hidden,n_output):
super(Reg, self).__init__()
self.layer1 = torch.nn.Linear(1,16)
self.layer2 = torch.nn.Linear(16,8)
#self.layer4 = torch.nn.Linear(32,16)
self.predict = torch.nn.Linear(8,3)
def forward(self,x):
x = F.relu(self.layer1(x))
x = F.relu(self.layer2(x))
x = self.predict(x)
return x
reg = Reg(n_feature = 1, n_hidden = 16, n_output = 3)
optimizer = torch.optim.SGD(reg.parameters(),lr=0.01)
loss_func = torch.nn.MSELoss()
for epoch in range(200):
PRM = reg(PHI_train) #Tensor
prm = PRM.detach().numpy() #Array
for i in range(len(ludwig_prm)): #length of prm = 21
a = prm[i,0]
b = prm[i,1]
c = prm[i,2]
#d = prm[i,3]
preds = ludwig_func(phi,a,b,c)#Array
np.set_printoptions(precision=10)
preds = np.nan_to_num(preds)
#print(preds)
Preds = torch.FloatTensor(preds)#Tensor
Preds[Preds == float('inf')] = -0.2e+10
#print(Preds)
loss = loss_func(Preds,KF_train) #Tensor
loss[loss == float('inf')] = 0.1e+29
loss = Variable(loss,requires_grad = True)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('Epoch:',epoch+1,'MESloss:',loss.item())
Я долго боролся с этим и не знаю почему. В интернете я тоже ничего не нашел. Это будет так приятно, что кто-то может помочь.