Как нормализовать изображения в PyTorch - PullRequest
1 голос
/ 02 марта 2020
transform = transforms.Compose([
    transforms.ToTensor()
])

trainset = torchvision.datasets.ImageFolder(root='C:/Users/beomseokpark/Desktop/CNN/train_data', transform = transform)
data_loader = DataLoader(dataset = trainset, batch_size = 8, shuffle = True, num_workers=2)

with torch.no_grad():
    for num, data in enumerate(trainset):
        imgs, label = data

Я загрузил изображения с помощью ImageFolder в библиотеку torchvision, и как я могу получить среднее и стандартное значение для каждого канала моих изображений?

Может кто-нибудь помочь мне?

Ответы [ 2 ]

1 голос
/ 02 марта 2020

Существует подход «ленивый человек»: вы можете просто подключить nn.BatchNorm2d в качестве самого первого слоя вашей сети. С соответствующими momentum и track_running_stats=True этот слой будет оценивать среднее значение и дисперсию ваших данных.

В качестве альтернативы, вы можете вычислить среднее значение и дисперсию, используя

mu = torch.zeros((3,), dtype=torch.float)
sig = torch.zeros((3,), dtype=torch.float)
n = 0
with torch.no_grad():
    for num, data in enumerate(trainset):
        imgs, _ = data 
        mu += torch.sum(imgs, dim=(0, 2, 3))
        sig += torch.sum(imgs**2, dim=(0, 2, 3))
        n += imgs.numel() // imgs.shape[0]
n  = float(n)
mu = mu / n  # mean
sig = sig / n - (mu ** 2)
1 голос
/ 02 марта 2020
import torch as t

batch_size = 8
imgs = t.empty(batch_size, 3, 128, 128).normal_()

t.nn.Flatten(start_dim=1)(imgs.permute(1, 0, 2, 3)).mean(dim=1)
t.nn.Flatten(start_dim=1)(imgs.permute(1, 0, 2, 3)).std(dim=1).shape

факел. Размер ([3])

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...