Множественная точность векторизованного вычисления с плавающей точкой в ​​Python - PullRequest
0 голосов
/ 24 марта 2019

Я пытаюсь закодировать алгоритм Блахута-Аримото, чтобы оптимизировать искажение скорости для канала X-> Y.Алгоритм использует бета-множитель Лагранжа, который может варьироваться от очень маленького до бесконечности.Решение использует показатель степени бета, поэтому, если бета очень велика, она превышает точность с плавающей точкой в ​​numpy.

Я понимаю, что массивы numpy не совместимы с библиотеками с множественной точностью, поэтому мне нужно использовать библиотеку /обертка похожа на gmpy2.Однако это неэффективно, поэтому хотелось бы определить, когда это необходимо, вместо использования множественной точности по умолчанию.

Я смог использовать np.errstate для обнаружения недопустимого значения.Следующим шагом является использование gmpy2 для вычисления экспоненты.Параметры являются массивами.Есть ли способ рассчитать это параллельно?

Кроме того, я пытался использовать gmpy2, однако я не смог получить разумный ответ.Я просто не уверен, как это работает, и не смог найти документацию, чтобы прояснить ситуацию.

Я не настроен на конкретное решение, только выбрал gmpy2, так как он кажется самой современной оболочкой дляБиблиотека C mpfr.

def blahut_step(p_x, betaD):
    """
    Rate Distortion optimisation for communicating signal X -> Y
    using Lagrange Multiplier beta and distortion matrix D
    multi precision floating arithmetic may be required 
        p_x_next = p_k * exp[betaD].normalised

    BigFloat,gmpy2 - python wrapper for gnu mpfr library

    Params:
        p_x:    current estimate for distribution
        betaD:  beta * distortion matrix (nX, nY)

    Returns:
        next estimate of p_x
    """

    with np.errstate(invalid='raise'):
        try:
            result = np.multiply(p_x, np.exp(betaD))
            return result / np.sum(result, axis=1)[:, None] # normalise
        except FloatingPointError as err:
            # use multiple precision library such as gmpy2
            print("multiple precision required: {0}".format(err))
            exit(0)

Спасибо за любую помощь, даже если она просто указывает мне на направление матричных вычислений с многократной точностью, используя gmpy2.

...