Найти среднее значение из перестановок в R - PullRequest
3 голосов
/ 10 апреля 2020

В сумке есть шарики со значениями от 1 до 3. Я нарисую все три мяча без замены наугад. Для первого шара мне нужно заплатить значение шара, умноженное на 1. Для второго шара мне нужно заплатить значение шара, умноженное на 2. Для третьего шара мне нужно заплатить значение шара умножить на 3. Например, если вы нарисовали 1,2,3, то ваш общий платеж составит (1 * 1) + (2 * 2) + (3 * 3) = 14. Я хочу найти среднее из всех возможных Всего платежей.

Итак, у меня был этот код:

library(gtools)

N<-1:3
perms3 <- data.frame(permutations(n = 3, r = 3, v = N))
perms3$total_payment <- perms3$X1 *1+ perms3$X2*2 + perms3$X3*3 
mean(perms3$total_payment)

Я хотел бы сделать общую функцию, которую я могу применить к любому числу N. Например, есть шары со значениями от 1 до 5 или от 1 до 10 и так далее. Я мог бы использовать приведенный выше код с небольшими манипуляциями для вычисления среднего значения общего платежа, например:

N<-1:5
perms5 <- data.frame(permutations(n = 5, r = 5, v = N))
perms5$total_payment <- perms5$X1 *1+ perms5$X2*2 +perms5$X3*3 +perms5$X4*4 +perms5$X5*5
mean(perms5$total_payment) 

Но я не хочу делать это каждый раз. Можете ли вы помочь мне решить эту проблему?

Ответы [ 5 ]

6 голосов
/ 10 апреля 2020

Это можно перевести в решение с постоянным временем, используя немного математики. Короче говоря, мы просто находим Ожидаемое значение .

TL; DR

sum(1:n) * (n + 1) / 2

, равное:

(n * (n + 1) / 2) * (n + 1) / 2   -->>   n * (n + 1)^2 / 4

constantTimeMean <- function(n) n * (n + 1)^2 / 4

constantTimeMean(5)
[1] 45

Пояснение

Пусть (x 1 , x 2 , ... x n ) - перестановка чисел 1 до n . Умножьте каждый x i на i и сложите примерно так:

x_1 * 1 + x_2 * 2 ... + x_n * n

Поскольку мы берем все перестановки, каждый индекс i имеет равную вероятность быть умноженным на каждое число от 1 до n . Также отметим, что если мы удалим коэффициенты, сумма каждой перестановки будет постоянной (то есть sum(1:n)). Таким образом, все, что нам нужно сделать, это вычислить среднее значение от 1 до n и умножить на сумму от 1 до n .

Закрытое выражение в виде суммы от 1 до n определяется как:

 (n * (n + 1) / 2)

Вместе со средним значением получаем :

n * (n + 1)^2 / 4

Это хорошо, потому что генерация всех перестановок выходит из-под контроля очень быстро. Например, что если мы установим N = 15 или даже N = 4321 ? Это facrorial(15) = 1.307674e+12 перестановок ... генерация уже невозможна (factorial(4321) возвращает Inf ... Используя пакет gmp, мы видим, что он действительно имеет более 13000 десятичных цифр: gmp::log10.bigz(gmp::factorialZ(4321)) ~= 13834.99). Тем не менее, с формулой выше, это не проблема:

system.time(print(constantTimeMean(15)))
[1] 960
user  system elapsed 
   0       0       0


system.time(print(constantTimeMean(4321)))
[1] 20178728641
user  system elapsed 
   0       0       0 
2 голосов
/ 10 апреля 2020

Если вам небезразлична скорость, вы можете попробовать реализацию Rfast:

# fastest previous proposition, for reference  
func <- function(N) {
    Ns <- seq_len(N)
    mean(gtools::permutations(n = N, r = N, v = Ns) %*% matrix(seq_len(N)))
}

# implementation using Rfast
func_u <- function(n){
    sn <- seq_len(n)
    mean(tcrossprod(Rfast::permutation(sn), t(sn)))
}

