pytorch 4d numpy массив с применением преобразований внутри пользовательского набора данных - PullRequest
0 голосов
/ 16 июня 2020

Внутри моего пользовательского набора данных я хочу применить transforms.Compose() к массиву NumPy.

Мои изображения находятся в формате массива NumPy с формой (num_samples, width, height, channels).

Как я могу применить следующие преобразования к полному массиву numpy?

img_transform = transforms.Compose([ transforms.Scale((224,224)), transforms.ToTensor(), transforms.Normalize([0.46, 0.48, 0.51], [0.32, 0.32, 0.32]) ])

Мои попытки заканчиваются несколькими ошибками, поскольку преобразования принимают изображение PIL, а не 4-d NumPy массив.

from torchvision import transforms
import numpy as np
import torch

img_transform = transforms.Compose([
        transforms.Scale((224,224)),
        transforms.ToTensor(),
        transforms.Normalize([0.46, 0.48, 0.51], [0.32, 0.32, 0.32])
    ])

a = np.random.randint(0,256, (299,299,3))
print(a.shape)

img_transform(a)

1 Ответ

1 голос
/ 16 июня 2020

Все преобразования torchvision работают с отдельными изображениями, а не с группами изображений, поэтому массив 4D использовать нельзя.

Отдельные изображения, заданные как массивы NumPy, как в вашем примере кода, можно использовать путем преобразования их в образ PIL. Вы можете просто добавить transforms.ToPILImage в начало конвейера преобразования, так как он преобразует либо тензор, либо массив NumPy в изображение PIL.

img_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize([0.46, 0.48, 0.51], [0.32, 0.32, 0.32])
    ])

Примечание: transforms.Scale устарел в пользу transforms.Resize.

В вашем примере вы использовали np.random.randint, который по умолчанию использует тип int64, но изображения должны быть uint8. Библиотеки, такие как OpenCV, возвращают массивы uint8 при загрузке изображения.

a = np.random.randint(0,256, (299,299,3), dtype=np.uint8)
...