Я написал класс, который случайно обрезает 3d-том. Требуется размер тома и crop_size.
Интересно, есть ли более эффективный и pythoni c способ переписать мой код
class RandomCrop3D():
def __init__(self, crop_size):
self.crop_sz = tuple(crop_size)
def __call__(self, x):
img_sz = x.shape
slice_hwd = [self._get_slice(i, k) for i, k in zip(img_sz, self.crop_sz)]
return self._crop(x, *slice_hwd)
@staticmethod
def _get_slice(sz, crop_sz):
try :
lower_bound = torch.randint(sz-crop_sz, (1,)).item()
return lower_bound, lower_bound + crop_sz
except:
return (None, None)
@staticmethod
def _crop(x, slice_h, slice_w, slice_d):
return x[:, slice_h[0]:slice_h[1], slice_w[0]:slice_w[1], slice_d[0]:slice_d[1]]
пример для проверки:
volume_3d = torch.rand(3, 100, 100, 100)
rand_crop = RandomCrop3D((64, 64, 64))
rand_crop(volume_3d)