Нахождение среднего и стандартного отклонения по каналам изображения PyTorch - PullRequest
4 голосов
/ 06 февраля 2020

Скажем, у меня есть серия изображений в виде тензоров с размерами (B x C x W x H), где B - размер партии, C - количество каналов в изображении, а W и H - ширина и высота изображения соответственно. Я хочу использовать функцию transforms.Normalize() для нормализации моих изображений относительно среднего значения и стандартного отклонения набора данных по C каналам изображений , что означает, что я хочу получить результирующий тензор в форме 1 х C. Есть ли простой способ сделать это?

Я пробовал torch.view(C, -1).mean(1) и torch.view(C, -1).std(1), но я получаю ошибку:

view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

Редактировать

После изучения того, как view() работает в PyTorch, я понимаю, почему мой подход не работает; однако я все еще не могу понять, как получить среднее значение для каждого канала и стандартное отклонение.

Ответы [ 2 ]

2 голосов
/ 06 февраля 2020

Вам просто нужно правильно переставить тензор партии: от [B, C, W, H] до [B, C, W * H]:

batch = batch.view(batch.size(0), batch.size(1), -1)

Вот полный пример использования случайных данных:

Код:

import torch
from torch.utils.data import TensorDataset, DataLoader

data = torch.randn(64, 3, 28, 28)
labels = torch.zeros(64, 1)
dataset = TensorDataset(data, labels)
loader = DataLoader(dataset, batch_size=8)

nimages = 0
mean = 0.
std = 0.
for batch, _ in loader:
    # Rearrange batch to be the shape of [B, C, W * H]
    batch = batch.view(batch.size(0), batch.size(1), -1)
    # Update total number of images
    nimages += batch.size(0)
    # Compute mean and std here
    mean += batch.mean(2).sum(0) 
    std += batch.std(2).sum(0)

# Final step
mean /= nimages
std /= nimages

print(mean)
print(std)

Выход:

tensor([-0.0029, -0.0022, -0.0036])
tensor([0.9942, 0.9939, 0.9923])
1 голос
/ 22 марта 2020

Обратите внимание, что отклонения добавляют, а не стандартные отклонения. Смотрите подробное объяснение здесь: https://apcentral.collegeboard.org/courses/ap-statistics/classroom-resources/why-variances-add-and-why-it-matters

Вот модифицированный код:

nimages = 0
mean = 0.0
var = 0.0
for i_batch, batch_target in enumerate(trainloader):
    batch = batch_target[0]
    # Rearrange batch to be the shape of [B, C, W * H]
    batch = batch.view(batch.size(0), batch.size(1), -1)
    # Update total number of images
    nimages += batch.size(0)
    # Compute mean and std here
    mean += batch.mean(2).sum(0) 
    var += batch.var(2).sum(0)

mean /= nimages
var /= nimages
std = torch.sqrt(var)

print(mean)
print(std)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...