Fork Join Matrix Умножение в Java - PullRequest
2 голосов
/ 29 марта 2011

Я провожу некоторые исследования производительности в среде fork / join в Java 7. Чтобы улучшить результаты теста, я хочу использовать различные рекурсивные алгоритмы во время тестов.Одна из них - это умножающие матрицы.

Я скачал следующий пример с веб-сайта Дуга Ли ():

<code>public class MatrixMultiply {

  static final int DEFAULT_GRANULARITY = 16;

  /** The quadrant size at which to stop recursing down
   * and instead directly multiply the matrices.
   * Must be a power of two. Minimum value is 2.
   **/
  static int granularity = DEFAULT_GRANULARITY;

  public static void main(String[] args) {

    final String usage = "Usage: java MatrixMultiply <threads> <matrix size (must be a power of two)> [<granularity>] \n Size and granularity must be powers of two.\n For example, try java MatrixMultiply 2 512 16";

    try {
      int procs;
      int n;
      try {
        procs = Integer.parseInt(args[0]);
        n = Integer.parseInt(args[1]);
        if (args.length > 2) granularity = Integer.parseInt(args[2]);
      }

      catch (Exception e) {
        System.out.println(usage);
        return;
      }

      if ( ((n & (n - 1)) != 0) || 
           ((granularity & (granularity - 1)) != 0) ||
           granularity < 2) {
        System.out.println(usage);
        return;
      }

      float[][] a = new float[n][n];
      float[][] b = new float[n][n];
      float[][] c = new float[n][n];
      init(a, b, n);

      FJTaskRunnerGroup g = new FJTaskRunnerGroup(procs);
      g.invoke(new Multiplier(a, 0, 0, b, 0, 0, c, 0, 0, n));
      g.stats();

      // check(c, n);
    }
    catch (InterruptedException ex) {}
  }


  // To simplify checking, fill with all 1's. Answer should be all n's.
  static void init(float[][] a, float[][] b, int n) {
    for (int i = 0; i < n; ++i) {
      for (int j = 0; j < n; ++j) {
        a[i][j] = 1.0F;
        b[i][j] = 1.0F;
      }
    }
  }

  static void check(float[][] c, int n) {
    for (int i = 0; i < n; i++ ) {
      for (int j = 0; j < n; j++ ) {
        if (c[i][j] != n) {
          throw new Error("Check Failed at [" + i +"]["+j+"]: " + c[i][j]);
        }
      }
    }
  }

  /** 
   * Multiply matrices AxB by dividing into quadrants, using algorithm:
   * <pre>
   *      A      x      B                             
   *
   *  A11 | A12     B11 | B12     A11*B11 | A11*B12     A12*B21 | A12*B22 
   * |----+----| x |----+----| = |--------+--------| + |---------+-------|
   *  A21 | A22     B21 | B21     A21*B11 | A21*B21     A22*B21 | A22*B22 
   * 
* / статический класс Multiplier extends FJTask {final float [] [] A;// Matrix A final int aRow;// первая строка текущего квадранта A final int aCol;// первый столбец текущего квадранта A final float [] [] B;// Аналогично для B final int bRow;final int bCol;последнее плавание [] [] C;// Аналогично для матрицы результатов C final int cRow;final int cCol;окончательный размер int;// количество элементов в текущем квадранте Multiplier (float [] [] A, int aRow, int aCol, float [] [] B, int bRow, int bCol, float [] [] C, int cRow, int cCol, intразмер) {это. А = А;this.aRow = aRow;this.aCol = aCol;this.B = B;this.bRow = bRow;this.bCol = bCol;this.C = C;this.cRow = cRow;this.cCol = cCol;this.size = размер;} public void run () {if (size

Этот код написан для более старой версии платформы fork / join. Поэтому я должен переписать это. Мой переписанный код реализует мой собственный интерфейс и выглядит так:

public class Java7MatrixMultiply implements Algorithm { 
    private static final int SIZE = 32;
    private static final int THRESHOLD = 8;

    private float[][] a = new float[SIZE][SIZE];
    private float[][] b = new float[SIZE][SIZE];
    private float[][] c = new float[SIZE][SIZE];

    ForkJoinPool forkJoinPool;

    @Override
    public void initialize() {
        init(a, b, SIZE);
    }

    @Override
    public void execute() {
        MatrixMultiplyTask mainTask = new MatrixMultiplyTask(a, 0, 0, b, 0, 0, c, 0, 0, SIZE);
        forkJoinPool = new ForkJoinPool();
        forkJoinPool.invoke(mainTask);

        System.out.println("Terminated!");
    }

    @Override
    public void printResult() { 
        check(c, SIZE);

        for (int i = 0; i < SIZE; i++) {
            for (int j = 0; j < SIZE; j++) {
                System.out.print(c[i][j] + " ");
            }

            System.out.println();
        }
    }

    // To simplify checking, fill with all 1's. Answer should be all n's.
    static void init(float[][] a, float[][] b, int n) {
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < n; ++j) {
                a[i][j] = 1.0F;
                b[i][j] = 1.0F;
            }
        }
    }

    static void check(float[][] c, int n) {
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                if (c[i][j] != n) {
                    //throw new Error("Check Failed at [" + i + "][" + j + "]: " + c[i][j]);
                    System.out.println("Check Failed at [" + i + "][" + j + "]: " + c[i][j]);
                }
            }
        }
    }

    private class MatrixMultiplyTask extends RecursiveAction {
        private final float[][] A; // Matrix A
        private final int aRow; // first row of current quadrant of A
        private final int aCol; // first column of current quadrant of A

        private final float[][] B; // Similarly for B
        private final int bRow;
        private final int bCol;

        private final float[][] C; // Similarly for result matrix C
        private final int cRow;
        private final int cCol;

        private final int size;

        MatrixMultiplyTask(float[][] A, int aRow, int aCol, float[][] B,
                int bRow, int bCol, float[][] C, int cRow, int cCol, int size) {
            this.A = A;
            this.aRow = aRow;
            this.aCol = aCol;
            this.B = B;
            this.bRow = bRow;
            this.bCol = bCol;
            this.C = C;
            this.cRow = cRow;
            this.cCol = cCol;
            this.size = size;
        }

        @Override
        protected void compute() {      
            if (size <= THRESHOLD) {
                multiplyStride2();
            } else {

                int h = size / 2;               

                invokeAll(new MatrixMultiplyTask[] {
                        new MatrixMultiplyTask(A, aRow, aCol, // A11
                                B, bRow, bCol, // B11
                                C, cRow, cCol, // C11
                                h),

                        new MatrixMultiplyTask(A, aRow, aCol + h, // A12
                                B, bRow + h, bCol, // B21
                                C, cRow, cCol, // C11
                                h),

                        new MatrixMultiplyTask(A, aRow, aCol, // A11
                                B, bRow, bCol + h, // B12
                                C, cRow, cCol + h, // C12
                                h),

                        new MatrixMultiplyTask(A, aRow, aCol + h, // A12
                                B, bRow + h, bCol + h, // B22
                                C, cRow, cCol + h, // C12
                                h),

                        new MatrixMultiplyTask(A, aRow + h, aCol, // A21
                                B, bRow, bCol, // B11
                                C, cRow + h, cCol, // C21
                                h),

                        new MatrixMultiplyTask(A, aRow + h, aCol + h, // A22
                                B, bRow + h, bCol, // B21
                                C, cRow + h, cCol, // C21
                                h),

                        new MatrixMultiplyTask(A, aRow + h, aCol, // A21
                                B, bRow, bCol + h, // B12
                                C, cRow + h, cCol + h, // C22
                                h),

                        new MatrixMultiplyTask(A, aRow + h, aCol + h, // A22
                                B, bRow + h, bCol + h, // B22
                                C, cRow + h, cCol + h, // C22
                                h) });

            }
        }

        /**
         * Version of matrix multiplication that steps 2 rows and columns at a
         * time. Adapted from Cilk demos. Note that the results are added into
         * C, not just set into C. This works well here because Java array
         * elements are created with all zero values.
         **/

        void multiplyStride2() {
            for (int j = 0; j < size; j += 2) {
                for (int i = 0; i < size; i += 2) {

                    float[] a0 = A[aRow + i];
                    float[] a1 = A[aRow + i + 1];

                    float s00 = 0.0F;
                    float s01 = 0.0F;
                    float s10 = 0.0F;
                    float s11 = 0.0F;

                    for (int k = 0; k < size; k += 2) {

                        float[] b0 = B[bRow + k];

                        s00 += a0[aCol + k] * b0[bCol + j];
                        s10 += a1[aCol + k] * b0[bCol + j];
                        s01 += a0[aCol + k] * b0[bCol + j + 1];
                        s11 += a1[aCol + k] * b0[bCol + j + 1];

                        float[] b1 = B[bRow + k + 1];

                        s00 += a0[aCol + k + 1] * b1[bCol + j];
                        s10 += a1[aCol + k + 1] * b1[bCol + j];
                        s01 += a0[aCol + k + 1] * b1[bCol + j + 1];
                        s11 += a1[aCol + k + 1] * b1[bCol + j + 1];
                    }

                    C[cRow + i][cCol + j] += s00;
                    C[cRow + i][cCol + j + 1] += s01;
                    C[cRow + i + 1][cCol + j] += s10;
                    C[cRow + i + 1][cCol + j + 1] += s11;
                }
            }
        }
    }
}

