Я пытаюсь сгенерировать случайный список целых чисел и выполнить некоторое арифметическое вычисление (x, +, -)
для случайно сгенерированного числа, но я продолжаю получать сообщение об ошибке.
from sympy.core.numbers import mod_inverse
from sympy.ntheory import sqrt_mod
from sympy.ntheory.residue_ntheory import nthroot_mod
import numpy as np
def findPrimitiveNthRoot(M, N):
roots = nthroot_mod(1, N, M, True)[1:]
for root in roots:
is_primitive = True
for k in range(1, N):
if pow(root, k, M) == 1:
is_primitive = False
if is_primitive:
return root
return None
class NTT:
def __init__(self, poly, M, N, ideal=True, ntt=False, w=None, phi=None):
if ntt:
self.initial_w_ntt(poly, M, N, ideal, w, phi)
else:
self.initial_wo_ntt(poly, M, N, ideal, w, phi)
def initial_wo_ntt(self, poly, M, N, ideal, w, phi):
self.mod = M
self.N = N
if w is None and phi is None:
self.w = findPrimitiveNthRoot(M, N)
self.phi = sqrt_mod(self.w, M)
else:
self.w = w
self.phi = phi
self.ideal = ideal
if ideal:
poly_bar = self.mulPhi(poly)
else:
poly_bar = poly
self.fft_poly = self.ntt(poly_bar)
def initial_w_ntt(self, poly, M, N, ideal, w, phi):
self.mod = M
self.N = N
self.w = w
self.phi = phi
self.ideal = ideal
self.fft_poly = poly
def __name__(self):
return "NTT"
def __str__(self):
poly = " ".join(str(coeff) for coeff in self.fft_poly)
return "NTT points [" + poly + "] modulus " + str(self.mod)
def __mul__(self, other):
assert type(other).__name__ == 'NTT' or type(other).__name__ == 'int', 'type error'
if type(other).__name__ == 'int':
mul_result = self.mulConstant(other)
else:
assert self.N == other.N, "points different"
assert self.mod == other.mod, "modulus different"
assert self.ideal == other.ideal
mul_result = []
for i, point in enumerate(self.fft_poly):
mul_result.append((point * other.fft_poly[i]) % self.mod
return NTT(mul_result, self.mod, self.N, self.ideal, True, self.w, self.phi)
def __add__(self, other):
assert self.N == other.N, 'points different'
assert self.mod == other.mod, 'modulus different'
assert self.ideal == other.ideal
add_result = []
for i, point in enumerate(self.fft_poly):
add_result.append((point + other.fft_poly[i]) % self.mod)
return NTT(add_result, self.mod, self.N, self.ideal, True, self.w, self.phi)
def __sub__(self, other):
assert self.N == other.N, 'points different'
assert self.mod == other.mod, 'modulus different'
assert self.ideal == other.ideal
sub_result = []
for i, point in enumerate(self.fft_poly):
sub_result.append((point - other.fft_poly[i]) % self.mod)
return NTT(sub_result, self.mod, self.N, self.ideal, True, self.w, self.phi)
def bitReverse(self, num, lens):
rev_num = 0
for i in range(0, lens):
if (num >> i) & 1:
rev_num |= 1 << (lens - 1 - i)
return rev_num
def orderReverse(self, poly, N_bit):
_poly = list(poly)
for i, coeff in enumerate(_poly):
rev_i = self.bitReverse(i, N_bit)
if rev_i <= i:
coeff ^= _poly[rev_i]
_poly[rev_i] ^= coeff
coeff ^= _poly[rev_i]
_poly[i] = coeff
return _poly
def ntt(self, poly, w=None):
if w is None:
w = self.w
N_bit = self.N.bit_length() - 1
rev_poly = self.orderReverse(poly, N_bit)
for i in range(0, N_bit):
points1, points2 = [], []
for j in range(0, int(self.N // 2)):
shift_bits = N_bit - 1 - i
P = (i >> shift_bits) << shift_bits
w_P = pow(w, P, self.mod)
odd = rev_poly[2 * i + 1] * w_P
even = rev_poly[2 * i]
points1.append((even + odd) % self.mod)
points2.append((even - odd) % self.mod)
points = points1 + points2
if i != N_bit:
rev_poly = points
return points
def intt(self):
inv_w = mod_inverse(self.w, self.mod)
inv_N = mod_inverse(self.N, self.mod)
poly = self.ntt(self.fft_poly, inv_w)
for i in range(0, self.N):
poly[i] = poly[i] * inv_N % self.mod
if self.ideal:
inv_phi = mod_inverse(self.phi, self.mod)
poly = self.mulPhi(poly,inv_phi)
return poly
def mulPhi(self, poly, phi=None):
if phi is None:
phi = self.phi
poly_bar = list(poly)
for i, coeff in enumerate(poly):
poly1 = (poly[i]) * pow(phi, i, self.mod)
poly11 = ''.join(map(str, poly1))
poly_bar[i] = int(self.mod) % int(poly11)
return poly_bar
def mulConstant(self, constant):
mul_result = []
for coeff in self.fft_poly:
result = coeff * constant % self.mod
mul_result.append(result)
return mul_result
Вычисление
pol = np.random.randint(0, 4, 10)
poly1 = np.reshape(pol, (2, 5))
poly = np.random.randint(1, 3, 10)
poly2 = np.reshape(poly, (2, 5))
modulus = 17
N = 4
fft_poly1 = NTT(poly1, modulus, N)
fft_poly2 = NTT(poly2, modulus, N)
print(type(fft_poly1).__name__)
mult_result = fft_poly1 * fft_poly2
add_result = fft_poly1 + fft_poly2
mult_const_result = fft_poly1 * 2
sub_result = fft_poly1 - fft_poly2
print('multiply:', mult_result.intt())
print('multiply 2:', mult_const_result.intt())
print('addition:', add_result.intt())
print('substraction:', sub_result.intt())
Сообщение об ошибке
Traceback (most recent call last):
File "/Users/mac/untitled/N.py", line 22, in <module>
print('multiply:', mult_result.intt())
File "/Users/mac/untitled/N.py", line 199, in intt
poly = self.mulPhi(poly, inv_phi)
File "/Users/mac/untitled/N.py", line 223, in mulPhi
poly11 = ''.join(map(str, poly1))
TypeError: 'int' object is not iterable