Если вы хотите сделать это без вычисления модульного обратного, вы можете вычислить его рекурсивно, используя:
1 + A + A 2 + A 3 + ... + A k
= 1 + (A + A 2 ) (1 + A 2 + (A 2 ) 2 + ... + (A 2 ) k / 2-1 )
Это для четных k. Для нечетного k:
1 + A + A 2 + A 3 + ... + A k
= (1 + A) (1 + A 2 + (A 2 ) 2 + ... + (A 2 ) (k-1) / 2 )
Поскольку k делится на 2 в каждом рекурсивном вызове, результирующий алгоритм имеет сложность O (log k). В java:
static int modSumAtoAk(int A, int k, int mod)
{
return (modSum1ToAk(A, k, mod) + mod-1) % mod;
}
static int modSum1ToAk(int A, int k, int mod)
{
long sum;
if (k < 5) {
//k is small -- just iterate
sum = 0;
long x = 1;
for (int i=0; i<=k; ++i) {
sum = (sum+x) % mod;
x = (x*A) % mod;
}
return (int)sum;
}
//k is big
int A2 = (int)( ((long)A)*A % mod );
if ((k%2)==0) {
// k even
sum = modSum1ToAk(A2, (k/2)-1, mod);
sum = (sum + sum*A) % mod;
sum = ((sum * A) + 1) % mod;
} else {
// k odd
sum = modSum1ToAk(A2, (k-1)/2, mod);
sum = (sum + sum*A) % mod;
}
return (int)sum;
}
Обратите внимание, что я очень тщательно следил за тем, чтобы каждый продукт выполнялся в 64-битном формате, и уменьшался на модуль после каждого.
Приложив немного математики, все вышеперечисленное можно преобразовать в итеративную версию, не требующую хранения:
static int modSumAtoAk(int A, int k, int mod)
{
// first, we calculate the sum of all 1... A^k
// we'll refer to that as SUM1 in comments below
long fac=1;
long add=0;
//INVARIANT: SUM1 = add + fac*(sum 1...A^k)
//this will remain true as we change k
while (k > 0) {
//above INVARIANT is true here, too
long newmul, newadd;
if ((k%2)==0) {
//k is even. sum 1...A^k = 1+A*(sum 1...A^(k-1))
newmul = A;
newadd = 1;
k-=1;
} else {
//k is odd.
newmul = A+1L;
newadd = 0;
A = (int)(((long)A) * A % mod);
k = (k-1)/2;
}
//SUM1 = add + fac * (newadd + newmul*(sum 1...Ak))
// = add+fac*newadd + fac*newmul*(sum 1...Ak)
add = (add+fac*newadd) % mod;
fac = (fac*newmul) % mod;
//INVARIANT is restored
}
// k == 0
long sum1 = fac + add;
return (int)((sum1 + mod -1) % mod);
}