Загрузка изображений FITS с помощью PyTorch - PullRequest
0 голосов
/ 08 мая 2018

Я пытаюсь создать CNN, используя PyTorch, но мои изображения нуждаются в импорте из формата FITS, а не в обычный .png или .jpeg и т. Д.

Есть ли способ сделать это легко с помощью torch.utils.data.DataLoader или есть место в исходном коде, где я могу вставить предложение, которое будет обрабатывать файлы FITS при загрузке?

Я посмотрел в документации, и самое подходящее, что я нашел, - это преобразователь ToPILImage, который преобразует тензор или ndarray в PIL-изображение.

В настоящее время я использую процедуру загрузки изображения следующим образом:

import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision

batch_size = 4

transform = transforms.Compose(
                   [transforms.Resize((32,32)),
                    transforms.ToTensor(),
                    ])

trainset = dset.ImageFolder(root="Documents/Image_data",transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,shuffle=True)

Астропия: http://www.astropy.org/

Pytorch: https://pytorch.org/

torch.utils: https://pytorch.org/docs/master/data.html

ОБНОВЛЕНИЕ: Возможно, используя torchvision.datasets.DatasetFolder вместо DataLoader, вставка в мой собственный обработчик FITS будет работать?

При попытке использовать этот класс я получаю следующую ошибку:

AttributeError: module 'torchvision.datasets' has no attribute 'DatasetFolder'

Действительно ли DatasetFolder на данный момент поддерживается torchvision?

Ответы [ 3 ]

0 голосов
/ 09 мая 2018

После прочтения некоторой комбинации документов и кода, я не думаю, что вы обязательно захотите использовать ImageFolder, так как он ничего не знает о FITS.

Вместо этого вы должны попытаться использовать более общий класс DataSetFolder (который фактически является родительским классом ImageFolder). Вы бы передали ему список расширений, которые он должен обрабатывать (например, ['.fits'] и функцию «загрузчик», которая принимает файл FITS и, похоже, должна возвращать PIL.Image.

Вы можете даже создать свой собственный подкласс по примеру ImageFolder. Э.Г.

class FitsFolder(DatasetFolder):

    EXTENSIONS = ['.fits']

    def __init__(self, root, transform=None, target_transform=None,
                 loader=None):
        if loader is None:
            loader = self.__fits_loader

        super(FitsFolder, self).__init__(root, loader, self.EXTENSIONS,
                                         transform=transform,
                                         target_transform=target_transform)

    @staticmethod
    def __fits_loader(filename):
        data = fits.getdata(filename)
        return Image.fromarray(data)

Точные данные __fits_loader могут зависеть от данных ваших файлов FITS. В этом базовом примере просто используется высокоуровневая функция fits.getdata(), которая возвращает первый массив изображений в файле FITS (некоторые файлы FITS могут иметь много расширений со многими изображениями или иметь таблицы и т. Д.). Так что эта часть была бы за вами.

0 голосов
/ 03 июля 2018

Я столкнулся с той же проблемой, что и @ user8188120, несколько недель назад. Использование ответа @ Iguananaut прекрасно работает при чтении меток из структуры папок. Если кто-то сталкивается с этим и нуждается в чтении из CSV-файла, это также может работать:

labels = []
transform = transforms.Compose([
    # here go your transforms
    ])


class MyFitsDataset(data.Dataset):
    def __init__(self, csv_path):
        # Read the csv file
        self.data_info = pd.read_csv(csv_path, header=None)
        # First column contains the image paths
        self.image_arr = np.asarray(self.data_info.iloc[:, 0])
        # the rest contain the labels
        self.label_arr = np.asarray(self.data_info.iloc[:, 1:])  # for multi-label
        self.label_arr = np.asarray(self.data_info.iloc[:, 1])  # for single-label
        labels.append(self.label_arr)
        self.data_len = len(self.data_info.index)

    def __getitem__(self, index):
        single_image_name = self.image_arr[index]

        data = pyfits.open(single_image_name, axes=2)
        data = data[0].data.astype('float32')
        data = data.reshape(IMG_WIDTH, IMG_HEIGHT, CHANNELS)

        img = transform(data)

        # Get label(class) of the image based on the pandas column
        single_image_label = self.label_arr[index]

        return (img, single_image_label)

    def __len__(self):
        return self.data_len

Это также позволяет избежать использования класса DatasetFolder, который по-прежнему недоступен в новейшей версии PyTorch. Надеюсь, это кому-нибудь поможет.

0 голосов
/ 08 мая 2018

Вы можете экспортировать изображение FITS в любой формат, поддерживаемый pyplot.imsave () , используя этот метод:

from astropy.io import fits
import matplotlib.pyplot as plt

image_data = fits.getdata(r"/path/to/image.fits")
plt.imsave("/path/to/image.png", image_data, cmap="gray")
...