Когда я использовал train_loader (pytorch), глобальные переменные выполнялись много раз. Почему? - PullRequest
0 голосов
/ 04 мая 2020

Я хочу обучить сеть (L eNet) с набором данных CIFAR10. Но я нашел некоторые проблемы. Если я использовал train_loader (pytorch) в глобальной области, py будет многократно выполнять коды глобальной области . Кто-нибудь может сказать мне, почему?

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import numpy as np
import torch.optim as optim
import os
import torch.nn.functional as f

CLASSES = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
DEVICE = torch.device('0' if torch.cuda.is_available() else "cpu")

data_home = 'F:\\work'
train_transform = transforms.Compose([transforms.RandomHorizontalFlip(), transforms.ToTensor()])
test_transform = transforms.Compose([transforms.ToTensor()])
train_set = torchvision.datasets.CIFAR10(root=os.path.join(data_home, 'dataset/CIFAR10'), train=True, download=True, transform=train_transform)
test_set = torchvision.datasets.CIFAR10(root=os.path.join(data_home, 'dataset/CIFAR10'), train=False, download=True, transform=test_transform)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True, num_workers=1)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=64, shuffle=True, num_workers=1)

print("print something")


def run():
    model = LeNet()
    model = model.to(DEVICE)
    optimizer = optim.SGD(params=model.parameters(), lr=0.01, momentum=0.5)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(50):
        for images, targets in train_loader:
            images, targets = images.to(DEVICE), targets.to(DEVICE)

            output = model(images)
            optimizer.zero_grad()
            loss = criterion(output, targets)
            loss.backward()
            optimizer.step()
    # for i in range(10):
    #     a = data_home


if __name__ == "__main__":
    run()

ниже результат результат

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...