Предложения по оптимизации простого Scala foldLeft для нескольких значений? - PullRequest
5 голосов
/ 02 февраля 2012

Я заново реализую некоторый код (простой алгоритм байесовского вывода, но это не очень важно) с Java на Scala.Я хотел бы реализовать его максимально эффективным способом, сохраняя при этом код чистым и функциональным, максимально избегая изменчивости.

Вот фрагмент кода Java:

    // initialize
    double lP  = Math.log(prior);
    double lPC = Math.log(1-prior);

    // accumulate probabilities from each annotation object into lP and lPC
    for (Annotation annotation : annotations) {
        float prob = annotation.getProbability();
        if (isValidProbability(prob)) {
            lP  += logProb(prob);
            lPC += logProb(1 - prob);
        }
    } 

Довольно просто, верно?Поэтому я решил использовать Scala foldLeft и методы map для первой попытки.Поскольку у меня есть два значения, по которым я накапливаю, аккумулятор является кортежем:

    val initial  = (math.log(prior), math.log(1-prior))
    val probs    = annotations map (_.getProbability)
    val (lP,lPC) = probs.foldLeft(initial) ((r,p) => {
      if(isValidProbability(p)) (r._1 + logProb(p), r._2 + logProb(1-p)) else r
    })

К сожалению, этот код работает примерно в 5 раз медленнее, чем Java (с использованием простой и неточной метрики; просто называется кодом 10000).раз в цикле).Один дефект довольно очевиден;мы просматриваем списки дважды, один раз в вызове map, а другой в foldLeft.Итак, вот версия, которая просматривает список один раз.

    val (lP,lPC) = annotations.foldLeft(initial) ((r,annotation) => {
      val  p = annotation.getProbability
      if(isValidProbability(p)) (r._1 + logProb(p), r._2 + logProb(1-p)) else r
    })

Это лучше!Он работает примерно в 3 раза хуже, чем код Java.Моя следующая догадка заключалась в том, что, вероятно, существуют некоторые затраты, связанные с созданием всех новых кортежей на каждом этапе сгиба.Поэтому я решил попробовать версию, которая проходит через список дважды, но без создания кортежей.

    val lP = annotations.foldLeft(math.log(prior)) ((r,annotation) => {
       val  p = annotation.getProbability
       if(isValidProbability(p)) r + logProb(p) else r
    })
    val lPC = annotations.foldLeft(math.log(1-prior)) ((r,annotation) => {
      val  p = annotation.getProbability
      if(isValidProbability(p)) r + logProb(1-p) else r
    })

Это работает примерно так же, как и в предыдущей версии (в 3 раза медленнее, чем в версии Java).Не удивительно, но я был полон надежд.

Итак, мой вопрос: есть ли более быстрый способ реализовать этот фрагмент Java в Scala, сохраняя код Scala в чистоте, избегая ненужной изменчивости и следуя идиомам Scala?Я ожидаю использовать этот код в конечном итоге в параллельной среде, поэтому ценность сохранения неизменности может перевесить более медленную производительность в одном потоке.

Ответы [ 5 ]

4 голосов
/ 02 февраля 2012

Во-первых, часть вашего штрафа может исходить от типа коллекции, которую вы используете.Но в большинстве случаев это, вероятно, создание объекта, которого вы на самом деле не избегаете, выполняя цикл дважды, поскольку числа должны быть заключены в квадрат.

Вместо этого вы можете создать изменяемый класс, который накапливает значения для вас:

class LogOdds(var lp: Double = 0, var lpc: Double = 0) {
  def *=(p: Double) = {
    if (isValidProbability(p)) {
      lp += logProb(p)
      lpc += logProb(1-p)
    }
    this  // Pass self on so we can fold over the operation
  }
  def toTuple = (lp, lpc)
}

Теперь, хотя вы можете использовать это небезопасно, вам не нужно.На самом деле, вы можете просто сложить его.

annotations.foldLeft(new LogOdds()) { (r,ann) => r *= ann.getProbability } toTuple

Если вы используете этот шаблон, вся изменчивая небезопасность будет спрятана внутри сгиба;оно никогда не ускользает.

Теперь вы не можете сделать параллельное сгибание, но вы можете сделать агрегат, который похож на сгиб с дополнительной операцией для объединения частей.Таким образом, вы добавляете метод

