Проблема дополнения CodeSprint 2 выполняется слишком медленно - PullRequest
3 голосов
/ 12 марта 2012

На оригинальном InterviewStreet Codesprint возникает вопрос о подсчете количества единиц в представлении дополнения двух чисел между a и b включительно. Мне удалось пройти все тесты на точность с помощью итерации, но я смог пройти только два за правильный промежуток времени. Был подсказка, в которой упоминалось нахождение рекуррентного отношения, поэтому я переключился на рекурсию, но это заняло столько же времени. Так может кто-нибудь найти более быстрый способ сделать это, чем код, который я предоставил? Первый номер входного файла - это контрольные примеры в файле. Я предоставил образец входного файла после кода.

import java.util.Scanner;

public class Solution {

    public static void main(String[] args) {

        Scanner scanner = new Scanner(System.in);
        int numCases = scanner.nextInt();
        for (int i = 0; i < numCases; i++) {
            int a = scanner.nextInt();
            int b = scanner.nextInt();
            System.out.println(count(a, b));
        }
    }

    /**
     * Returns the number of ones between a and b inclusive
     */
    public static int count(int a, int b) {
        int count = 0;
        for (int i = a; i <= b; i++) {
            if (i < 0)
                count += (32 - countOnes((-i) - 1, 0));
            else
                count += countOnes(i, 0);
        }

        return count;
    }

    /**
     * Returns the number of ones in a
     */
    public static int countOnes(int a, int count) {
        if (a == 0)
            return count;
        if (a % 2 == 0)
            return countOnes(a / 2, count);
        else
            return countOnes((a - 1) / 2, count + 1);
    }
}

Введите:

3
-2 0
-3 4
-1 4

Output:
63
99
37

1 Ответ

2 голосов
/ 12 марта 2012

Первый шаг - заменить

public static int countOnes(int a, int count) {
    if (a == 0)
        return count;
    if (a % 2 == 0)
        return countOnes(a / 2, count);
    else
        return countOnes((a - 1) / 2, count + 1);
}

, который повторяется до глубины лога 2 a, с более быстрой реализацией, например знаменитым бит-тиддлингом

public static int popCount(int n) {
    // count the set bits in each bit-pair
    // 11 -> 10, 10 -> 01, 0* -> 0*
    n -= (n >>> 1) & 0x55555555;
    // count bits in each nibble
    n = ((n >>> 2) & 0x33333333) + (n & 0x33333333);
    // count bits in each byte
    n = ((n >> 4) & 0x0F0F0F0F) + (n & 0x0F0F0F0F);
    // accumulate the counts in the highest byte and shift
    return (0x01010101 * n) >> 24;
    // Java guarantees wrap-around, so we can use int here,
    // in C, one would need to use unsigned or a 64-bit type
    // to avoid undefined behaviour
}

, в котором используются четыре смены, пять битовых разрядов, одно вычитание, два сложения и одно умножение, что в общей сложности дает тринадцать очень дешевых инструкций.

Но если диапазоны не очень малы, можно сделать намного лучше, чем считать биты каждого отдельного числа.

Давайте сначала рассмотрим неотрицательные числа. Для чисел от 0 до 2 k -1 установлено до k битов. Каждый бит установлен ровно в половине из них, поэтому общее количество битов составляет k*2^(k-1). Теперь пусть 2^k <= a < 2^(k+1). Общее количество битов в числах 0 <= n <= a является суммой битов в числах 0 <= n < 2^k и битов в числах 2^k <= n <= a. Первый счет, как мы видели выше, k*2^(k-1). Во второй части у нас есть a - 2^k + 1 числа, у каждого из них есть набор 2 k -бит, и, игнорируя старший бит, биты такие же, как в числах 0 <= n <= (a - 2^k), так

totalBits(a) = k*2^(k-1) + (a - 2^k + 1) + totalBits(a - 2^k)

Теперь для отрицательных чисел. В дополнение к двум, -(n+1) = ~n, поэтому числа -a <= n <= -1 являются дополнениями к числам 0 <= m <= (a-1), а общее число установленных битов в числах -a <= n <= -1 равно a*32 - totalBits(a-1).

Для общего числа битов в диапазоне a <= n <= b мы должны сложить или вычесть, в зависимости от того, имеют ли оба конца диапазона противоположный или один и тот же знак.

// if n >= 0, return the total of set bits for
// the numbers 0 <= k <= n
// if n < 0, return the total of set bits for
// the numbers n <= k <= -1
public static long totalBits(int n){
    if (n < 0) {
        long a = -(long)n;
        return (a*32 - totalBits((int)(a-1)));
    }
    if (n < 3) return n;
    int lg = 0, mask = n;
    // find the highest set bit in n and its position
    while(mask > 1){
        ++lg;
        mask >>= 1;
    }
    mask = 1 << lg;
    // total bit count for 0 <= k < 2^lg
    long total = 1L << lg-1;
    total *= lg;
    // add number of 2^lg bits
    total += n+1-mask;
    // add number of other bits for 2^lg <= k <= n
    total += totalBits(n-mask);
    return total;
}

// return total set bits for the numbers a <= n <= b
public static long totalBits(int a, int b) {
    if (b < a) throw new IllegalArgumentException("Invalid range");
    if (a == b) return popCount(a);
    if (b == 0) return totalBits(a);
    if (b < 0) return totalBits(a) - totalBits(b+1);
    if (a == 0) return totalBits(b);
    if (a > 0) return totalBits(b) - totalBits(a-1);
    // Now a < 0 < b
    return totalBits(a) + totalBits(b);
}
...