Алгоритм Карацубы слишком много рекурсии - PullRequest
6 голосов
/ 14 августа 2011

Я пытаюсь реализовать алгоритм умножения Карацубы в c ++, но сейчас я просто пытаюсь заставить его работать в python.

Вот мой код:

def mult(x, y, b, m):
    if max(x, y) < b:
        return x * y

    bm = pow(b, m)
    x0 = x / bm
    x1 = x % bm
    y0 = y / bm
    y1 = y % bm

    z2 = mult(x1, y1, b, m)
    z0 = mult(x0, y0, b, m)
    z1 = mult(x1 + x0, y1 + y0, b, m) - z2 - z0

    return mult(z2, bm ** 2, b, m) + mult(z1, bm, b, m) + z0

Чего я не понимаю: как следует создавать z2, z1 и z0? Является ли использование функции mult рекурсивно правильным? Если так, я где-то напортачу, потому что рекурсия не останавливается.

Может кто-то указать, где ошибка?

Ответы [ 5 ]

5 голосов
/ 15 августа 2011

Примечание: приведенный ниже ответ непосредственно касается вопроса ОП о чрезмерной рекурсии, но он не пытается обеспечить правильный алгоритм Карацубы.Другие ответы гораздо более информативны в этом отношении.

Попробуйте эту версию:

def mult(x, y, b, m):
    bm = pow(b, m)

    if min(x, y) <= bm:
        return x * y

    # NOTE the following 4 lines
    x0 = x % bm
    x1 = x / bm
    y0 = y % bm
    y1 = y / bm

    z0 = mult(x0, y0, b, m)
    z2 = mult(x1, y1, b, m)
    z1 = mult(x1 + x0, y1 + y0, b, m) - z2 - z0

    retval = mult(mult(z2, bm, b, m) + z1, bm, b, m) + z0
    assert retval == x * y, "%d * %d == %d != %d" % (x, y, x * y, retval)
    return retval

Самая серьезная проблема с вашей версией состоит в том, что ваши вычисления x0 и x1, ииз y0 и y1 перевернуты.Кроме того, вывод алгоритма не выполняется, если x1 и y1 равны 0, поскольку в этом случае шаг факторизации становится недействительным.Следовательно, вы должны избегать этой возможности, гарантируя, что x и y больше, чем b ** m.

EDIT: исправлена ​​опечатка в коде;добавлены пояснения

РЕДАКТИРОВАТЬ2:

Чтобы быть более ясным, комментируя непосредственно к вашей первоначальной версии:

def mult(x, y, b, m):
    # The termination condition will never be true when the recursive 
    # call is either
    #    mult(z2, bm ** 2, b, m)
    # or mult(z1, bm, b, m)
    #
    # Since every recursive call leads to one of the above, you have an
    # infinite recursion condition.
    if max(x, y) < b:
        return x * y

    bm = pow(b, m)

    # Even without the recursion problem, the next four lines are wrong
    x0 = x / bm  # RHS should be x % bm
    x1 = x % bm  # RHS should be x / bm
    y0 = y / bm  # RHS should be y % bm
    y1 = y % bm  # RHS should be y / bm

    z2 = mult(x1, y1, b, m)
    z0 = mult(x0, y0, b, m)
    z1 = mult(x1 + x0, y1 + y0, b, m) - z2 - z0

    return mult(z2, bm ** 2, b, m) + mult(z1, bm, b, m) + z0
4 голосов
/ 15 августа 2011

Целью умножения Карацубы является улучшение алгоритма умножения «разделяй и властвуй», совершая 3 рекурсивных вызова вместо четырех. Поэтому единственные строки в вашем скрипте, которые должны содержать рекурсивный вызов умножения, - это те, которые присваивают z0, z1 и z2. Все остальное ухудшит сложность. Вы не можете использовать pow для вычисления b m , если вы еще не определили умножение (и тем более возведение в степень), либо.

Для этого алгоритм критически использует тот факт, что он использует систему позиционной записи. Если у вас есть представление x числа в базе b , то x * b m просто получается смещением цифр это представление m раз влево. Эта операция переключения по существу «свободна» в любой системе позиционного обозначения. Это также означает, что если вы хотите реализовать это, вы должны воспроизвести эту позиционную нотацию и «свободный» сдвиг. Либо вы решили вычислять в base b = 2 и использовать битовые операторы python (или битовые операторы с заданным десятичным, шестнадцатеричным, ... base, если они есть на вашей тестовой платформе), или вы решите реализовать для образовательных целей что-то, что работает для произвольной b , и вы воспроизводите эту позиционную арифметику с чем-то вроде строк, массивов или списков .

У вас есть решение со списками . Мне нравится работать со строками в python, поскольку int(s, base) даст вам целое число, соответствующее строке s, представленной как представление чисел в базе base: это облегчает тестирование. Я выложил строго комментированную реализацию на основе строк в виде гисти здесь , включая примитивы от строки к номеру и числа к строке для хорошей меры.

Вы можете проверить это, предоставив дополненные строки с основанием и их (равной) длиной в качестве аргументов mult:

In [169]: mult("987654321","987654321",10,9)

Out[169]: '966551847789971041'

Если вы не хотите выяснять длину строк или считать длину строки, функция заполнения может сделать это за вас:

In [170]: padding("987654321","2")

Out[170]: ('987654321', '000000002', 9)

И, конечно, он работает с b>10:

In [171]: mult('987654321', '000000002', 16, 9)

Out[171]: '130eca8642'

(Проверьте с wolfram alpha )

4 голосов
/ 15 августа 2011

Обычно большие числа хранятся в виде массивов целых чисел. Каждое целое число представляет одну цифру. Этот подход позволяет умножить любое число на степень основания с простым смещением влево массива.

Вот моя реализация на основе списка (может содержать ошибки):

def normalize(l,b):
    over = 0
    for i,x in enumerate(l):
        over,l[i] = divmod(x+over,b)
    if over: l.append(over)
    return l
def sum_lists(x,y,b):
    l = min(len(x),len(y))
    res = map(operator.add,x[:l],y[:l])
    if len(x) > l: res.extend(x[l:])
    else: res.extend(y[l:])
    return normalize(res,b)
def sub_lists(x,y,b):
    res = map(operator.sub,x[:len(y)],y)
    res.extend(x[len(y):])
    return normalize(res,b)
def lshift(x,n):
    if len(x) > 1 or len(x) == 1 and x[0] != 0:
        return [0 for i in range(n)] + x
    else: return x
def mult_lists(x,y,b):
    if min(len(x),len(y)) == 0: return [0]
    m = max(len(x),len(y))
    if (m == 1): return normalize([x[0]*y[0]],b)
    else: m >>= 1
    x0,x1 = x[:m],x[m:]
    y0,y1 = y[:m],y[m:]
    z0 = mult_lists(x0,y0,b)
    z1 = mult_lists(x1,y1,b)
    z2 = mult_lists(sum_lists(x0,x1,b),sum_lists(y0,y1,b),b)
    t1 = lshift(sub_lists(z2,sum_lists(z1,z0,b),b),m)
    t2 = lshift(z1,m*2)
    return sum_lists(sum_lists(z0,t1,b),t2,b)

sum_lists и sub_lists возвращает ненормированный результат - одна цифра может быть больше базового значения. normalize функция решила эту проблему.

Все функции ожидают получения списка цифр в обратном порядке. Например, 12 в базе 10 следует записать как [2,1]. Возьмем квадрат 9987654321.

» a = [1,2,3,4,5,6,7,8,9]
» res = mult_lists(a,a,10)
» res.reverse()
» res
[9, 7, 5, 4, 6, 1, 0, 5, 7, 7, 8, 9, 9, 7, 1, 0, 4, 1]
1 голос
/ 14 августа 2011

Я полагаю, что идея этого метода заключается в том, что члены z i вычисляются с использованием рекурсивного алгоритма, но результаты не объединяются таким образом.Поскольку вы хотите получить чистый результат

z0 B^2m + z1 B^m + z2

Предполагая, что вы выбрали подходящее значение B (скажем, 2), вы можете вычислить B ^ m без каких-либо умножений.Например, при использовании B = 2 вы можете вычислить B ^ m, используя сдвиги битов, а не умножения.Это означает, что последний шаг может быть выполнен без каких-либо умножений.

Еще одна вещь - я заметил, что вы выбрали фиксированное значение m для всего алгоритма.Как правило, вы реализуете этот алгоритм так, чтобы значение m всегда было таким, чтобы B ^ m составляло половину числа цифр в x и y, когда они записаны в базе B. Если вы используете степени двух, это будет сделановыбрав m = ceil ((log x) / 2).

Надеюсь, это поможет!

0 голосов
/ 23 сентября 2018

В Python 2.7: сохраните этот файл как Karatsuba.py

   def karatsuba(x,y):
        """Karatsuba multiplication algorithm.
        Return the product of two numbers in an efficient manner
        @author Shashank
        date: 23-09-2018

        Parameters
        ----------
        x : int
            First Number 
        y : int
            Second Number   

        Returns
        -------
        prod : int
               The product of two numbers 

        Examples
        --------
        >>> import Karatsuba.karatsuba
        >>> a = 1234567899876543211234567899876543211234567899876543211234567890
        >>> b = 9876543211234567899876543211234567899876543211234567899876543210
        >>> Karatsuba.karatsuba(a,b)
        12193263210333790590595945731931108068998628253528425547401310676055479323014784354458161844612101832860844366209419311263526900
        """
        if len(str(x)) == 1 or len(str(y)) == 1:
            return x*y
        else:
            n = max(len(str(x)), len(str(y)))
            m = n/2

            a = x/10**m
            b = x%10**m
            c = y/10**m
            d = y%10**m

            ac = karatsuba(a,c)                             #step 1
            bd = karatsuba(b,d)                             #step 2
            ad_plus_bc = karatsuba(a+b, c+d) - ac - bd      #step 3
            prod = ac*10**(2*m) + bd + ad_plus_bc*10**m     #step 4
            return prod
...