Целочисленное переполнение в модульной функции возведения в степень - PullRequest
3 голосов
/ 03 января 2012

Я пишу powMod функцию, которую я должен использовать довольно интенсивно. Отправной точкой является пользовательская функция pow:

// Compute power using multiplication and square.
// pow (*) (^2) 1 x n = x^n
let pow mul sq one x n =   
    let rec loop x' n' acc =
       match n' with
       | 0 -> acc
       | _ -> let q = n'/2
              let r = n'%2
              let x2 = sq x'
              if r = 0 then
                 loop x2 q acc
              else
                 loop x2 q (mul x' acc)
    loop x n one

После проверки диапазона ввода я выбрал int64, потому что он достаточно большой, чтобы представлять вывод, и я могу избежать дорогостоящих вычислений с помощью bigint:

let mulMod m a b = (a*b)%m
let squareMod m a = mulMod m a a
let powMod m = pow (mulMod m) (squareMod m) 1L

Я предполагаю, что по модулю (m) больше, чем множители (a, b), и функции работают только с неотрицательными числами. Функция powMod подходит для большинства случаев; однако проблема заключается в функции mulMod, где a*b может превышать диапазон int64, но (a*b)%m не равен . Пример ниже демонстрирует проблему переполнения:

let a = (pown 2L 40) - 1L
let b = (pown 2L 32) - 1L
let p = powMod a b 2 // p = -8589934591L -- wrong

Есть ли способ избежать переполнения int64 без обращения к типу bigint?

Ответы [ 5 ]

2 голосов
/ 03 января 2012

Проблема, с которой вы столкнулись, заключается в том, что все ваши промежуточные вычисления неявно являются модом 2 64 , и, как правило, неверно, что

a · b mod m = (a· B mod 2 64 ) mod m

, это то, что вы рассчитываете.

Я не могу придумать простой способ сделать правильныйвычисление с использованием только 64-битных чисел, но вам не нужно переходить к bigints;если a и b имеют не более 64 бит, то их полный продукт имеет максимум 128 бит, поэтому вы можете отслеживать продукт в двух 64-битных целых числах (здесь связанных как пользовательская структура):

// bit width of a uint64, needed for mod calculation
let width = 
    let rec loop w = function
    | 0uL -> w
    | n -> loop (w+1) (n >>> 1)
    loop 0

[<Struct; CustomComparison; CustomEquality>]
type UInt128 =
    val hi : uint64
    val lo : uint64
    new (hi,lo) = { lo = lo; hi = hi }
    new (lo) = { lo = lo; hi = 0uL }
    static member (+)(x:UInt128, y:UInt128) =
        if x.lo > 0xffffffffuL - y.lo then
            UInt128(x.hi + y.hi + 1uL, x.lo + y.lo)
        else
            UInt128(x.hi + y.hi, x.lo + y.lo)
    static member (-)(x:UInt128, y:UInt128) =
        if y.lo > x.lo then
            UInt128(x.hi - y.hi - 1uL, x.lo - y.lo)
        else
            UInt128(x.hi - y.hi, x.lo - y.lo)

    static member ( * )(x:UInt128, y:UInt128) =
        let a1 = ((x.lo &&& 0xffffffffuL) * (y.lo &&& 0xffffffffuL)) >>> 32
        let a2 =  (x.lo &&& 0xffffffffuL) * (y.lo >>> 32)
        let a3 =  (x.lo >>> 32) * (y.lo &&& 0xffffffffuL)
        let sum = ((a1 + a2 + a3) >>> 32) + (x.lo >>> 32) * (y.lo >>> 32)
        let sum =
            if a2 > 0xffffffffffffffffuL - a1 || a1 + a2 > 0xffffffffffffffffuL - a3 then
                0x100000000uL + sum
            else
                sum
        UInt128(x.hi * y.lo + x.lo * y.hi + sum, x.lo * y.lo)

    static member (>>>)(x:UInt128, n) =
        UInt128(x.hi >>> n, x.lo >>> n)

    static member (<<<)(x:UInt128, n) =
        UInt128((x.hi <<< n) + (x.lo >>> (64 - n)), x.lo <<< n)

    interface System.IComparable with
        member x.CompareTo(y) =
            match y with
            | :? UInt128 as y ->
                match x.hi.CompareTo(y.hi) with
                | 0 -> x.lo.CompareTo(y.lo)
                | n -> n

    override x.Equals(y) = 
        match y with
        | :? UInt128 as y -> x.hi = y.hi && x.lo = y.lo
        | _ -> false

    override x.GetHashCode() = x.hi.GetHashCode() + x.lo.GetHashCode() * 7

    (* calculate mod via long-division *)
    static member (%)(x:UInt128, d) =
        let rec reduce (r:UInt128) d' =
            if r.hi = 0uL then r.lo % d
            else
                let r' = if r < d' then r else r - d'
                reduce r' (d' >>> 1)
        let shift = width x.hi + (64 - width d)
        reduce x (UInt128(0uL,d) <<< shift)

