Функция, загружающая обучение, проверку и набор тестов, возвращает ошибку имени. - PullRequest
0 голосов
/ 08 декабря 2018

Вот моя функция:

from torchvision import *
import numpy as np
from torch.utils.data.sampler import SubsetRandomSampler

def get_train_valid_test_loader(data_dir,
                       batch_size,
                       random_seed,
                       augment=True,
                       train_size = 0.6,
                       valid_size=0.2,
                       test_size = 0.2,
                       shuffle=True,
                       show_sample=False,
                       num_workers=0,
                       pin_memory=True):

#     error_msg = "[!] valid_size should be in the range [0, 1]."
#     assert ((valid_size >= 0) and (valid_size <= 1)), error_msg

normalize = torchvision.transforms.Normalize(
    mean=[0.4914, 0.4822, 0.4465],
    std=[0.2023, 0.1994, 0.2010],
)


# load the dataset
dataset_loader = torch.utils.data.DataLoader(data_dir)
trainset = dataset_loader

#slice the dataset into train, validation, and test partitions
num_train = len(dataset_loader)
indices = list(range(num_train))
train_split = int(np.floor(train_size * num_train))
validation_split = int(np.floor(valid_size * num_train)) + 1
test_split = int(np.floor(test_size * num_train)) + 1



if shuffle:
    np.random.seed(random_seed)
    np.random.shuffle(indices)

train_idx = indices[:train_split]
valid_idx = indices[train_split:validation_split]
test_idx =  indices[validation_split:test_split]

train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
test_sampler = SubsetRandomSampler(test_idx)


#LOAD ALL THE SETS 

#TRAIN SET

#load the trainset

trainset = torch.utils.data.DataLoader(
    dataset_loader, batch_size=batch_size, sampler=train_sampler,
    num_workers=num_workers, pin_memory=pin_memory,
)

#Train: Data Augmentation, Resizing, and Normalization 

if augment:
    trainset = transforms.Compose([
        torchvision.transforms.Resize(224), #resize all the pictures 
        torchvision.transforms.RandomCrop(224,pad_if_needed=True),
        torchvision.transforms.RandomCrop(224,pad_if_needed=True),
        torchvision.transforms.RandomRotation(45, resample=False, 
                                              expand=False, center=None),
        torchvision.transforms.RandomCrop(224,pad_if_needed=True),
        torchvision.transforms.RandomCrop(224,pad_if_needed=True),
        torchvision.transforms.RandomRotation(30, resample=False, 
                                              expand=False, center=None),
        torchvision.transforms.RandomCrop(224,pad_if_needed=True),
        torchvision.transforms.RandomGrayscale(p=0.1),
        torchvision.transforms.RandomVerticalFlip(p=1),
        torchvision.transforms.RandomHorizontalFlip(p=1),
        torchvision.transforms.RandomRotation(90, resample=False, 
                                              expand=False, center=None),
        torchvision.transforms.ToTensor(),
        normalize,
    ])
else:
    trainset = torchvision.transforms.Compose([
        torchvision.transforms.Resize(224), #resize all the pictures 
        torchvision.transforms.ToTensor(),
        normalize
    ])


#VALIDATION SET

#load the validation set 

validset = torch.utils.data.DataLoader(
    dataset_loader, batch_size=batch_size, sampler=valid_sampler,
    num_workers=num_workers, pin_memory=pin_memory
)


#Validation: Data Augmentation, Resizing, and Normalization   

validset = torchvision.transforms.Compose([
        torchvision.transforms.Resize(224), #resize all the pictures 
        torchvision.transforms.RandomCrop(224,pad_if_needed=True),
        torchvision.transforms.RandomCrop(224,pad_if_needed=True),
        torchvision.transforms.RandomRotation(45, resample=False, 
                                              expand=False, center=None),
        torchvision.transforms.RandomCrop(224,pad_if_needed=True),
        torchvision.transforms.RandomCrop(224,pad_if_needed=True),
        torchvision.transforms.RandomRotation(30, resample=False, 
                                              expand=False, center=None),
        torchvision.transforms.RandomCrop(224,pad_if_needed=True),
        torchvision.transforms.RandomGrayscale(p=0.1),
        torchvision.transforms.RandomVerticalFlip(p=1),
        torchvision.transforms.RandomHorizontalFlip(p=1),
        torchvision.transforms.RandomRotation(90, resample=False, 
                                              expand=False, center=None),
        torchvision.transforms.ToTensor(),
        normalize,
])


