Сетка динамического программирования в Python - PullRequest
0 голосов
/ 24 сентября 2018

Я хочу минимизировать функцию стоимости по сетке в Python.У меня есть две переменные x и y, которые можно вычислить как

x[i+1,j+1], y[i+1,j+1] = f(x[i,j], x[i+1,j], x[i,j+1], foo[i,j], bar[i,j])

Другими словами, точка сетки (i + 1, j + 1) зависит от двух ядер foo и bar, и этососедние узлы (i, j + 1) (i + 1, j) и (i, j).Пример с игрушкой можно увидеть ниже

import numpy as np

N = 20
ivec = np.arange(N)
jvec = np.arange(N)

# Kernels
foo = np.sin(ivec[:,None] * jvec[None,:])
bar = np.cos(ivec[:,None] + jvec[None,:])

# We want to find the total cost for traversing over the matrix
d = np.zeros((N,N))

# And store the optimal path
indices = np.zeros((N,N), "int")

for i in range(N-1):
    for j in range(N-1):

        # Compute all posibilities for reaching current node
        dd = [
            d[i+1,j] + foo[i,j],
            d[i,j+1] + bar[i,j],
            d[i,j] + foo[i,j] * bar[i,j]
        ]

        # And find and store the minimim path
        indices[i+1,j+1] = np.argmin(dd)
        d[i+1,j+1] = dd[indices[i+1,j+1]]

print(d[-1,-1])

Однако это очень неэффективное решение.Тем более что N может быть сколь угодно большим.Поэтому мой вопрос: какой самый / более эффективный способ вычислить это?Использование итераторов (я пробовал np.nditer без особого успеха), или использование Numba, или есть какие-то хитрые хитрости, которые можно сделать в Numpy?Я начал изучать ufuncs и ufunc.accumulate с Numpy, но не могу сразу увидеть решение.

Обратите внимание, что в foo, bar и dd будет сложнее, чем в игрушкепример.

1 Ответ

0 голосов
/ 25 сентября 2018

Как вы упомянули, использование numba может быть самым простым способом сделать ваш код быстрее, когда N велико.

import numba
import numpy as np

@numba.jit(nopython=True)
def dp(foo, bar):
    N = foo.shape[0]

    # We want to find the total cost for traversing over the matrix
    d = np.zeros((N,N))

    # And store the optimal path
    indices = np.zeros((N,N), dtype=np.int64)

    for i in range(N-1):
        for j in range(N-1):

            # Compute all posibilities for reaching current node
            dd = np.array([
                d[i+1,j] + foo[i,j],
                d[i,j+1] + bar[i,j],
                d[i,j] + foo[i,j] * bar[i,j]
            ])

            # And find and store the minimim path
            indices[i+1,j+1] = np.argmin(dd)
            d[i+1,j+1] = dd[indices[i+1,j+1]]
    return d[-1,-1]
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...