Показать примеры дополненных изображений в PyTorch - PullRequest
1 голос
/ 15 марта 2019

Я хочу показать несколько образцов дополненных тренировочных образов.

Мое преобразование включает стандарт ImageNet transforms.Normalize, например:

train_transforms = transforms.Compose([transforms.RandomRotation(30),
                                       transforms.RandomResizedCrop(224),
                                       transforms.RandomHorizontalFlip(),
                                       transforms.ToTensor(),
                                       transforms.Normalize([0.485, 0.456, 0.406],
                                                            [0.229, 0.224, 0.225])])

Однако из-за Normalise изображения отображаются странными цветами.

В этом ответе говорится, что мне нужен доступ к исходному изображению, что затруднительно, когда преобразования применяются во время загрузки:

image_datasets['train'] = datasets.ImageFolder(train_dir, transform=train_transforms)

Как бы я мог показать несколько образцов увеличенных изображений в их обычных цветах, используя нормализованные для расчета?

Ответы [ 4 ]

2 голосов
/ 15 марта 2019

Я предлагаю два варианта:

  1. Создайте отдельный этап «преобразования», который отображает изображение и передает его дальше без изменений.Бесплатный бонус заключается в том, что вы можете вставить его на любом этапе в списке трансформации.
    import cv2
    import numpy as np
    def TransformShow(name="img", wait=100):
        def transform_show(img):
            cv2.imshow(name, np.array(img))
            cv2.waitKey(wait)
            return img
        return transform_show

Вставьте этот «трансформатор» перед ToTensor():

                                       transforms.RandomHorizontalFlip(),
                                       TransformShow("window_name", delay_in_ms),
                                       transforms.ToTensor(),

Используйте нольdelay_in_ms для ожидания нажатия клавиши.

Я использую OpenCV здесь для отображения изображений.Это также можно сделать только с помощью Pillow / PIL, но мне не понравилось, как он справляется с этим.

Отменить нормализацию и отобразить изображение.
def show_image(img, name="img", wait=100):
    mean = np.array([0.485, 0.456, 0.406])
    std =  np.array([0.229, 0.224, 0.225])
    cv2.imshow(name, img.cpu().numpy().transpose((1,2,0)) * std + mean)
    cv2.waitKey(wait)

, а затем вызвать его как

        show_image(data[0], "unaug", 1)
Последний подход может быть аппроксимирован быстрым двухслойным, но с несколько искаженными цветами:
    cv2.imshow("approx", data[0].cpu().numpy().transpose((1,2,0)) * 0.225 + 0.45)
    cv2.waitKey(10)
1 голос
/ 15 марта 2019

Я уже сталкивался с той же проблемой. Моим решением было создать другой torch.Dataset с дополнением данных, но без нормализации.

Здесь Я создаю Dataset. Здесь У меня есть класс, который реализует дополнения. У меня есть два члена: self.tf_augment и self.tf_transform. В первом случае применяется только увеличение данных, а во втором - увеличение данных и нормализация.

0 голосов
/ 18 марта 2019

Чтобы ответить на мой собственный вопрос, я предложил следующее:

Example output

# Undo transforms.Normalize
def denormalise(image):
    image = image.numpy().transpose(1, 2, 0)  # PIL images have channel last
    mean = [0.485, 0.456, 0.406]
    stdd = [0.229, 0.224, 0.225]
    image = (image * stdd + mean).clip(0, 1)
    return image


example_rows = 2
example_cols = 5

sampler = torch.utils.data.RandomSampler(image_datasets['train'],
                                         num_samples=example_rows * example_cols)

# Get a batch of images and labels  
images, indices = next(iter(sampler)) 

plt.rcParams['figure.dpi'] = 120  # Increase size of pyplot plots

# Show a grid of example images    
fig, axes = plt.subplots(example_rows, example_cols, figsize=(9, 5)) #  sharex=True, sharey=True)
axes = axes.flatten()
for ax, image, index in zip(axes, images, indices):
    ax.imshow(denormalise(image))
    ax.set_axis_off()
    ax.set_title(class_names[index], fontsize=7)

fig.subplots_adjust(wspace=0.02, hspace=0)
fig.suptitle('Augmented training set images', fontsize=20)
plt.show()

Это основано на коде учебного пособия по переводу PyTorch , ноотображает заголовок над каждым изображением и в целом выглядит намного лучше.

0 голосов
/ 16 марта 2019

Просто отмените операцию нормализации, то есть умножьте на стандартное отклонение и добавьте значение.

Пожалуйста, ознакомьтесь с методом imshow в уроках по pytorch: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html#visualize-a-few-images

...