Пользовательский слой CoreML: нормализация по пикселям с металлическими шейдерами - PullRequest
0 голосов
/ 21 сентября 2018

Я конвертирую Nvidia Progressive Growing of GAN Generator в coreML.Мне удалось перенести все в coreML, за исключением слоя Pixelwise Normalization (Lambda), который я планирую реализовать в качестве пользовательского слоя coreML в Swift / Metal.

В TensorFlow.Keras я внедрил пиксельную норму как

def pixelwise_norm(a):
    return a / tf.sqrt(tf.reduce_mean(a * a, axis=3, keep_dims=True) + 1e-8)

Теперь я почти никогда не работал с шейдерами / металлом, но следуя инструкциям здесь: http://machinethink.net/blog/coreml-custom-layers/, У меня есть пользовательский слой, настроенный для использования металла для операций обратной связи.Я использую MTLComputePipelineState, который (вызывает? Кодирует?) Следующий шейдер для операций слоя:

#include <metal_stdlib>
using namespace metal;


kernel void pixelwise_norm(
              texture2d_array<half, access::read> inTexture [[texture(0)]],
              texture2d_array<half, access::write> outTexture [[texture(1)]],
              ushort3 gid [[thread_position_in_grid]])
{
    if (gid.x >= outTexture.get_width() ||
        gid.y >= outTexture.get_height()) {
        return;
    }

    const float4 x = float4(inTexture.read(gid.xy, gid.z));
    const float4 y = 0.0000001f + (x / sqrt(pow(x,2)));
    outTexture.write(half4(y), gid.xy, gid.z);
}

У меня проблемы с вычислением металлического эквивалента "redu_mean", сейчас этот шейдер реализует~ tenorflow ~ операция вроде

return a / tf.sqrt((a * a) + 1e-8) 

У кого-нибудь есть указатели?Спасибо

1 Ответ

0 голосов
/ 21 сентября 2018

Если я правильно читаю, для каждого пикселя в карте объектов это делит этот пиксель на норму L2 по каналам этого пикселя?

В этом случае вам нужно использовать цикл for длячитать каналы для этого пикселя, суммировать эти числа и делить на количество каналов.(Этот цикл нужно выполнять только в том случае, если количество каналов превышает 4.)

Также обратите внимание, что ваш 1e-8 должен быть внутри sqrt () или хотя бы в знаменателе.

...