Я пытался построить нейронную сеть с 4 входными узлами / функциями и только одной выходной функцией (0/1). Я написал этот код, и он работает, но во время обучения модель возвращает NaN. Я тоже отлаживал и весы и уклоны в порядке пока они не пройдут модель.
Из того, что я искал до сих пор, это может быть проблемой при передаче данных.
Мои входные данные: tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00, 1.5340e+00],
[1.5000e+01, 1.0000e-01, 2.4210e+00, 3.0000e+01],
[3.0000e+00, 2.2000e-01, 2.2000e-01, 4.5000e+01],
...,
[1.0000e+00, 2.0000e-02, 2.0000e-02, 1.5000e+01],
[6.0000e+00, 2.0000e-01, 2.0000e-01, 1.5000e+01],
[1.7000e+01, 5.2400e-01, 5.2400e-01, 2.0000e+00]], dtype=torch.float64)
import torch
from torchvision import datasets, transforms
import pandas as pd
import numpy as np
from torch.autograd import Variable
# Import tensor dataset & data loader
from torch.utils.data import TensorDataset, DataLoader
from torch import nn, optim
import torch.nn.functional as F
file = pd.read_csv('ks-projects-201801.csv')
array = np.array(file.values)
result = np.empty(len(array))
input_data = np.empty((len(array), 4))
for i in range(len(array)):
input_data[i] = np.array([array[i][10], array[i][12]/1000, array[i][13]/1000, array[i][14]/1000])
if array[i][9] == 'successful':
result[i] = 1
else:
result[i] = 0
input_node = Variable(torch.from_numpy(input_data))
output = torch.from_numpy(result)
print(input_node)
print(output)
train_ds = TensorDataset(input_node.squeeze(), output.squeeze())
batch_size = 5
train_dl = DataLoader(train_ds, batch_size, shuffle=True)
Это актуальная модель и обучение
model = nn.Linear(4, 1)
print(model.weight)
print(model.bias)
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.003)
epochs = 5
model = model.double()
for e in range(epochs):
running_loss = 0
for xb, yb in train_dl:
optimizer.zero_grad()
res = model(xb)
loss = criterion(res, yb)
loss.backward()
optimizer.step()
running_loss += loss.item()
else:
print(f"model : {loss}")
Это распечатывает модель: nan для каждой эпохи и заканчивается. Я очень плохо знаком с pytorch и не знаю, как справиться с этой проблемой.