Эффективно извлечь патч из изображения и этикетки - PullRequest
0 голосов
/ 25 апреля 2018

У меня есть проект сегментации.У меня есть изображения и ярлыки, на которых хранится правда о сегментацииИзображения большие и содержат много «пустых» областей.Я хочу вырезать патчи из изображения и метки, чтобы в патче была ненулевая метка.

Мне нужно, чтобы он был максимально эффективным .

Я написал следующий код, но он очень медленный.любые улучшения будут высоко оценены.

import numpy as np
import matplotlib.pyplot as plt
Позволяет создавать фиктивные данные
img = np.random.rand(300,200,3)
img[240:250,120:200]=0

mask = np.zeros((300,200))
mask[220:260,120:300]=0.7
mask[250:270,140:170]=0.3

f, axarr = plt.subplots(1,2, figsize = (10, 5))
axarr[0].imshow(img)
axarr[1].imshow(mask)[![enter image description here][1]][1]
plt.show()

given image and label

Мой неэффективный код:

IM_SIZE = 60     # Patch size

x_min, y_min = 0,0
x_max = img.shape[0] - IM_SIZE
y_max = img.shape[1] - IM_SIZE
xd, yd, x, y = 0,0,0,0

if (mask.max() > 0):
    xd, yd = np.where(mask>0)

    x_min = xd.min()
    y_min = yd.min()
    x_max = min(xd.max()- IM_SIZE-1, img.shape[0] - IM_SIZE-1)
    y_max = min(yd.max()- IM_SIZE-1, img.shape[1] - IM_SIZE-1)

    if (y_min >= y_max):

        y = y_max
        if (y + IM_SIZE >= img.shape[1] ): 
            print('Error')

    else:
        y = np.random.randint(y_min,y_max)

    if (x_min>=x_max):

        x = x_max
        if (x+IM_SIZE >= img.shape[0] ):
            print('Error')

    else:
        x = np.random.randint(x_min,x_max )
print(x,y)    
img = img[x:x+IM_SIZE, y:y+IM_SIZE,:]
mask = mask[x:x+IM_SIZE, y:y+IM_SIZE]

f, axarr = plt.subplots(1,2, figsize = (10, 5))
axarr[0].imshow(img)
axarr[1].imshow(mask)
plt.show()

enter image description here

1 Ответ

0 голосов
/ 26 апреля 2018

Снимок результата, предоставленного профилировщиком строк, выглядит следующим образом: enter image description here

Большую часть времени используется mask.max () (который можно изменить на np.max (маска) для некоторого ускорения) и np.where (маска> 0).

Если вам нужно каждый раз использовать функцию where для другой маски, взгляните на numberxpr .Или вы можете использовать joblib , чтобы сохранить результаты для x / y_min / max для данной маски, запустив много таких случаев параллельно.

Перестановка функции с помощью numba.jit дает мне лучшеРезультаты:

@jit
def temp(mask):
    xd, yd = np.where(mask>0)

    x_min = np.min(xd)
    y_min = np.min(yd)
    x_max = min(np.max(xd)- IM_SIZE-1, img.shape[0] - IM_SIZE-1)
    y_max = min(np.max(yd)- IM_SIZE-1, img.shape[1] - IM_SIZE-1)
    return x_min,x_max,y_min,y_max

def solver_new(img):
    IM_SIZE = 60     # Patch size

    x_min, y_min = 0,0
    x_max = img.shape[0] - IM_SIZE
    y_max = img.shape[1] - IM_SIZE
    xd, yd, x, y = 0,0,0,0

    if (np.max(mask) > 0):
        x_min,x_max,y_min,y_max = temp(mask)
        if (y_min >= y_max):

            y = y_max
            if (y + IM_SIZE >= img.shape[1] ): 
                print('Error')

        else:
            y = np.random.randint(y_min,y_max)

        if (x_min>=x_max):

            x = x_max
            if (x+IM_SIZE >= img.shape[0] ):
                print('Error')

        else:
            x = np.random.randint(x_min,x_max )
    return x,y

Поскольку размеры изображений и патчей невелики, результаты не слишком значимы, поскольку кэширование оказывает большое влияние на время.Я получаю примерно ~ 200us за реализацию, опубликованную в вопросе, и ~ 90us за ту, которая размещена здесь.

...