Нейросеть Pytorch для классификации - PullRequest
0 голосов
/ 01 марта 2019

Задача состоит в том, чтобы обучить нейронную сеть выбрать значение v в списке для заданного ввода x , который минимизирует (x - v) .Например, в качестве входных данных у меня есть тензор измерения 100x1 целых чисел.Моя нейронная сеть должна вывести распределение вероятностей для этих значений.

Это работает правильно для тензора 1x1, но выдает одинаковое распределение для каждого примера на входе.Я хочу, чтобы он выводил правильное распределение для каждого примера отдельно.Как мне это исправить?

Вот код:

N, D_in, D_out, H = 20, 1, 5, 10
X = torch.LongTensor(100, 1).random_(0, 500).float()
values = torch.FloatTensor([0, 1, 3, 5, 100])

model = nn.Sequential(
    nn.Linear(D_in, H),
    nn.Sigmoid(),
    nn.Linear(H, D_out),
    nn.Softmax(dim=1)
)

def loss_function(x, y, probability):
    return torch.mean(torch.sum(probability*((x-y)**2), 1).div(probability.size(1)))

loss_fn = loss_function
learning_rate = 1e-3
optimizer = optim.SGD(model.parameters(), lr=learning_rate)

model.train(True)
for epoch in range(200):
    pred = model(X)
    loss = loss_fn(X, values, pred)
    print('loss: ', loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

И это вход и выход для 20 примеров: Вход:

tensor([[ 16.],
         [ 28.],
         [188.],
         [ 16.],
         [379.],
         [366.],
         [366.],
         [421.],
         [ 13.],
         [ 86.],
         [ 38.],
         [135.],
         [173.],
         [153.],
         [223.],
         [133.],
         [316.],
         [108.],
         [133.],
         [295.]])

Выход:

tensor([[1.2926e-04, 1.3165e-04, 1.2761e-04, 1.3385e-04, 9.9948e-01],
         [9.9158e-05, 1.0086e-04, 9.8492e-05, 1.0170e-04, 9.9960e-01],
         [3.3826e-05, 3.4437e-05, 3.4443e-05, 3.2122e-05, 9.9987e-01],
         [1.2926e-04, 1.3165e-04, 1.2761e-04, 1.3385e-04, 9.9948e-01],
         [3.3173e-05, 3.3778e-05, 3.3758e-05, 3.1419e-05, 9.9987e-01],
         [3.3175e-05, 3.3780e-05, 3.3759e-05, 3.1421e-05, 9.9987e-01],
         [3.3175e-05, 3.3780e-05, 3.3759e-05, 3.1421e-05, 9.9987e-01],
         [3.3170e-05, 3.3775e-05, 3.3754e-05, 3.1416e-05, 9.9987e-01],
         [1.4026e-04, 1.4281e-04, 1.3817e-04, 1.4524e-04, 9.9943e-01],
         [4.4578e-05, 4.5341e-05, 4.5195e-05, 4.3355e-05, 9.9982e-01],
         [8.1636e-05, 8.2994e-05, 8.1468e-05, 8.2810e-05, 9.9967e-01],
         [3.5991e-05, 3.6629e-05, 3.6643e-05, 3.4392e-05, 9.9986e-01],
         [3.4158e-05, 3.4773e-05, 3.4784e-05, 3.2472e-05, 9.9986e-01],
         [3.4881e-05, 3.5505e-05, 3.5521e-05, 3.3232e-05, 9.9986e-01],
         [3.3425e-05, 3.4031e-05, 3.4026e-05, 3.1694e-05, 9.9987e-01],
         [3.6153e-05, 3.6794e-05, 3.6807e-05, 3.4561e-05, 9.9986e-01],
         [3.3191e-05, 3.3796e-05, 3.3777e-05, 3.1439e-05, 9.9987e-01],
         [3.9212e-05, 3.9897e-05, 3.9870e-05, 3.7751e-05, 9.9984e-01],
         [3.6153e-05, 3.6794e-05, 3.6807e-05, 3.4561e-05, 9.9986e-01],
         [3.3207e-05, 3.3812e-05, 3.3795e-05, 3.1457e-05, 9.9987e-01]],
        grad_fn=<SoftmaxBackward>)
...