как оценить и получить точность нейронной сети прямой связи с pytorch - PullRequest
0 голосов
/ 06 октября 2019

Я начал использовать Pytorch, и сейчас я работаю над проектом, в котором я использую простую нейронную сеть с прямой связью для линейной регрессии. Проблема в том, что в Pytorch я не нашел ничего, что позволило бы мне получить Точность Модели линейной регрессии, как в Keras или в SKlearn. в керасе это было бы просто, установив metrics=["accuracy"] внутри функции компиляции. Я искал в документации и на официальном сайте Pytorch, но ничего не нашел. кажется, что этот API не существует в Pytorch. Я знаю, что могу наблюдать потерю во время тренировки, или я могу просто получить тестовую потерю и, основываясь на ней, я могу знать, уменьшилась потеря или нет, но я хочу использовать ту структуру Keras, где я получаю значение потери, а также значение точности,путь Кераса выглядит более понятным. Я также попытался реализовать функцию точности, используя r2_score из sklearn, но он дал мне странные значения:

criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)

def train(model, optimizer, loss_fn):
    def train_step(x, y):
        model.train()
        optimizer.zero_grad()
        out = model(x)
        loss = loss_fn(out, y)
        loss.backward()
        optimizer.step()
        return loss.item()
    return train_step

def fit(epochs=100):
    train_func = train(model, optimizer, criterion)
    count, total = 0, 0
    loss_list, accuracy_list, iters = [], [], []
    for e in range(epochs):
        for X, y in train_loader:
            loss = train_func(X, y)
            count += 1
            total += len(y)
            if count % 50 == 0:
                print("loss= ", loss)
                loss_list.append(loss)
                iters.append(total)

            if count % 100 == 0:
                model.eval()   # im not sure if we can do this in pytorch. I mean evaluating the model while training! it would be great if you tell me whether this is ok or not
                out = model(X)
                out = out.detach().numpy()
                y = y.detach().numpy()
                accuracy = r2_score(y, out)   # r2_score is the scikit learn r2 score function.
                print("accuracy = ", accuracy)   # here i get wierd values and it doesn't get better over time, in contrast the loss decreased over time
                accuracy_list.append(accuracy)

    return iters, loss_list, accuracy_list

Я знаю, как реализовать функцию точности в случае проблемы классификации, потому что она использует дискретные значения. это ясно для меня, потому что реализация проста и понятна. Я должен только посмотреть, какое правильное предсказание сделала модель, а затем рассчитать точность. но в этом случае у меня есть непрерывные значения, поэтому я сам не смог реализовать эту функцию, и меня удивило, что у Pytorch нет встроенной функции для этого. может, кто-то может сказать мне, как реализовать это или где найти его реализацию?

Другое дело, где использовать оценку и где установить модель в режиме оценки, вызвав функцию eval. должен ли я использовать его во время обучения, как я делал в своем Кодексе, или я должен тренироваться, а затем тестировать после тренировки, и если я тестирую во время тренировки, я должен вызвать функцию eval, как я это делал, или это повлияет на тренировку, когда цикл вернется к тренировкеРежим? Еще одна вещь, которую я не нашел в Pytorch, это Крестная валидация. как мне реализовать это в pytorch, если для него нет API, как в Keras?

1 Ответ

0 голосов
/ 06 октября 2019
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))

Смотрите здесь для получения дополнительной информации: https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...