Повторное умножение вызовет некоторую хитрую числовую нестабильность, так как результаты ваших умножений требуют все больше и больше битов для представления.Я предлагаю вам перевести это в лог-пространство и использовать суммирование, а не умножение:
import math
def get_perplexity(test_set, model):
log_perplexity = 0
n = 0
for word in test_set:
n += 1
log_perplexity -= math.log(get_prob(model, word))
log_perplexity /= float(n)
return math.exp(log_perplexity)
Таким образом, все ваши логарифмы могут быть представлены в стандартном количестве битов, и вы не получите числовые увеличения и потериточности.Кроме того, вы можете ввести произвольную степень точности, используя модуль decimal
:
import decimal
def get_perplexity(test_set, model):
with decimal.localcontext() as ctx:
ctx.prec = 100 # set as appropriate
log_perplexity = decimal.Decimal(0)
n = 0
for word in test_set:
n += 1
log_perplexity -= decimal.Decimal(get_prob(model, word))).ln()
log_perplexity /= float(n)
return log_perplexity.exp()