Иногда мои вычисления не проходят проверку. Некоторые поля матрицы имеют другое значение, как и ожидалось. Эти несоответствия являются случайными, и не всегда происходят. Я подозреваю, что что-то идет не так в методе compute, потому что мне пришлось переписывать части, где используется класс Seq. Класс Seq выполняет задачи по порядку, в отличие от метода invokeAll (). Класс больше не существует в текущей версии платформы fork / join. Я не очень знаком с алгоритмом умножения матриц, поэтому очень трудно понять, что происходит не так. Есть предложения?

Ответы [ 2 ]

1 голос
/ 29 марта 2011

Вы накапливаете результаты в C[cRow + i][cCol + j] += s00; и т.п.Это не потокобезопасная операция, поэтому необходимо синхронизировать строку или убедиться, что только одна задача когда-либо обновляет ячейку.Без этого вы увидите, что случайные ячейки установлены неправильно.

Я бы проверил, что вы получите правильный ответ с параллелизмом 1.

Кстати: float может быть не лучшим выбором здесь.Он имеет довольно низкое количество цифр точности, а в тяжелых матричных операциях (которые, как я полагаю, вы делаете, или если не будет большого смысла в использовании нескольких потоков), ошибка округления может израсходовать большую часть или всю вашу точность.Я бы предложил вместо double.

, например, float имеет около 7 цифр точности, и одно из практических правил заключается в том, что ошибка пропорциональна количеству вычислений.Таким образом, для матрицы 1K x 1K у вас может остаться 4 цифры точности.Для 10K x 10K у вас может быть только три в лучшем случае.double имеет 16 цифр точности, что означает, что у вас может быть 12 цифр точности после умножения 10K x 10K.

0 голосов
/ 29 марта 2011

Как вы уже заметили, последовательное выполнение подзадач, принадлежащих одному и тому же квадранту, важно для этого алгоритма.Итак, вам нужно реализовать свою собственную функцию seq(), например, следующим образом, и использовать ее как в исходном коде:

public ForkJoinTask<?> seq(final ForkJoinTask<?> a, final ForkJoinTask<?> b) {
    return adapt(new Runnable() {
        public void run() {
            a.invoke();
            b.invoke();
        }
    });
}
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...