Входы Pytorch для nn.CrossEntropyLoss () - PullRequest
0 голосов
/ 26 декабря 2018

Я пытаюсь выполнить логистическую регрессию в PyTorch на простом наборе данных с меткой 0,1.Критерий или потеря определяется как: criterion = nn.CrossEntropyLoss().Модель: model = LogisticRegression(1,2)

У меня есть точка данных, представляющая собой пару: dat = (-3.5, 0), первый элемент - точка данных, а второй - соответствующая метка.Затем я преобразовываю первый элемент ввода в тензор: tensor_input = torch.Tensor([dat[0]]).Затем я применяю модель к tenor_input: outputs = model(tensor_input).Затем я конвертирую метку в тензор: tensor_label = torch.Tensor([dat[1]]).Теперь, когда я пытаюсь это сделать, дело обрывается: loss = criterion(outputs, tensor_label).Это дает и ошибку: RuntimeError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

import torch
import torch.nn as nn

class LogisticRegression(nn.Module):
    def __init__(self, input_size, num_classes):
        super(LogisticRegression, self).__init__()
        self.linear = nn.Linear(input_size, num_classes) 

    def forward(self, x):
        out = self.linear(x)
        return out

model = LogisticRegression(1,2)
criterion = nn.CrossEntropyLoss()
dat = (-3.5,0)
tensor_input = torch.Tensor([dat[0]])
outputs = binary_model(tensor_input)
tensor_label = torch.Tensor([dat[1]])
loss = criterion(outputs, tensor_label)

Я не могу при жизни понять это.

Ответы [ 2 ]

0 голосов
/ 26 декабря 2018

Чтобы выполнить логистическую регрессию в PyTorch, вам нужно 3 вещи:

  • Метки (цели), закодированные как 0 или 1;
  • Активация сигмоида на последнем слое, поэтому числовыходы будут 1;
  • Двоичная перекрестная энтропия как функция потерь.

Вот минимальный пример:

import torch
import torch.nn as nn


class LogisticRegression(nn.Module):
    def __init__(self, n_inputs, n_outputs):
        super(LogisticRegression, self).__init__()
        self.linear = nn.Linear(n_inputs, n_outputs)
        self.sigmoid = nn.Sigmoid()


    def forward(self, x):
        x = self.linear(x)
        return self.sigmoid(x)


# Init your model
# Attention!!! your num_output will be 1, because logistic function returns one value in range (0, 1) 
model = LogisticRegression(n_inputs=1, n_outputs=1)
# Define Binary Cross Entropy Loss:
criterion = nn.BCELoss()

# dummy data
data = (42.0, 0)
tensor_input = torch.Tensor([data[0]])
tensor_label = torch.Tensor([data[1]])

outputs = model(tensor_input)

loss = criterion(outputs, tensor_label)

print(loss.item())
0 голосов
/ 26 декабря 2018

По большей части документация PyTorch проделывает потрясающую работу по объяснению различных функций;они обычно включают ожидаемые входные измерения, а также некоторые простые примеры.Вы можете найти описание для nn.CrossEntropyLoss() здесь .

Чтобы перейти к вашему конкретному примеру, давайте начнем с рассмотрения ожидаемого входного измерения:

Ввод: (N, C) где C = количество классов.[...]

Чтобы добавить к этому, N обычно относится к размеру партии (количеству образцов).Чтобы сравнить это с тем, что у вас есть:

outputs.shape
>>> torch.Size([2])

Т.е. в настоящее время у нас есть только входное измерение (2,), а не (1,2), как ожидается PyTorch.Мы можем облегчить это, добавив «поддельное» измерение к нашему текущему тензору, просто используя .unsqueeze(), например, так:

outputs = binary_model(tensor_input).unsqueeze(dim=0)
outputs.shape
>>> torch.Size([1,2])

Теперь, когда мы получили это, давайте посмотрим на ожидаемый ввод для целей:

Цель: (N) [...]

Итак, мы уже получили правильную форму для этого.Если мы попробуем это, мы все равно столкнемся с ошибкой:

RuntimeError: Expected object of scalar type Long but got scalar type Float 
              for argument #2 'target'.

Опять же, сообщение об ошибке довольно выразительно.Проблема здесь в том, что тензоры PyTorch (по умолчанию) интерпретируются как torch.FloatTensors, но вместо этого входные данные должны быть целыми числами (или Long).Мы можем просто сделать это, указав точный тип во время создания тензора:

tensor_label = torch.LongTensor([dat[1]])

Я использую PyTorch 1.0 под Linux.

...