У меня есть это задание, чтобы использовать то, что я узнал о форке и соединении. Я слышал некоторые слухи, что можно достичь 60-кратной оптимизации для n = 11
этой проблемы. Я могу думать только о проблеме как о разбиении умножения на 4 меньшие матрицы на основе рекурсивного алгоритма, представленного в Matrix
, поэтому я просто скопировал алгоритм и перевел их в конструкторы MatrixMultiplication
, как видно из метода compute
. Однако мне удалось добиться 20-кратной оптимизации только для n = 11
. Я не знаю, связана ли моя неудача с достижением требуемой оптимизации с алгоритмом, который я использую, или с размещением форка и соединения. Я не ищу улучшения, когда я изменил некоторые из вызываемых конструкторов на compute
вместо fork
и join
. Может ли кто-нибудь дать мне подсказку о том, как я могу добиться оптимизации в 60 раз.
Вот описание проблемы, моя попытка и что я имел в виду под увеличением в 20 раз (согласно выводу из программы).
MatrixMultiplication
import java.util.concurrent.RecursiveTask;
class MatrixMultiplication extends RecursiveTask<Matrix> {
/** The fork threshold. */
private static final int FORK_THRESHOLD = 128;
/** The first matrix to multiply with. */
private Matrix m1;
/** The second matrix to multiply with. */
private Matrix m2;
/** The starting row of m1. */
private int m1Row;
/** The starting col of m1. */
private int m1Col;
/** The starting row of m2. */
private int m2Row;
/** The starting col of m2. */
private int m2Col;
/**
* The dimension of the input (sub)-matrices and the size of the output
* matrix.
*/
private int dimension;
/**
* A constructor for the Matrix Multiplication class.
* @param m1 The matrix to multiply with.
* @param m2 The matrix to multiply with.
* @param m1Row The starting row of m1.
* @param m1Col The starting col of m1.
* @param m2Row The starting row of m2.
* @param m2Col The starting col of m2.
* @param dimension The dimension of the input (sub)-matrices and the size
* of the output matrix.
*/
MatrixMultiplication(Matrix m1, Matrix m2, int m1Row, int m1Col, int m2Row,
int m2Col, int dimension) {
this.m1 = m1;
this.m2 = m2;
this.m1Row = m1Row;
this.m1Col = m1Col;
this.m2Row = m2Row;
this.m2Col = m2Col;
this.dimension = dimension;
}
@Override
public Matrix compute() {
/* if (dimension == 1) {
//Matrix result = new Matrix(1);
return Matrix.nonRecursiveMultiply(m1,m2,m1Row,m1Col,m2Row,m2Col,dimension);
//trivial and return the same type...
//count the usually we do one and must be pre-existing
}*/
if (dimension < FORK_THRESHOLD) {
return Matrix.nonRecursiveMultiply(m1, m2, m1Row, m1Col, m2Row, m2Col, dimension);
}
int size = dimension / 2;
Matrix result = new Matrix(dimension);
MatrixMultiplication mma11b11 = new MatrixMultiplication(m1,m2,m1Row,m1Col,
m2Row,m2Col, size);
MatrixMultiplication mma12b21 = new MatrixMultiplication(m1, m2, m1Row,
m1Col + size, m2Row + size, m2Col, size);
MatrixMultiplication mma11b12 = new MatrixMultiplication(m1, m2, m1Row, m1Col,
m2Row, m2Col + size, size);
MatrixMultiplication mma12b22 = new MatrixMultiplication(m1, m2, m1Row,
m1Col + size, m2Row + size, m2Col + size, size);
MatrixMultiplication mma21b11 = new MatrixMultiplication(m1, m2, m1Row + size,
m1Col, m2Row, m2Col, size);
MatrixMultiplication mma22b21 = new MatrixMultiplication(m1, m2, m1Row + size,
m1Col + size, m2Row + size, m2Col, size);
MatrixMultiplication mma21b12 = new MatrixMultiplication(m1, m2, m1Row + size,
m1Col, m2Row, m2Col + size, size);
MatrixMultiplication mma22b22 = new MatrixMultiplication(m1, m2, m1Row + size,
m1Col + size, m2Row + size, m2Col + size, size);
mma11b11.fork();
mma12b21.fork();
mma11b12.fork();
mma12b22.fork();
mma21b11.fork();
mma22b21.fork();
mma21b12.fork();
mma22b22.fork();
Matrix a22b22 = mma22b22.join();
Matrix a21b12 = mma21b12.join();
Matrix a22b21 = mma22b21.join();
Matrix a21b11 = mma21b11.join();
Matrix a12b22 = mma12b22.join();
Matrix a11b12 = mma11b12.join();
Matrix a12b21 = mma12b21.join();
Matrix a11b11 = mma11b11.join();
for (int i = 0; i < size; i++) {
double[] m1m = a21b12.m[i];
double[] m2m = a22b22.m[i];
double[] r1m = result.m[i + size];
for (int j = 0; j < size; j++) {
r1m[j + size] = m1m[j] + m2m[j];
}
}
for (int i = 0; i < size; i++) {
double[] m1m = a21b11.m[i];
double[] m2m = a22b21.m[i];
double[] r1m = result.m[i + size];
for (int j = 0; j < size; j++) {
r1m[j] = m1m[j] + m2m[j];
}
}
for (int i = 0; i < size; i++) {
double[] m1m = a11b12.m[i];
double[] m2m = a12b22.m[i];
double[] r1m = result.m[i];
for (int j = 0; j < size; j++) {
r1m[j + size] = m1m[j] + m2m[j];
}
}
for (int i = 0; i < size; i++) {
double[] m1m = a11b11.m[i];
double[] m2m = a12b21.m[i];
double[] r1m = result.m[i];
for (int j = 0; j < size; j++) {
r1m[j] = m1m[j] + m2m[j];
}
}
return result;
}
}
Матрица
import java.util.function.Supplier;
import java.lang.StringBuilder;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveTask;
import java.util.concurrent.ForkJoinTask;
import java.lang.Runtime;
/**
* Encapsulate a square matrix of double values.
*/
class Matrix {
/**
* 2D square array of double values, storing the matrix.
*/
double[][] m;
/**
* The number of columns and rows in the matrix.
*/
int dimension;
private static final int THRESHOLD = 2;
/**
* Checks if two matrices are equals.
* @param m1 First matrices to check
* @param m2 Second matrices to check against
* @return true if every elements in m1 and m2 are the same; false otherwise.
*/
public static boolean equals(Matrix m1, Matrix m2) {
if (m1.dimension != m2.dimension) {
return false;
}
for (int i = 0; i < m1.dimension; i++) {
for (int j = 0; j < m1.dimension; j++) {
if (Math.abs(m1.m[i][j] - m2.m[i][j]) > 0.000001) {
return false;
}
}
}
return true;
}
/**
* A constructor for the matrix.
* @param dimension The number of rows.
*/
Matrix(int dimension) {
this.dimension = dimension;
this.m = new double[dimension][dimension];
}
/**
* Generate a matrix of d x d according to the given supplier.
* @param dimension The dimension of the matrix
* @param supplier The lambda to generate the matrix with.
* @return The new matrix.
*/
static Matrix generate(int dimension, Supplier<Double> supplier) {
Matrix matrix = new Matrix(dimension);
for (int row = 0; row < dimension; row++) {
for (int col = 0; col < dimension; col++) {
matrix.m[row][col] = supplier.get();
}
}
return matrix;
}
/**
* Return a string representation of the matrix, pretty-printed
* with each row on a single line.
* @return The string representation of this matrix.
*/
public String toString() {
StringBuilder s = new StringBuilder();
for (int row = 0; row < dimension; row++) {
for (int col = 0; col < dimension; col++) {
s.append(String.format("%.4f", m[row][col]) + " ");
}
s.append("\n");
}
return s.toString();
}
/**
* Multiply matrix m with this matrix, return a new result matrix.
* @param m1 The matrix to multiply with.
* @param m2 The matrix to multiply with.
* @param m1Row The starting row of m1.
* @param m1Col The starting col of m1.
* @param m2Row The starting row of m2.
* @param m2Col The starting col of m2.
* @param dimension The dimension of the input (sub)-matrices and the size
* of the output matrix.
* @return The new matrix.
*/
public static Matrix nonRecursiveMultiply(Matrix m1, Matrix m2,
int m1Row, int m1Col, int m2Row, int m2Col, int dimension) {
Matrix result = new Matrix(dimension);
for (int row = 0; row < dimension; row++) {
for (int col = 0; col < dimension; col++) {
double sum = 0;
// multiply row to col
for (int i = 0; i < dimension; i++) {
sum += m1.m[row + m1Row][i + m1Col] * m2.m[i + m2Row][col + m2Col];
}
result.m[row][col] = sum;
}
}
return result;
}
/**
* Multiple two matrices non-recursively.
* @param m1 The matrix to multiply with.
* @param m2 The matrix to multiply with.
* @return The resulting matrix m1 * m2
*/
public static Matrix nonRecursiveMultiply(Matrix m1, Matrix m2) {
return Matrix.nonRecursiveMultiply(m1, m2, 0, 0, 0, 0, m1.dimension);
}
/**
* Multiply matrix m with this matrix, return a new result matrix.
* @param m1 The matrix to multiply with.
* @param m2 The matrix to multiply with.
* @param m1Row The starting row of m1.
* @param m1Col The starting col of m1.
* @param m2Row The starting row of m2.
* @param m2Col The starting col of m2.
* @param dimension The dimension of the input (sub)-matrices and the size
* of the output matrix.
* @return The resulting matrix m1 * m2
*/
public static Matrix recursiveMultiply(Matrix m1, Matrix m2,
int m1Row, int m1Col, int m2Row, int m2Col, int dimension) {
// If the matrix is small enough, just multiple non-recursively.
if (dimension <= THRESHOLD) {
return Matrix.nonRecursiveMultiply(m1, m2, m1Row, m1Col, m2Row, m2Col, dimension);
}
// Else, cut the matrix into four blocks of equal size, recursively
// multiply then sum the multiplication result.
int size = dimension / 2;
Matrix result = new Matrix(dimension);
Matrix a11b11 = recursiveMultiply(m1, m2, m1Row, m1Col, m2Row,
m2Col, size);
Matrix a12b21 = recursiveMultiply(m1, m2, m1Row, m1Col + size,
m2Row + size, m2Col, size);
for (int i = 0; i < size; i++) {
double[] m1m = a11b11.m[i];
double[] m2m = a12b21.m[i];
double[] r1m = result.m[i];
for (int j = 0; j < size; j++) {
r1m[j] = m1m[j] + m2m[j];
}
}
Matrix a11b12 = recursiveMultiply(m1, m2, m1Row, m1Col, m2Row,
m2Col + size, size);
Matrix a12b22 = recursiveMultiply(m1, m2, m1Row, m1Col + size,
m2Row + size, m2Col + size, size);
for (int i = 0; i < size; i++) {
double[] m1m = a11b12.m[i];
double[] m2m = a12b22.m[i];
double[] r1m = result.m[i];
for (int j = 0; j < size; j++) {
r1m[j + size] = m1m[j] + m2m[j];
}
}
Matrix a21b11 = recursiveMultiply(m1, m2, m1Row + size, m1Col,
m2Row, m2Col, size);
Matrix a22b21 = recursiveMultiply(m1, m2, m1Row + size, m1Col + size,
m2Row + size, m2Col, size);
for (int i = 0; i < size; i++) {
double[] m1m = a21b11.m[i];
double[] m2m = a22b21.m[i];
double[] r1m = result.m[i + size];
for (int j = 0; j < size; j++) {
r1m[j] = m1m[j] + m2m[j];
}
}
Matrix a21b12 = recursiveMultiply(m1, m2, m1Row + size, m1Col,
m2Row, m2Col + size, size);
Matrix a22b22 = recursiveMultiply(m1, m2, m1Row + size, m1Col + size,
m2Row + size, m2Col + size, size);
for (int i = 0; i < size; i++) {
double[] m1m = a21b12.m[i];
double[] m2m = a22b22.m[i];
double[] r1m = result.m[i + size];
for (int j = 0; j < size; j++) {
r1m[j + size] = m1m[j] + m2m[j];
}
}
return result;
}
/**
* Multiple two matrices recursively but sequentially with
* divide-and-conquer algorithm.
* @param m1 The matrix to multiply with.
* @param m2 The matrix to multiply with.
* @return The resulting matrix m1 * m2
*/
public static Matrix recursiveMultiply(Matrix m1, Matrix m2) {
return Matrix.recursiveMultiply(m1, m2, 0, 0, 0, 0, m1.dimension);
}
/**
* Multiple two matrices recursively and parallely with
* divide-and-conquer algorithm.
* @param m1 The matrix to multiply with.
* @param m2 The matrix to multiply with.
* @return The resulting matrix m1 * m2
*/
public static Matrix parallelMultiply(Matrix m1, Matrix m2) {
return new MatrixMultiplication(m1, m2, 0, 0, 0, 0, m1.dimension)
.compute();
}
}
import java.util.function.Supplier;
import java.util.stream.DoubleStream;
import java.util.Random;
import java.util.Scanner;
import java.time.Instant;
import java.time.Duration;
Главная
/**
* Main is the main driver class for testing matrix multiplication.
* Usage: java Main n
* 2^n is the dimension of the square matrixOne
*/
class Main {
public static void main(String[] args) {
int n = (new Scanner(System.in)).nextInt();
Random random = new Random(1);
int dimension = 1 << n;
System.out.println("dimension " + dimension);
Matrix matrixOne = Matrix.generate(dimension, () -> random.nextDouble());
Matrix matrixTwo = Matrix.generate(dimension, () -> random.nextDouble());
Matrix result1 = Matrix.nonRecursiveMultiply(matrixOne, matrixTwo);
Matrix result2 = Matrix.parallelMultiply(matrixOne, matrixTwo);
boolean match = Matrix.equals(result1, result2);
if (!match) {
System.out.println("ERROR: matrix multiplication gives inconsistent " +
"result in sequential and parallel implementations.");
return;
}
double d1 = measureTimeToRun(() -> Matrix.nonRecursiveMultiply(matrixOne, matrixTwo));
double d2 = measureTimeToRun(() -> Matrix.parallelMultiply(matrixOne, matrixTwo));
System.out.printf("Parallel %.3f ms Sequential %.3f ms Speedup %.3f times\n", d2, d1, d1 / d2);
}
/**
* Return the average time needed to run the task over three runs.
* @param task A lambda expression for the task to be run
* @return The average time taken in ms.
*/
private static double measureTimeToRun(Supplier<Matrix> task) {
final int numOfTimes = 3;
double sum = 0;
for (int i = 0; i < numOfTimes; i++) {
Instant start = Instant.now();
Matrix m = task.get();
Instant stop = Instant.now();
sum += Duration.between(start, stop).toMillis();
}
return sum / numOfTimes;
}
}
Редактировать: Добавлено в вывод n = 11
из вывода программы на случай, если это вызовет некоторую путаницу.