Я недавно начал экспериментировать с интересной библиотекой python Jax , которая содержит усиленный Numpy, а также Automati c Дифференциатор. Я хотел создать грубый «дифференцируемый рендерер», написав функцию шейдера и потери в python, а затем используя AD Jax для нахождения градиента. Затем мы должны иметь возможность инвертировать рендеринг изображения, запустив градиентный спуск по этому градиенту потерь. Я сделал это довольно хорошо с простыми шейдерами, но у меня возникли проблемы при использовании логических выражений. Это код моего шейдера, который генерирует шаблон шахматной доски:
import jax.numpy as np
class CheckerShader:
def __init__(self, scale: float, color1: np.ndarray, color2: np.ndarray):
self.color1 = None
self.color2 = None
self.scale = None
self.scale_min = 0
self.scale_max = 20
self.color1 = color1
self.color2 = color2
self.scale = scale * 20
def checker(self, x: float, y: float) -> float:
xi = np.abs(np.floor(x))
yi = np.abs(np.floor(y))
first_col = np.mod(xi, 2) == np.mod(yi, 2)
return first_col
def shade(self, x: float, y: float):
x = x * self.scale
y = y * self.scale
first_col = self.checker(x, y)
if first_col:
return self.color1
else:
return self.color2
И это моя функция рендеринга, которая является первым местом, где JIT терпит неудачу:
import jax.numpy as np
import numpy as onp
import jax
def render(scale, c1, c2):
img = onp.zeros((WIDTH, HEIGHT, CHANNELS))
sh = CheckerShader(scale, c1, c2)
jit_func = jax.jit(sh.shade)
for y in range(HEIGHT):
for x in range(WIDTH):
val = jit_func(x / WIDTH, y / HEIGHT)
img[y, x, :] = val
return img
Ошибка полученное сообщение:
TypeError: Abstract value passed to `bool`, which requires a concrete value. The function to be transformed can't be traced at the required level of abstraction. If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions instead.
, и я думаю, это потому, что вы не можете запустить JIT для функции с логическими значениями, значения которых зависят от того, что не было решено во время компиляции. Но как я могу переписать его для работы с JIT? Без JIT это мучительно медленно.
Другой вопрос, который у меня есть, есть ли что-то, что я могу сделать, чтобы ускорить Numpy Джекса в целом? Рендеринг моего изображения (100x100 пикселей) с нормальным Numpy занимает несколько миллисекунд, но с Jax Numpy это занимает секунды! Спасибо: D