Эффективный способ обрезать объем 3d в pytorch - PullRequest
0 голосов
/ 25 апреля 2020

Я написал класс, который случайно обрезает 3d-том. Требуется размер тома и crop_size.

Интересно, есть ли более эффективный и pythoni c способ переписать мой код

class RandomCrop3D():
    def __init__(self, crop_size):
        self.crop_sz = tuple(crop_size)

    def __call__(self, x):
        img_sz  = x.shape
        slice_hwd = [self._get_slice(i, k) for i, k in zip(img_sz, self.crop_sz)]
        return self._crop(x, *slice_hwd)

    @staticmethod
    def _get_slice(sz, crop_sz):
        try : 
            lower_bound = torch.randint(sz-crop_sz, (1,)).item()
            return lower_bound, lower_bound + crop_sz
        except: 
            return (None, None)

    @staticmethod
    def _crop(x, slice_h, slice_w, slice_d):
        return x[:, slice_h[0]:slice_h[1], slice_w[0]:slice_w[1], slice_d[0]:slice_d[1]]

пример для проверки:

volume_3d = torch.rand(3, 100, 100, 100)
rand_crop = RandomCrop3D((64, 64, 64))

rand_crop(volume_3d)
...