Отображение дополненного изображения с помощью Pytorch - PullRequest
1 голос
/ 13 марта 2019

Этот фрагмент в основном предоставлен @ptrblck на форуме Pytorch для увеличения данных на некоторых изображениях.

Задача - сегментация, поэтому я предполагаю, что изображение и соответствующая ему маска должны быть дополнены.

Мне интересно, как я могу отобразить некоторые изображения и соответствовать маске после преобразования, чтобы узнать, как они выглядят?

Вот сценарий:

import torch
from torch.utils.data.dataset import Dataset  # For custom data-sets
import torchvision.transforms as transforms
import torchvision.transforms.functional as tf
from PIL import Image
import numpy 
import glob
import matplotlib.pyplot as plt
from split_dataset import test_loader
import os

class CustomDataset(Dataset):
    def __init__(self, image_paths, target_paths, transform_images, transform_masks):   

    self.image_paths = image_paths
    self.target_paths = target_paths

    self.transform_images = transform_images
    self.transform_masks = transform_masks


    self.transformm = transforms.Compose([transforms.Lambda(lambda x: tf.rotate(x, 10)),
                                          transforms.Lambda(lambda x: tf.affine(x, angle=0,
                                      translate=(0, 0),
                                      scale=0.2,
                                      shear=0.2))
                                        ])

    self.transform = transforms.ToTensor()

    self.mapping = {
        0: 0,
        255: 1              
    }

def mask_to_class(self, mask):
    for k in self.mapping:
        mask[mask==k] = self.mapping[k]
    return mask

def __getitem__(self, index):

    image = Image.open(self.image_paths[index])
    mask = Image.open(self.target_paths[index])

    if any([img in self.image_paths[index] for img in self.transform_images]):
        print('applying special transformation')
        image = self.transformm(image) #augmentation

    if any([msk in self.target_paths[index] for msk in self.transform_masks]):
        print('applying special transformation')
        image = self.transformm(mask) #augmentation

    t_image = image.convert('L')
    t_image = self.transform(t_image) # transform to tensor for image
    mask = self.transform(mask) # transform to tensor for mask


    mask = torch.from_numpy(numpy.array(mask, dtype=numpy.uint8)) 
    mask = self.mask_to_class(mask)
    mask = mask.long()

    return t_image, mask, self.image_paths[index], self.target_paths[index] 

def __len__(self):  # return count of sample we have

    return len(self.image_paths)


image_paths = glob.glob("D:\\Neda\\Pytorch\\U-net\\my_data\\imagesResized\\*.png")
target_paths = glob.glob("D:\\Neda\\Pytorch\\U-net\\my_data\\labelsResized\\*.png")


transform_images = ['image_981.png', 'image_982.png','image_983.png', 'image_984.png', 'image_985.png',
                    'image_986.png','image_987.png','image_988.png','image_989.png','image_990.png',
                    'image_991.png']  # apply special transformation only on these images
print(transform_images)
#['image_991.png', 'image_991.png']

transform_masks = ['image_labeled_981.png', 'image_labeled_982.png','image_labeled_983.png', 'image_labeled_984.png',
                    'image_labeled_985.png', 'image_labeled_986.png','image_labeled_987.png','image_labeled_988.png',
                    'image_labeled_989.png','image_labeled_990.png',
                    'image_labeled_991.png'] 

dataset = CustomDataset(image_paths, target_paths, transform_images, transform_masks)

for transform_images in dataset:

    #print(transform_images)        
    transform_images = Image.open(os.path.join(image_paths, transform_images))
    transform_images = numpy.array(transform_images)

    transform_masks = Image.open(os.path.join(target_paths, transform_masks))
    transform_masks = numpy.array(transform_masks)


    fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=1, sharex=True, sharey=True, figsize = (6,6))

    img1 = ax1.imshow(transform_images, cmap='gray')
    ax1.axis('off')   

    img2 = ax2.imshow(transform_masks)
    ax1.axis('off')        
    plt.show() 

в настоящее время это вызывает ошибку

путь = os.fspath (путь)
Ошибка типа: ожидаемый объект str, bytes или os.PathLike, не кортеж

1 Ответ

0 голосов
/ 13 марта 2019

glob.glob возвращает список имен путей, соответствующих вводу.Вы используете это, как будто это путь.Вы можете взять базовый путь и присоединить его к имени вашего изображения.Я бы также предложил не использовать имя переменной transform_images в цикле for.Я переименовал его в current_image и current_mask соответственно.

Вот пересмотренный код:

basePath = 'D:\\Neda\\Pytorch\\U-net\\my_data\\imagesResized\\'
image = Image.open(os.path.join(basePath, current_image))

[...]

targetPath = 'D:\\Neda\\Pytorch\\U-net\\my_data\\labelsResized\\'
mask = Image.open(os.path.join(targetPath, current_mask))
...