Как работают эти функции Python? Сегментация с помощью U-Net - PullRequest
0 голосов
/ 25 июня 2019

Я работаю над проектом и хочу лучше понять некоторый код из источника, который я обнаружил.Идея состоит в том, что я хочу сделать некоторую семантическую сегментацию, используя U-Net.Я понял почти все, что произошло в проекте, за исключением двух функций.

Первая функция связана с потреблением памяти (так сказал парень, который делал проект).Идея в том, что у меня есть сеть U-Net, и после того, как они выполняют последнюю операцию свертки, есть еще 2 операции, которые применяются к извилистому слою.После этих операций они применяют активацию и т. Д.

    conv6 = core.Reshape((2, patch_height * patch_width))(conv6)
    conv6 = core.Permute((2, 1))(conv6)

Хорошо, теперь, после этого, в основном учебном модуле, перед запуском функции model.fit, выходная модель преобразуется с помощью функции, котораяЯ говорил вам о том, что улучшается потребление памяти.Ниже у вас есть функция.

def function_unet_masks(masks):
    im_h = masks.shape[2]
    im_w = masks.shape[3]
    masks = np.reshape(masks, (masks.shape[0], im_h * im_w))
    new_masks = np.empty((masks.shape[0], im_h * im_w, 2))
    for i in range(masks.shape[0]):
        for j in range(im_h * im_w):
            if masks[i, j] == 0:
                new_masks[i, j, 0] = 1
                new_masks[i, j, 1] = 0
            else:
                new_masks[i, j, 0] = 0
                new_masks[i, j, 1] = 1
    return new_masks

Что означает код выше?Почему лучше при потреблении памяти?Я протестировал все без этой функции, и да, «потеря» тренировочной модели резко возрастает, очень большая, а также все занимает больше времени.

Теперь вторая проблема.Я делаю фазу обучения на основе патчей.Таким образом, я разделил тренировочный набор данных на маленькие патчи и изучаю все на основе патчей.После того, как у меня есть обученная модель, и я хочу сделать тест, предсказания также являются патчами.Таким образом, в последнем шаге мне нужно восстановить изображения на основе предсказанных исправлений.Проблема в том, что я не могу понять, почему они используют сумму и вероятность пикселей, чтобы вернуть окончательный массив с правильным порядком исправлений.Я понял, что патчи перекрываются, и это должно переделать изображение без наложений, или что-то в этом роде.Ниже у вас есть функция.

def reconstruct_overlapping_images(preds, img_h, img_w, stride_h, stride_w):
    assert (len(preds.shape) == 4)  # 4D arrays
    assert (preds.shape[1] == 1 or preds.shape[1] == 3)  # check the channel is 1 or 3
    patch_h = preds.shape[2]
    patch_w = preds.shape[3]
    N_patches_h = (img_h - patch_h) // stride_h + 1
    N_patches_w = (img_w - patch_w) // stride_w + 1
    N_patches_img = N_patches_h * N_patches_w
    print("N_patches_h: " + str(N_patches_h))
    print("N_patches_w: " + str(N_patches_w))
    print("N_patches_img: " + str(N_patches_img))
    assert (preds.shape[0] % N_patches_img == 0)
    N_full_imgs = preds.shape[0] // N_patches_img
    print("According to the dimension inserted, there are " + str(N_full_imgs) + " full images (of " + str(
        img_h) + "x" + str(img_w) + " each)")
    full_prob = np.zeros(
        (N_full_imgs, preds.shape[1], img_h, img_w))  # itialize to zero mega array with sum of Probabilities
    full_sum = np.zeros((N_full_imgs, preds.shape[1], img_h, img_w))

    k = 0  # iterator over all the patches
    for i in range(N_full_imgs):
        for h in range((img_h - patch_h) // stride_h + 1):
            for w in range((img_w - patch_w) // stride_w + 1):
                full_prob[i, :, h * stride_h:(h * stride_h) + patch_h, w * stride_w:(w * stride_w) + patch_w] += preds[
                    k]
                full_sum[i, :, h * stride_h:(h * stride_h) + patch_h, w * stride_w:(w * stride_w) + patch_w] += 1
                k += 1
    assert (k == preds.shape[0])
    assert (np.min(full_sum) >= 1.0)  # at least one
    final_avg = full_prob / full_sum
    print(final_avg.shape)
    assert (np.max(final_avg) <= 1.0)  # max value for a pixel is 1.0
    assert (np.min(final_avg) >= 0.0)  # min value for a pixel is 0.0
    return final_avg

Можете ли вы помочь мне, пожалуйста, в понимании этих функций и их использования?

Спасибо

...