Целочисленный объект не повторяется - PullRequest
0 голосов
/ 17 октября 2019

Я пытаюсь сгенерировать случайный список целых чисел и выполнить некоторое арифметическое вычисление (x, +, -) для случайно сгенерированного числа, но я продолжаю получать сообщение об ошибке.

from sympy.core.numbers import mod_inverse
from sympy.ntheory import sqrt_mod
from sympy.ntheory.residue_ntheory import nthroot_mod
import numpy as np


def findPrimitiveNthRoot(M, N):
    roots = nthroot_mod(1, N, M, True)[1:] 
    for root in roots: 
        is_primitive = True
        for k in range(1, N):
            if pow(root, k, M) == 1:
                is_primitive = False
        if is_primitive:
            return root
    return None


class NTT:
    def __init__(self, poly, M, N, ideal=True, ntt=False, w=None, phi=None):
        if ntt:  
            self.initial_w_ntt(poly, M, N, ideal, w, phi)
        else:  
            self.initial_wo_ntt(poly, M, N, ideal, w, phi)

    def initial_wo_ntt(self, poly, M, N, ideal, w, phi):
        self.mod = M 
        self.N = N      
        if w is None and phi is None:   
            self.w = findPrimitiveNthRoot(M, N)  
            self.phi = sqrt_mod(self.w, M)  
        else:
            self.w = w
            self.phi = phi
        self.ideal = ideal
        if ideal:
            poly_bar = self.mulPhi(poly)
        else:
            poly_bar = poly
        self.fft_poly = self.ntt(poly_bar)

    def initial_w_ntt(self, poly, M, N, ideal, w, phi):
        self.mod = M 
        self.N = N   
        self.w = w 
        self.phi = phi
        self.ideal = ideal
        self.fft_poly = poly

    def __name__(self):
        return "NTT"

    def __str__(self):
        poly = " ".join(str(coeff) for coeff in self.fft_poly)
        return "NTT points [" + poly + "] modulus " + str(self.mod)

    def __mul__(self, other):
        assert type(other).__name__ == 'NTT' or type(other).__name__ == 'int', 'type error'
        if type(other).__name__ == 'int':
            mul_result = self.mulConstant(other)
        else:
            assert self.N == other.N, "points different"
            assert self.mod == other.mod, "modulus different"
            assert self.ideal == other.ideal
            mul_result = []
            for i, point in enumerate(self.fft_poly):    
                mul_result.append((point * other.fft_poly[i]) % self.mod
        return NTT(mul_result, self.mod, self.N, self.ideal, True, self.w, self.phi)

    def __add__(self, other):
        assert self.N == other.N, 'points different'
        assert self.mod == other.mod, 'modulus different'
        assert self.ideal == other.ideal
        add_result = []
        for i, point in enumerate(self.fft_poly):
            add_result.append((point + other.fft_poly[i]) % self.mod)
        return NTT(add_result, self.mod, self.N, self.ideal, True, self.w, self.phi)

    def __sub__(self, other):
        assert self.N == other.N, 'points different'
        assert self.mod == other.mod, 'modulus different'
        assert self.ideal == other.ideal
        sub_result = []
        for i, point in enumerate(self.fft_poly):
            sub_result.append((point - other.fft_poly[i]) % self.mod)
        return NTT(sub_result, self.mod, self.N, self.ideal, True, self.w, self.phi)

    def bitReverse(self, num, lens):
        rev_num = 0
        for i in range(0, lens):
            if (num >> i) & 1:
                rev_num |= 1 << (lens - 1 - i)
        return rev_num

    def orderReverse(self, poly, N_bit):
        _poly = list(poly)
        for i, coeff in enumerate(_poly):
            rev_i = self.bitReverse(i, N_bit)
            if rev_i <= i:   
                coeff ^= _poly[rev_i]
                _poly[rev_i] ^= coeff
                coeff ^= _poly[rev_i]
                _poly[i] = coeff
        return _poly

    def ntt(self, poly, w=None):
        if w is None:
            w = self.w
        N_bit = self.N.bit_length() - 1
        rev_poly = self.orderReverse(poly, N_bit)
        for i in range(0, N_bit):
            points1, points2 = [], []
            for j in range(0, int(self.N // 2)):
                shift_bits = N_bit - 1 - i
                P = (i >> shift_bits) << shift_bits 
                w_P = pow(w, P, self.mod)   
                odd = rev_poly[2 * i + 1] * w_P  
                even = rev_poly[2 * i]   
                points1.append((even + odd) % self.mod)
                points2.append((even - odd) % self.mod)
                points = points1 + points2
            if i != N_bit:
                rev_poly = points
        return points

    def intt(self):
        inv_w = mod_inverse(self.w, self.mod)
        inv_N = mod_inverse(self.N, self.mod)
        poly = self.ntt(self.fft_poly, inv_w)
        for i in range(0, self.N):
            poly[i] = poly[i] * inv_N % self.mod
        if self.ideal:
            inv_phi = mod_inverse(self.phi, self.mod)
            poly = self.mulPhi(poly,inv_phi)
        return poly

    def mulPhi(self, poly, phi=None):      
        if phi is None:
            phi = self.phi
        poly_bar =  list(poly)
        for i, coeff in enumerate(poly):
            poly1 = (poly[i]) * pow(phi, i, self.mod)
            poly11 = ''.join(map(str, poly1))
            poly_bar[i] = int(self.mod) % int(poly11)  
        return poly_bar

    def mulConstant(self, constant):
        mul_result = []
        for coeff in self.fft_poly:
            result = coeff * constant % self.mod
            mul_result.append(result)
        return mul_result

Вычисление

pol = np.random.randint(0, 4, 10)
poly1 = np.reshape(pol, (2, 5))
poly = np.random.randint(1, 3, 10)
poly2 = np.reshape(poly, (2, 5))
modulus = 17
N = 4
fft_poly1 = NTT(poly1, modulus, N)
fft_poly2 = NTT(poly2, modulus, N)
print(type(fft_poly1).__name__)
mult_result = fft_poly1 * fft_poly2
add_result = fft_poly1 + fft_poly2
mult_const_result = fft_poly1 * 2
sub_result = fft_poly1 - fft_poly2
print('multiply:', mult_result.intt())
print('multiply 2:', mult_const_result.intt())
print('addition:', add_result.intt())
print('substraction:', sub_result.intt())    

Сообщение об ошибке

Traceback (most recent call last):
  File "/Users/mac/untitled/N.py", line 22, in <module>
    print('multiply:', mult_result.intt())
  File "/Users/mac/untitled/N.py", line 199, in intt
    poly = self.mulPhi(poly, inv_phi)
  File "/Users/mac/untitled/N.py", line 223, in mulPhi
    poly11 = ''.join(map(str, poly1))
TypeError: 'int' object is not iterable
...