#TEST SET

#load the test set

testset = torch.utils.data.DataLoader(
    dataset_loader, batch_size=batch_size, sampler=test_sampler, 
    num_workers=num_workers, pin_memory=pin_memory,
)

#Test: Data Augmentation, Resizing, and Normalization 

testset = torchvision.transforms.Compose([
        torchvision.transforms.Resize(224), #resize all the pictures 
        torchvision.transforms.RandomCrop(224,pad_if_needed=True),
        torchvision.transforms.RandomCrop(224,pad_if_needed=True),
        torchvision.transforms.RandomRotation(45, resample=False, 
                                              expand=False, center=None),
        torchvision.transforms.RandomCrop(224,pad_if_needed=True),
        torchvision.transforms.RandomCrop(224,pad_if_needed=True),
        torchvision.transforms.RandomRotation(30, resample=False, 
                                              expand=False, center=None),
        torchvision.transforms.RandomCrop(224,pad_if_needed=True),
        torchvision.transforms.RandomGrayscale(p=0.1),
        torchvision.transforms.RandomVerticalFlip(p=1),
        torchvision.transforms.RandomHorizontalFlip(p=1),
        torchvision.transforms.RandomRotation(90, resample=False, 
                                              expand=False, center=None),
        torchvision.transforms.ToTensor(),
        normalize,
])

return trainset, validset, testset

Здесь я вызываю функцию:

 if __name__ == "__main__":

    random_seed = random.seed(1000)
    the_dir = '/content/drive/My Drive/Deep Learning/data'
    get_train_valid_test_loader(the_dir,32, random_seed, True)

    print(len(trainset))
    print(len(validset))
    print(len(testset))


    main()

Я понятия не имею, почему я получаю следующую ошибку:

 NameError                         Traceback (most recent call   last)
 <ipython-input-36-8b4c3d896768> in <module>()
  6   get_train_valid_test_loader(the_dir,32, random_seed, True)
  7 
  8   print(len(trainset)). <----------- Error
  9   print(len(validset))
 10   print(len(testset))

 NameError: name 'trainset' is not defined

Я думал, что использовал "trainset", прежде чем даже определить его, но, похоже, это не так.Я пытался присвоить "dataset_loader" для "transet", но все еще не дал результатов.Я также пытался загрузить "trainset", используя ту же самую строку, которую я использовал для моего dataset_loader.

1 Ответ

0 голосов
/ 09 декабря 2018

Вызов функции возвращает значения в return функции, но это не имеет ничего общего с именами этих значений, как определено в области действия функции.Вот гораздо более простой случай, чтобы воспроизвести вашу ошибку:

def SimpleExample(a):
    return a

if __name__ == '__main__':
    SimpleExample(1)
    print(a)

И, конечно, ожидаемая ошибка:

NameError: name 'a' is not defined

Process finished with exit code 1

Итак, если вы хотите захватить значение, возвращаемое функцией, вынеобходимо сделать это явно, и затем вы можете вернуть эти значения в любое имя переменной:

def SimpleExample (a): вернуть

if __name__ == '__main__':
    a = SimpleExample(1)
    print(a)
    b = SimpleExample(2)
    print(b)
    x = SimpleExample([0, 0, 0])
    print(x)

И тогда все будет хорошо:

1
2
[0, 0, 0]

Process finished with exit code 0

Возвращаясь к вашему делу, попробуйте простое:

if __name__ == "__main__":

    random_seed = random.seed(1000)
    the_dir = '/content/drive/My Drive/Deep Learning/data'
    trainset, validset, testset = get_train_valid_test_loader(the_dir,32, random_seed, True)

    print(len(trainset))
    print(len(validset))
    print(len(testset))

Удачи!

...