Раствор в O (m log n):
def geometric(n,b,m):
T=1
e=b%m
total = 0
while n>0:
if n&1==1:
total = (e*total + T)%m
T = ((e+1)*T)%m
e = (e*e)%m
n = n//2
return total
def efficient_solve(n, m):
ans = 0
for x in range(1, min(n, m) + 1):
k = pow(x, m, m)
s = pow(x, x, m)
times = (n // m) + (x <= n % m)
ans += s * geometric(times, k, m)
ans = ans % m
return ans
геометрический вычисляет геометрический ряд по модулю m, взятый из https://stackoverflow.com/a/42033401/3308055
Пояснение
N слишком велико, нам нужен способ для вычисления результатов нескольких сумм за одну операцию.
Обратите внимание, что x ^ x % m = (mk + i) ^ (mk + i) % m
при i
(mk + i) ^ (mk + i) % m = (mk + i) * (mk + i) * (mk + i) * ... * (mk + i)
(мк + i) раз
Если бы мы начали распространять это, почти все результаты имели бы как минимум 1 мк как фактор, а mk * whatever % m
будет равно 0.
Единственный результат без коэффициента mk будет i * i * i * i * ... * i
(mk + i) раз. То есть i^(mk + i)
.
Так что, если n = 5 и m = 3, вместо решения 1^1 + 2^2 + 3^3 + 4^4 + 5^5 % 3
мы можем решить 1 ^ (0 + 1) + 2 ^ (0 + 2) + 0 ^ (3 + 0) + 1 ^ (3 + 1) + 2 ^ (3 + 2) % m
.
Это хорошо, но нам все еще нужно выполнить O (n) операций. Давайте попробуем сгруппировать некоторые из этих сумм. Мы будем группировать на основе i % m
, у нас есть 3 группы:
1 ^ (0 + 1) + 1 ^ (3 + 1)
2 ^ (0 + 2) + 2 ^ (3 + 2)
0 ^ (3 + 0)
Как мы можем эффективно рассчитать результат каждой группы?
Обратите внимание, что для каждой группы у нас одна и та же база, и показатель степени увеличивается на m для каждой суммы. Если мы знаем результат первой суммы (1 ^ (0 + 1)
), как изменится следующая сумма (1 ^ (3 + 1)
) в отношении% m?
1 ^ (3 + 1) % m = (1 ^ 1 % m) * (1 ^ 3 % m) % m
. Если n было выше, и у нас было 1 ^ (6 + 1)
в этой группе, 1 ^ (6 + 1) % m = (1 ^ 1 % m) * (1 ^ 3 % m) * (1 ^ 3 % m) % m
. Обратите внимание, что для каждой следующей суммы в той же группе нам просто нужно добавить результат (1 ^ 3 % m)
. В целом, нам нужно добавить base ^ m % m
.
Сколько сумм у нас в каждой группе? Ну, у нас будет 1 для каждого n times = (n // m) + (x <= n % m).
Давайте назовем x
индекс группы, который также будет основой показателей. У нас будет min (n, m) групп
Давайте назовем k
результат x ^ m % m
. Давайте назовем s
результат x ^ x % m
.
Результат решения всех сумм для этой группы будет:
s + s * k + s * k^2 + s * k^3 ... + s * k^(times - 1)
Это эквивалентно:
s * (1 + k + k^2 + k^3 ... + k^(times - 1))
И у нас есть геометрический ряд, который мы можем эффективно вычислить. Благодаря этому у нас есть все необходимое для расчета ответа на задачу.