Реализация БПФ над конечными полями - PullRequest
0 голосов
/ 11 сентября 2018

Я хотел бы реализовать умножение полиномов с использованием NTT.Я следовал Теоретико-числовое преобразование (целочисленное ДПФ) и, похоже, оно работает.

Теперь я хотел бы реализовать умножение многочленов над конечными полями Z_p[x], где p - произвольное простое числочисло.

Изменяет ли что-нибудь, что коэффициенты теперь ограничены p, по сравнению с прежним неограниченным случаем?

В частности, оригинальное NTT требовало найти простое число N какрабочий модуль, который больше (magnitude of largest element of input vector)^2 * (length of input vector) + 1, так что результат никогда не переполняется.Если результат все равно будет ограничен этим p простым числом, насколько малым может быть модуль?Обратите внимание, что p - 1 не обязательно должен иметь форму (some positive integer) * (length of input vector).

Редактировать: я скопировал исходный код по ссылке выше, чтобы проиллюстрировать проблему:

# 
# Number-theoretic transform library (Python 2, 3)
# 
# Copyright (c) 2017 Project Nayuki
# All rights reserved. Contact Nayuki for licensing.
# https://www.nayuki.io/page/number-theoretic-transform-integer-dft
#

import itertools, numbers

def find_params_and_transform(invec, minmod):
    check_int(minmod)
    mod = find_modulus(len(invec), minmod)
    root = find_primitive_root(len(invec), mod - 1, mod)
    return (transform(invec, root, mod), root, mod)

def check_int(n):
    if not isinstance(n, numbers.Integral):
        raise TypeError()

def find_modulus(veclen, minimum):
    check_int(veclen)
    check_int(minimum)
    if veclen < 1 or minimum < 1:
        raise ValueError()
    start = (minimum - 1 + veclen - 1) // veclen
    for i in itertools.count(max(start, 1)):
        n = i * veclen + 1
        assert n >= minimum
        if is_prime(n):
            return n

def is_prime(n):
    check_int(n)
    if n <= 1:
        raise ValueError()
    return all((n % i != 0) for i in range(2, sqrt(n) + 1))

def sqrt(n):
    check_int(n)
    if n < 0:
        raise ValueError()
    i = 1
    while i * i <= n:
        i *= 2
    result = 0
    while i > 0:
        if (result + i)**2 <= n:
            result += i
        i //= 2
    return result

