Преобразования применяются при вызове метода __getitem__
для Dataset
. Например, посмотрите на метод __getitem__
класса набора данных MNIST
: https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py#L62
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], self.targets[index]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img.numpy(), mode='L')
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
Метод __getitem__
вызывается, когда вы индексируете свой экземпляр MNIST
для обучающего набора, например ::
trainset[0]
Для получения дополнительной информации о __getitem__
: https://docs.python.org/3.6/reference/datamodel.html#object.getitem
Причина, по которой Resize
и RandomHorizontalFlip
должны быть до ToTensor
, заключается в том, что они действуют на PIL Images , и все наборы данных в Pytorch для согласованности загружают данные как PIL Image
с. На самом деле вы можете видеть, что здесь они вынуждают это поведение:
img = Image.fromarray(img.numpy(), mode='L')
Когда у вас есть PIL Image
соответствующего индекса, преобразования применяются с
if self.transform is not None:
img = self.transform(img)
ToTensor
преобразует PIL Image
в torch.Tensor
и Normalize
вычитает среднее значение и делит на стандартное отклонение, которое вы предоставляете.
В конце концов некоторые преобразования применяются к метке с
if self.target_transform is not None:
target = self.target_transform(target)
Наконец обработанное изображение и обработанная метка возвращаются. Все это происходит за один trainset[key]
звонок.
import torch
from torchvision.transforms import *
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
transform_train = Compose([Resize(28, interpolation=2),
RandomHorizontalFlip(p=0.5),
ToTensor(),
Normalize([0.], [1.])])
trainset = MNIST(root='./data', train=True, download=True,
transform=transform_train)
trainloader = DataLoader(trainset, batch_size=32, shuffle=True, num_workers=2)
print(trainset[0][0].size(), trainset[0][0].min(), trainset[0][0].max())
показывает
(torch.Size([1, 28, 28]), tensor(0.), tensor(1.))