def **(lo: LogOdds) = new LogOdds(lp + lo.lp, lpc + lo.lpc)

к LogOdds, а затем

annotations.aggregate(new LogOdds())(
  (r,ann) => r *= ann.getProbability,
  (l,r) => l**r
).toTuple

и все будет хорошо.

(Не стесняйтесь использовать не-математические символы для этого, но, поскольку вы в основном умножаете вероятности, символ умножения, скорее всего, дает интуитивное представление о том, что происходит, чем включаетProbability или что-то подобное.)

3 голосов
/ 02 февраля 2012

Вы можете реализовать хвостовой рекурсивный метод, который будет преобразован компилятором в цикл while, следовательно, он должен быть таким же быстрым, как и версия Java.Или вы можете просто использовать цикл - нет закона против него, если он просто использует локальные переменные в методе (см., Например, обширное использование в исходном коде коллекций Scala).

def calc(lst: List[Annotation], lP: Double = 0, lPC: Double = 0): (Double, Double) = {
  if (lst.isEmpty) (lP, lPC)
  else {
    val prob = lst.head.getProbability
    if (isValidProbability(prob)) 
      calc(lst.tail, lP + logProb(prob), lPC + logProb(1 - prob))
    else 
      calc(lst.tail, lP, lPC)
  }
}

Преимуществосвертывание заключается в том, что он распараллеливается, что может привести к тому, что он будет быстрее, чем версия Java на многоядерном компьютере (см. другие ответы).

2 голосов
/ 02 февраля 2012

Во-первых, давайте обратимся к проблеме производительности: нет способа реализовать это так быстро, как Java, кроме как с помощью циклов while .По сути, JVM не может оптимизировать цикл Scala в той степени, в которой он оптимизирует цикл Java.Причины этого даже вызывают беспокойство у людей из JVM, поскольку они мешают параллельной работе библиотек.

Теперь, возвращаясь к производительности Scala, вы также можете использовать .view, чтобы избежать созданияновая коллекция в шаге map, но я думаю, что шаг map всегда приведет к снижению производительности.Дело в том, что вы конвертируете коллекцию в один параметризованный на Double, который должен быть упакован и распакован.

Однако есть один возможный способ его оптимизации: сделать его параллельным.Если вы вызываете .par на annotations для создания параллельной коллекции, вы можете использовать fold:

val parAnnot = annotations.par
val lP = parAnnot.map(_.getProbability).fold(math.log(prior)) ((r,p) => {
   if(isValidProbability(p)) r + logProb(p) else r
})
val lPC = parAnnot.map(_.getProbability).fold(math.log(1-prior)) ((r,p) => {
  if(isValidProbability(p)) r + logProb(1-p) else r
})

Чтобы избежать отдельного шага map, используйте aggregate вместо fold, как рекомендует Рекс.

Для получения бонусных баллов вы можете использовать Future, чтобы оба вычисления выполнялись параллельно.Я подозреваю, что вы получите лучшую производительность, если вернете кортежи и запустите их за один раз.Вам нужно будет сравнить этот материал, чтобы увидеть, что работает лучше.

В параллельных коллекциях он может окупиться первым filter за действительные аннотации.Или, возможно, collect.

val parAnnot = annottions.par.view map (_.getProbability) filter (isValidProbability(_)) force;

или

val parAnnot = annotations.par collect { case annot if isValidProbability(annot.getProbability) => annot.getProbability }

В любом случае, эталон.

2 голосов
/ 02 февраля 2012

В качестве дополнительного примечания: вы можете избежать идиоматического обхода списка вдвое, используя view:

val probs = annotations.view.map(_.getProbability).filter(isValidProbability)

val (lP, lPC) = ((logProb(prior), logProb(1 - prior)) /: probs) {
   case ((pa, ca), p) => (pa + logProb(p), ca + logProb(1 - p))
}

Это, вероятно, не даст вам лучшей производительности, чем ваша третья версия, но мне кажется, что она более изящна.

1 голос
/ 02 февраля 2012

В настоящее время невозможно взаимодействовать с библиотекой коллекций Scala без бокса.Так что примитивы double в Java будут постоянно упаковываться и распаковываться в операции fold, даже если вы не заключали их в Tuple2 (который специализирован - ноКонечно, вы уже платите за производительность, создавая новые объекты каждый раз).

...