Проблемы с ограничениями JIT JIT и Numpy - PullRequest
0 голосов
/ 06 марта 2020

Я недавно начал экспериментировать с интересной библиотекой python Jax , которая содержит усиленный Numpy, а также Automati c Дифференциатор. Я хотел создать грубый «дифференцируемый рендерер», написав функцию шейдера и потери в python, а затем используя AD Jax для нахождения градиента. Затем мы должны иметь возможность инвертировать рендеринг изображения, запустив градиентный спуск по этому градиенту потерь. Я сделал это довольно хорошо с простыми шейдерами, но у меня возникли проблемы при использовании логических выражений. Это код моего шейдера, который генерирует шаблон шахматной доски:

import jax.numpy as np

class CheckerShader:

    def __init__(self, scale: float, color1: np.ndarray, color2: np.ndarray):
        self.color1 = None
        self.color2 = None
        self.scale = None
        self.scale_min = 0
        self.scale_max = 20
        self.color1 = color1
        self.color2 = color2
        self.scale = scale * 20

    def checker(self, x: float, y: float) -> float:
        xi = np.abs(np.floor(x))
        yi = np.abs(np.floor(y))

        first_col = np.mod(xi, 2) == np.mod(yi, 2)
        return first_col

    def shade(self, x: float, y: float):
        x = x * self.scale
        y = y * self.scale

        first_col = self.checker(x, y)
        if first_col:
            return self.color1
        else:
            return self.color2

И это моя функция рендеринга, которая является первым местом, где JIT терпит неудачу:

import jax.numpy as np
import numpy as onp
import jax

def render(scale, c1, c2):
    img = onp.zeros((WIDTH, HEIGHT, CHANNELS))
    sh = CheckerShader(scale, c1, c2)
    jit_func = jax.jit(sh.shade)

    for y in range(HEIGHT):
        for x in range(WIDTH):
            val = jit_func(x / WIDTH, y / HEIGHT)
            img[y, x, :] = val

    return img

Ошибка полученное сообщение:

TypeError: Abstract value passed to `bool`, which requires a concrete value. The function to be transformed can't be traced at the required level of abstraction. If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions instead.

, и я думаю, это потому, что вы не можете запустить JIT для функции с логическими значениями, значения которых зависят от того, что не было решено во время компиляции. Но как я могу переписать его для работы с JIT? Без JIT это мучительно медленно.

Другой вопрос, который у меня есть, есть ли что-то, что я могу сделать, чтобы ускорить Numpy Джекса в целом? Рендеринг моего изображения (100x100 пикселей) с нормальным Numpy занимает несколько миллисекунд, но с Jax Numpy это занимает секунды! Спасибо: D

1 Ответ

1 голос
/ 09 марта 2020

Заменить

if first_col:
    return self.color1
else:
    return self.color2

на

return np.where(first_col, self.color1, self.color2)
...