microbenchmark::microbenchmark(
    f_3 = func(3),
    u3 = func_u(3),
    f_7 = func(7),
    u7 = func_u(7)
)
#> Unit: microseconds
#>  expr       min          lq        mean      median          uq        max
#>   f_3   168.345    187.7160    661.2309    217.8845    244.7845  44466.821
#>    u3    35.434     45.3930    127.6996     52.6240     90.3450   6398.212
#>   f_7 47170.752 111422.4390 112419.3058 113008.3590 114360.2590 126243.638
#>    u7   234.751    271.7305    882.8380    298.1155    336.3765  41195.978
#>  neval cld
#>    100  a 
#>    100  a 
#>    100   b
#>    100  a

Создано в 2020-04-09 пакетом Представить (v0. 3,0)

2 голосов
/ 10 апреля 2020

Альтернатива функции RonakShah.

func <- function(N) {
  Ns <- seq_len(N)
  mean(gtools::permutations(n = N, r = N, v = Ns) %*% matrix(Ns))
}
func(3)
# [1] 12
func(5)
# [1] 45

Этот метод имеет преимущество в том, что он заботится о умножении матрицы, которое вы используете. Улучшения скорости могут иметь тенденцию выравниваться для больших образцов. Мы также можем добавить предложение R.Schifini (в get_mean_b ниже), чтобы использовать apply, хотя в целом rowSums быстрее, чем более универсальный c apply использует:

microbenchmark::microbenchmark(
  ronak_3  = get_mean(3),
  ronak_3b = get_mean_b(3),
  akrun_3  = akrun(3),
  r2_3     = func(3),
  ronak_5  = get_mean(5),
  ronak_5b = get_mean_b(5),
  akrun_5  = akrun(5),
  r2_5     = func(5),
  ronak_7  = get_mean(7),
  ronak_7b = get_mean_b(7),
  akrun_7  = akrun(7),
  r2_7     = func(7)
)
# Unit: microseconds
#      expr       min         lq       mean     median         uq        max neval
#   ronak_3   438.001   577.5010   684.8250   639.3510   752.7010   1769.601   100
#  ronak_3b   241.901   310.0005   386.5211   352.0010   423.1515   1202.001   100
#   akrun_3   202.601   274.4510   484.4809   297.0005   365.2010  13570.301   100
#      r2_3    87.601   110.4510   132.0599   125.3505   150.9010    218.000   100
#   ronak_5  1338.101  1689.3010  2085.9439  1774.6510  1949.9510  25789.601   100
#  ronak_5b  1208.101  1545.5000  1813.0931  1643.9015  1831.6510   5187.100   100
#   akrun_5  1004.301  1291.5010  1459.4920  1376.2010  1526.7010   3422.901   100
#      r2_5   924.601  1097.8510  1334.1570  1161.7510  1308.2010   5304.501   100
#   ronak_7 35273.101 46720.0505 59103.9000 54075.6015 64263.3005 118192.401   100
#  ronak_7b 43330.700 56615.3005 70568.5350 62788.4515 74308.0505 213410.001   100
#   akrun_7 34402.701 44957.6015 57026.5051 52982.6010 62273.2010 131092.001   100
#      r2_7 35018.401 43930.4510 58400.5710 51515.6510 61678.9510 167691.602   100
2 голосов
/ 10 апреля 2020

Вы можете написать функцию для вычисления этого.

library(gtools)

get_mean <- function(n) {
   perms <- data.frame(permutations(n = n, r = n, v = seq_len(n)))
   mean(rowSums(perms * as.list(seq_len(n))))
}

get_mean(3)
#[1] 12

get_mean(5)
#[1] 45
1 голос
/ 10 апреля 2020

Мы можем использовать crossprod

get_mean <- function(n) {
    perms <- data.frame(permutations(n = n, r = n, v = seq_len(n)))
     mean(crossprod(t(perms), seq_len(n)))

 }
get_mean(3)
#[1] 12
get_mean(5)
#[1] 45
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...