PyTorch Dataloader - список не вызывается ошибка при перечислении - PullRequest
0 голосов
/ 30 января 2019

При переборе по загрузчику данных PyTorch, например,

# define dataset, dataloader
train_data = datasets.ImageFolder(data_dir + '/train', transform=train_transforms)
test_data = datasets.ImageFolder(data_dir + '/test', transform=test_transforms)
trainloader = torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True)
testloader = torch.utils.data.DataLoader(test_data, batch_size=64)

# define model, optimizer, loss
# not included - irrelevant to the question

for ii, (inputs, labels) in enumerate(trainloader):

    # Move input and label tensors to the GPU
    inputs, labels = inputs.to(device), labels.to(device)

    start = time.time()

    outputs = model.forward(inputs)
    loss = criterion(outputs, labels)
    loss.backward()

я получаю TypeError: 'list' object is not callable в этой строке

for ii, (inputs, labels) in enumerate(trainloader):

Что за глупость я забыл?

1 Ответ

0 голосов
/ 30 января 2019

Вы не забыли вызвать transforms.Compose в своем списке преобразований?

В этой строке

train_data = datasets.ImageFolder(data_dir + '/train', transform=train_transforms)

параметр transform ожидает вызываемый объект, а не список.

Так, например, это неправильно:

train_transforms = [
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]

Это должно выглядеть так

train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...