let mulMod m a b =
    UInt128(a) * UInt128(b) % m

(* squareMod, powMod basically as before: *)
let squareMod m a = mulMod m a a  
let powMod m = pow (mulMod m) (squareMod m) 1uL  

let a = (pown 2uL 40) - 1uL  
let b = (pown 2uL 32) - 1uL  
let p = powMod a b 2

Сказав, что, поскольку bigint s даст вам правильный ответ, почему бы просто не использовать bigint s, чтобы выполнить промежуточный расчет и конвертировать в long в конце (который гарантированно будет без потерьпреобразование с учетом диапазона м)?Я подозреваю, что снижение производительности при использовании bigints должно быть приемлемым для большинства приложений (по сравнению с головной болью поддержки ваших собственных математических процедур).

2 голосов
/ 03 января 2012

Я почти ничего не знаю о f #, однако, я думаю, вы могли бы применить тот факт, что:

Если b нечетно и n такое, что b = 2n + 1

a * b mod(m) = 2 * a * n + a mod(m)
             = 2 * (a*n mod(m)) + a mod(m)

и аналогично, если bдаже.Очевидно, вы можете повторять это столько раз, сколько необходимо для a или n, пока не получите продукт, который будет встраиваться в int64.Я думаю, что все еще возможно получить переполнение, если m> maxint64 / 2.

2 голосов
/ 03 января 2012

Согласно Википедии следующие формулы эквивалентны.Ваш код использует первый, переход на второй должен решить проблему переполнения.

c = (a x b) mod(m)
c = (a x (b mod(m))) mod(m) 

Надеюсь, это поможет.

На основе ваших комментариев ниже - если a <= m и b <= m и m> sqrt (maxint64), тогда я не уверен, что решение возможно без перехода к большему объему памяти.Для больших значений m b mod m вернет b, поэтому использование приведенной выше формулы эквивалентности бесполезно.

Хорошей новостью является то, что вы должны иметь возможность ограничить изменения одной строкой и повторно ввести значениевернемся к 64 битам [поскольку мы знаем, что (a * b)% c не должно переполняться], прежде чем продолжить вычисления.Это ограничивает затраты (с точки зрения производительности выполнения) до максимально возможной части кода.

0 голосов
/ 03 января 2012

Не ответ на ваш вопрос, но написание функции таким образом сделает ее более универсальной и удобной в использовании, а также, казалось бы, более эффективной:

let inline pow x n =
    let zero = LanguagePrimitives.GenericZero
    let rec loop x acc = function
        | n when n = zero -> acc
        | n ->
            let q = n >>> 1
            let acc = if n = (q <<< 1) then acc else x * acc
            loop (x * x) acc q
    loop x LanguagePrimitives.GenericOne n;;

for x = 0 to 1000000 do
    pow 3UL 31UL |> ignore

Кроме того, я полагаю, длинная победа без знакаэтого не достаточно?

Редактировать: следующий алгоритм в 3 раза быстрее, чем приведенный выше для большого bigint, поскольку выполняет меньшее умножение - может помочь вам выбрать в пользу bigint:

let inline pow2 x n =
    let zero = LanguagePrimitives.GenericZero
    let one = LanguagePrimitives.GenericOne
    let rec loop x data = function
        | c when c <<< 1 <= n ->
            let c = c <<< 1
            let x = x * x
            loop x (Map.add -c x data) c
        | c -> reduce x data (n - c)
    and reduce acc data = function
        | c when c = zero -> acc
        | c ->
            let next, value = data |> Seq.pick (fun (KeyValue (n, v)) -> if -n <= c then Some (-n, v) else None)
            reduce (acc * value) data (c - next)
    loop x (Map [-1, x]) one;;

for x = 10000 downto 9000 do
    pow2 7I x |> ignore
0 голосов
/ 03 января 2012

Я не очень хорошо знаю F #, но что-то вроде (псевдокод)

pmod a b n // a^b % n
pmod a 0 n = 1
pmod a 1 n = a%n
pmod a b n = match b%2
   | 0 -> ((pmod (a) (b/2) n) ^ 2) % n
   | 1 -> ((pmod (a) (b-1) n) * a ) % n

pmod еще не в порядке, но его следует использовать в качестве вспомогательной функции, поэтому

PowMod a b n = pmod (a%n) b n

Вы можете видеть, что этот результат будет неправильным, как только квадрат результата будет больше, чем uint64, поэтому n должен соответствовать uint32

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...