def find_primitive_root(degree, totient, mod):
    check_int(degree)
    check_int(totient)
    check_int(mod)
    if not (1 <= degree <= totient < mod):
        raise ValueError()
    if totient % degree != 0:
        raise ValueError()
    gen = find_generator(totient, mod)
    root = pow(gen, totient // degree, mod)
    assert 0 <= root < mod
    return root

def find_generator(totient, mod):
    check_int(totient)
    check_int(mod)
    if not (1 <= totient < mod):
        raise ValueError()
    for i in range(1, mod):
        if is_generator(i, totient, mod):
            return i
    raise ValueError("No generator exists")

def is_generator(val, totient, mod):
    check_int(val)
    check_int(totient)
    check_int(mod)
    if not (0 <= val < mod):
        raise ValueError()
    if not (1 <= totient < mod):
        raise ValueError()
    pf = unique_prime_factors(totient)
    return pow(val, totient, mod) == 1 and all((pow(val, totient // p, mod) != 1) for p in pf)

def unique_prime_factors(n):
    check_int(n)
    if n < 1:
        raise ValueError()
    result = []
    i = 2
    end = sqrt(n)
    while i <= end:
        if n % i == 0:
            n //= i
            result.append(i)
            while n % i == 0:
                n //= i
            end = sqrt(n)
        i += 1
    if n > 1:
        result.append(n)
    return result

def transform(invec, root, mod):
    check_int(root)
    check_int(mod)
    if len(invec) >= mod:
        raise ValueError()
    if not all((0 <= val < mod) for val in invec):
        raise ValueError()
    if not (1 <= root < mod):
        raise ValueError()

    outvec = []
    for i in range(len(invec)):
        temp = 0
        for (j, val) in enumerate(invec):
            temp += val * pow(root, i * j, mod)
            temp %= mod
        outvec.append(temp)
    return outvec

def inverse_transform(invec, root, mod):
    outvec = transform(invec, reciprocal(root, mod), mod)
    scaler = reciprocal(len(invec), mod)
    return [(val * scaler % mod) for val in outvec]

def reciprocal(n, mod):
    check_int(n)
    check_int(mod)
    if not (0 <= n < mod):
        raise ValueError()
    x, y = mod, n
    a, b = 0, 1
    while y != 0:
        a, b = b, a - x // y * b
        x, y = y, x % y
    if x == 1:
        return a % mod
    else:
        raise ValueError("Reciprocal does not exist")

def circular_convolve(vec0, vec1):
    if not (0 < len(vec0) == len(vec1)):
        raise ValueError()
    if any((val < 0) for val in itertools.chain(vec0, vec1)):
        raise ValueError()
    maxval = max(val for val in itertools.chain(vec0, vec1))
    minmod = maxval**2 * len(vec0) + 1
    temp0, root, mod = find_params_and_transform(vec0, minmod)
    temp1 = transform(vec1, root, mod)
    temp2 = [(x * y % mod) for (x, y) in zip(temp0, temp1)]
    return inverse_transform(temp2, root, mod)

vec0 = [24, 12, 28, 8, 0, 0, 0, 0]
vec1 = [4, 26, 29, 23, 0, 0, 0, 0]

print(circular_convolve(vec0, vec1))

def modulo(vec, prime):
    return [x % prime for x in vec]

print(modulo(circular_convolve(vec0, vec1), 31))

Печать:

[96, 672, 1120, 1660, 1296, 876, 184, 0]
[3, 21, 4, 17, 25, 8, 29, 0]

Однако, когда я изменяю minmod = maxval**2 * len(vec0) + 1 на minmod = maxval + 1, он перестает работать:

[14, 16, 13, 20, 25, 15, 20, 0]
[14, 16, 13, 20, 25, 15, 20, 0]

Какой самый маленький minmod (N в ссылке выше) быть для того, чтобы работать как положено?

1 Ответ

0 голосов
/ 15 сентября 2018

Если ваш ввод n целых чисел связан с каким-то простым q (любое mod q не просто простое будет одинаковым) Вы можете использовать его как max value +1, но будьте осторожны, вы не можете использовать его как штрих p для NTT , потому что NTT штрих p обладает особыми свойствами. Все они здесь:

поэтому наше максимальное значение для каждого входа составляет q-1, но во время вычислений вашей задачи (Свертка по результатам 2 NTT ) величина результатов первого уровня может возрасти до n.(q-1), но, как мы делаем в свертке на них входная величина конечного iNTT возрастет до:

m = n.((q-1)^2)

Если вы выполняете операции на NTT с, отличные от m, уравнение может измениться.

Теперь давайте вернемся к p, поэтому в двух словах вы можете использовать любое простое число p, которое поддерживает это:

p mod n == 1
p > m

и существует 1 <= r,L < p такое, что:

p mod (L-1) = 0
r^(L*i) mod p == 1 // i = { 0,n }
r^(L*i) mod p != 1 // i = { 1,2,3, ... n-1 }

Если все это выполнено, тогда p является n-м корнем единицы и может использоваться для NTT . Чтобы найти такое простое число, а также r,L посмотрите на ссылку выше (есть код C ++, который находит такое).

Например, во время умножения строки мы берем 2 строки, делаем NTT , затем сворачиваем результат и iNTT возвращаем результат (который является суммой обоих входных размеров). Так, например:

                                99999999999999999999999999999999
                               *99999999999999999999999999999999
----------------------------------------------------------------
9999999999999999999999999999999800000000000000000000000000000001

q = 10 и оба операнда 9 ^ 32, поэтому n=32, следовательно, m = 9*9*32 = 2592 и найденное простое число p = 2689. Как вы видите результат совпадения, переполнение не происходит. Однако, если я использую любое меньшее простое число, которое все еще соответствует всем другим условиям, результат не будет совпадать. Я использовал это специально, чтобы максимально растянуть значения NTT (все значения равны q-1, а размеры равны той же степени 2)

В случае, если ваш NTT быстрый и n не является степенью 2, вам нужно обнулить блок до ближайшей более высокой или равной степени 2 для каждого NTT . Но это не должно влиять на значение m, поскольку нулевое значение не должно увеличивать величину значений. Мое тестирование доказывает, что для свертки вы можете использовать:

m = (n1+n2).((q-1)^2)/2

, где n1,n2 - исходные размеры входных данных перед нулевой величиной.

Для получения дополнительной информации о реализации NTT вы можете проверить мой в C ++ (значительно оптимизирован):

Итак, чтобы ответить на ваши вопросы:

  1. да, вы можете воспользоваться тем, что вводом является mod q, но вы не можете использовать q как p !!!

  2. Вы можете использовать minmod = n * (maxval + 1) только для одного NTT (или первого уровня NTT), но, поскольку вы объединяете их в цепочку во время использования NTT, вы не можете использовать это для заключительного этапа INTT !! !

Однако, как я уже упоминал в комментариях, проще всего использовать максимально возможную p, которая соответствует типу данных, который вы используете, и может использоваться для всех поддерживаемых входных мощностей двух размеров .

Что в основном делает ваш вопрос неуместным. Единственный случай, который я могу придумать, где это невозможно / желательно, - это числа с произвольной точностью, где нет максимального предела «нет». Существует много проблем с производительностью, связанных с переменной p, так как поиск p действительно медленный (может быть даже медленнее, чем сам NTT ), а также переменная p отключает многие оптимизации производительности модульного арифметика должна была сделать NTT очень медленным.

...