Вам просто нужно правильно переставить тензор партии: от [B, C, W, H]
до [B, C, W * H]
:
batch = batch.view(batch.size(0), batch.size(1), -1)
Вот полный пример использования случайных данных:
Код:
import torch
from torch.utils.data import TensorDataset, DataLoader
data = torch.randn(64, 3, 28, 28)
labels = torch.zeros(64, 1)
dataset = TensorDataset(data, labels)
loader = DataLoader(dataset, batch_size=8)
nimages = 0
mean = 0.
std = 0.
for batch, _ in loader:
# Rearrange batch to be the shape of [B, C, W * H]
batch = batch.view(batch.size(0), batch.size(1), -1)
# Update total number of images
nimages += batch.size(0)
# Compute mean and std here
mean += batch.mean(2).sum(0)
std += batch.std(2).sum(0)
# Final step
mean /= nimages
std /= nimages
print(mean)
print(std)
Выход:
tensor([-0.0029, -0.0022, -0.0036])
tensor([0.9942, 0.9939, 0.9923])