Восстановление изображений из патчей с помощью Python - PullRequest
1 голос
/ 04 июля 2019

Я изменяю код stuff_patches_3D, чтобы восстановить трехмерное изображение ребра из перекрывающихся патчей, но нахожу результаты не совсем правильными.

Сначала я использую следующий код для извлечения патчей:

def extract_patches(arr, patch_shape=24, extraction_step=8):
    # From: scikit-learn/sklearn/feature_extraction/image.py
    """Extracts patches of any n-dimensional array in place using strides.

    Parameters
    ----------
    arr : ndarray. n-dimensional array of which patches are to be extracted
    patch_shape : integer or tuple of length arr.ndim
    extraction_step : integer or tuple of length arr.ndim

    Returns
    -------
    patches : strided ndarray

    """
    print('Extract Patch...')
    arr_ndim = arr.ndim

    if isinstance(patch_shape, numbers.Number):
        patch_shape = tuple([patch_shape] * arr_ndim)
    if isinstance(extraction_step, numbers.Number):
        extraction_step = tuple([extraction_step] * arr_ndim)

    patch_strides = arr.strides

    slices = tuple(slice(None, None, st) for st in extraction_step)
    indexing_strides = arr[slices].strides

    patch_indices_shape = ((np.array(arr.shape) - np.array(patch_shape)) //
                           np.array(extraction_step)) + 1

    shape = tuple(list(patch_indices_shape) + list(patch_shape))
    strides = tuple(list(indexing_strides) + list(patch_strides))

    patches = as_strided(arr, shape=shape, strides=strides)

    return patches

Затем я использую модифицированные коды из поста Google для восстановления патчей:

def stuff_patches_3D(img_org, patches,xstep=12,ystep=12,zstep=12):
    out_shape = img_org.shape

    print('Recover image...')
    out = np.zeros(out_shape, patches.dtype)
    denom = np.zeros(out_shape, patches.dtype)
    patch_shape = patches.shape[-3:]

    patches_shape = ((out.shape[0]-patch_shape[0])//xstep+1, (out.shape[1]-patch_shape[1])//ystep+1,
                    (out.shape[2]-patch_shape[2])//zstep+1, patch_shape[0], patch_shape[1], patch_shape[2])
    patches_strides = (out.strides[0]*xstep, out.strides[1] * ystep,out.strides[2] * zstep, out.strides[0], out.strides[1],out.strides[2])

    patches_6D = np.lib.stride_tricks.as_strided(out, patches_shape, patches_strides)
    denom_6D = np.lib.stride_tricks.as_strided(denom, patches_shape, patches_strides)

    grid_inds = tuple(x.ravel() for x in np.indices(patches_6D.shape))
    np.add.at(patches_6D, grid_inds, patches.ravel())
    np.add.at(denom_6D, grid_inds, 1)

    # in case there are 0 elements in denom
    inds = denom != 0
    img_recover = np.zeros(out.shape, dtype=out.dtype)
    img_recover[inds] = out[inds]/denom[inds]

    return img_recover

Я использую plt.imshow () восстановленное изображение out_new, которое выглядит как исходное изображение. Но при использовании кодов inds = img_recover! = Img_org и num = np.sum (inds) для проверки восстановленного изображения я нахожу num намного больше 0. Фактически, img_org имеет форму 399 * 196 * 299, то есть 23382996 вокселей и num = 1317528 .

Я не могу найти причину. Любая помогает? Заранее спасибо!

...