import numpy as np
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use("TkAgg")
import os, h5py
import PIL
#------------------------------
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
#------------------------------
from data_augmentation import *
#------------------------------
dtype = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
class NiftiDataset(Dataset):
def __init__(self,transformation_params,data_path, mode='train',transforms=None ):
"""
Parameters:
data_path (string): Root directory of the preprocessed dataset.
mode (string, optional): Select the image_set to use, ``train``, ``valid``
transforms (callable, optional): Optional transform to be applied
on a sample.
"""
self.data_path = data_path
self.mode = mode
self.images = []
self.labels = []
self.W_maps = []
self.centers = []
self.radiuss = []
self.pixel_spacings = []
self.transformation_params = transformation_params
self.transforms = transforms
#-------------------------------------------------------------------------------------
if self.mode == 'train':
self.data_path = os.path.join(self.data_path,'train_set')
elif self.mode == 'valid':
self.data_path = os.path.join(self.data_path,'validation_set')
#-------------------------------------------------------------------------------------
for _, _, f in os.walk(self.data_path):
for file in f:
hdf_file = os.path.join(self.data_path,file)
data = h5py.File(hdf_file,'r') # Dictionary
# Preprocessing of Input Image and Label
patch_img, patch_gt, patch_wmap = PreProcessData(file, data, self.mode, self.transformation_params)
#print(type(data))
self.images.append(patch_img) # 2D image
#print('image shape is : ',patch_img.shape)
self.labels.append(patch_gt) # 2D label
#print('label shape is : ',patch_img.shape)
self.W_maps.append(patch_wmap) # Weight_Map
# self.centers.append(data['roi_center'][:]) # [x,y]
# self.radiuss.append(data['roi_radii'][:]) # [R_min,R_max]
# self.pixel_spacings.append(data['pixel_spacing'][:]) # [x , y , z]
def __len__(self):
return len(self.images)
def __getitem__(self, index):
image = self.images[index]
label = self.labels[index]
W_map = self.W_maps[index]
if self.transforms is not None:
image, label, W_maps = self.transforms(image, label, W_map)
return image, label, W_map
#=================================================================================================
if __name__ == '__main__':
# Test Routinue to check your threaded dataloader
# ACDC dataset has 4 labels
n_labels = 4
path = './hdf5_files'
batch_size = 1
# Data Augmentation Parameters
# Set patch extraction parameters
size1 = (128, 128)
patch_size = size1
mm_patch_size = size1
max_size = size1
train_transformation_params = {
'patch_size': patch_size,
'mm_patch_size': mm_patch_size,
'add_noise': ['gauss', 'none1', 'none2'],
'rotation_range': (-5, 5),
'translation_range_x': (-5, 5),
'translation_range_y': (-5, 5),
'zoom_range': (0.8, 1.2),
'do_flip': (False, False),
}
valid_transformation_params = {
'patch_size': patch_size,
'mm_patch_size': mm_patch_size}
transformation_params = { 'train': train_transformation_params,
'valid': valid_transformation_params,
'n_labels': 4,
'data_augmentation': True,
'full_image': False,
'data_deformation': False,
'data_crop_pad': max_size}
#====================================================================
dataset = NiftiDataset(transformation_params=transformation_params,data_path=path,mode='train')
dataloader = DataLoader(dataset=dataset,batch_size=2,shuffle=True,num_workers=0)
dataiter = iter(dataloader)
data = dataiter.next()
images, labels,W_map = data
#===============================================================================
# Data Visualization
#===============================================================================
print('image: ',images.shape,images.type(),'label: ',labels.shape,labels.type(),
'W_map: ',W_map.shape,W_map.type())
img = transforms.ToPILImage()(images[0,0,:,:,0].float())
lbl = transforms.ToPILImage()(labels[0,0,:,:].float())
W_mp = transforms.ToPILImage()(W_map [0,0,:,:].float())
plt.subplot(1,3,1)
plt.imshow(img,cmap='gray',interpolation=None)
plt.title('image')
plt.subplot(1,3,2)
plt.imshow(lbl,cmap='gray',interpolation=None)
plt.title('label')
plt.subplot(1,3,3)
plt.imshow(W_mp,cmap='gray',interpolation=None)
plt.title('Weight Map')
plt.show()
Я заметил несколько странных такие вещи, как типы Tensor, различны, даже если изображения, метки и карты весов являются изображениями одного типа и размера. Отслеживание ошибок:
Traceback (most recent call last):
File "D:\Saudi_CV\Vibot\Smester_2\2_Medical Image analysis\Project_2020\OUR_Project\data_loader.py", line 118, in <module>
data = dataiter.next()
File "F:\Download_2019\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 345, in __next__
data = self._next_data()
File "F:\Download_2019\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 385, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
File "F:\Download_2019\Anaconda3\lib\site-packages\torch\utils\data\_utils\fetch.py", line 47, in fetch
return self.collate_fn(data)
File "F:\Download_2019\Anaconda3\lib\site-packages\torch\utils\data\_utils\collate.py", line 79, in default_collate
return [default_collate(samples) for samples in transposed]
File "F:\Download_2019\Anaconda3\lib\site-packages\torch\utils\data\_utils\collate.py", line 79, in <listcomp>
return [default_collate(samples) for samples in transposed]
File "F:\Download_2019\Anaconda3\lib\site-packages\torch\utils\data\_utils\collate.py", line 64, in default_collate
return default_collate([torch.as_tensor(b) for b in batch])
File "F:\Download_2019\Anaconda3\lib\site-packages\torch\utils\data\_utils\collate.py", line 55, in default_collate
return torch.stack(batch, 0, out=out)
RuntimeError: Expected object of scalar type Double but got scalar type Long for sequence element 1 in sequence argument at position #1 'tensors'
[Finished in 19.9s with exit code 1]