Java результаты библиотек линейной алгебры неверны - PullRequest
1 голос
/ 14 июля 2020

У меня есть алгоритм, реализованный в Python, использующий numpy библиотеку для линейной алгебры. Я хочу реализовать это в приложении Java для Android, я пробовал много библиотек, таких как Jama.

Алгоритм проверки того, имеет ли набор 3D-точек тот же образец, что и другой набор.

это реализация в Python, которая отлично работает:

def compute_similarity(original_shape, shape, nearest_neighbors= False, debug= False):    
    reference_shape = np.copy(original_shape)
    # calculate rotation matrix 3 x 3
    U, _, V = svd(reference_shape.dot(shape.T))
    rotation_matrix = U.dot(V)
    rotated_shape = np.dot(rotation_matrix, shape)
    # Perform procrustes analysis
    ref_shape_mean_vals = np.mean(reference_shape, axis=1)[:, np.newaxis]
    shape_mean_vals = np.mean(rotated_shape, axis=1)[:, np.newaxis]

    translation = ref_shape_mean_vals - scale * shape_mean_vals
    # Fit shape
    fitted_shape = np.dot(rotation_matrix, shape) + translation
  
    # Euclidean metric
    score = np.linalg.norm(reference_shape-fitted_shape)    
    return score

если точки имеют одинаковый узор, оценка должна быть небольшим значением меньше 1.

Это мой реализация в Java:

package com.moussa.zomzoom.Player;

import Jama.Matrix;
import Jama.SingularValueDecomposition;

public class Solver {


public static boolean isFit(double[][] _reference_shape,double[][] _shape) {
    boolean result = true;
    double threshold = 50;
    // calculate rotation matrix 3 x 3
    Matrix reference_shape = new Matrix(_reference_shape);
    Matrix shape = new Matrix(_shape);
    Matrix dot = reference_shape.times(shape.transpose());
    SingularValueDecomposition SVD = new SingularValueDecomposition(dot);
    Matrix U = SVD.getU();
    Matrix V = SVD.getV();
    Matrix rotation_matrix = U.times(V);
    Matrix rotated_shape = rotation_matrix.times(shape);

    // Perform procrustes analysis
    double score = Double.MAX_VALUE;
    Matrix reference_shape_mean = mean(reference_shape);
    Matrix shape_mean = mean(rotated_shape);

    // double scale = reference_shape.minus(reference_shape_mean).norm1() / shape.minus(shape_mean).norm1();
    //Matrix translation = reference_shape_mean.minus(shape_mean.times(scale));
    Matrix translation = reference_shape_mean.minus(shape_mean);

    // Fit shape
    //Matrix fitted_shape = rotation_matrix.times(shape).times(scale).plus(translation);
    Matrix fitted_shape = rotated_shape.plus(translation);
    // Euclidean metric
    //double euclidean_dist = fitted_shape. distance2(reference_shape);
    score = reference_shape.minus(fitted_shape).norm1();

    result = score < threshold;
    return result;
}

private static Matrix mean(Matrix matrix){
    double[]mean_values=new double[3];
    for(int i=0;i<matrix.getRowDimension();i++){
        mean_values[0]+=matrix.get(i,0);
        mean_values[1]+=matrix.get(i,1);
        mean_values[2]+=matrix.get(i,2);
    }
    mean_values[0]/=matrix.getRowDimension();
    mean_values[1]/=matrix.getRowDimension();
    mean_values[2]/=matrix.getRowDimension();
    double[][]m=new double[matrix.getRowDimension()][matrix.getColumnDimension()];
    for(int i=0;i<matrix.getRowDimension();i++){
        m[i]=mean_values;
    }
    return new Matrix(m);
}

}

Тот же тест в Java приводит к очень большому количеству очков> 2000. Я заметил, что значения SVD в Jama отличаются от Python.

Я пробовал использовать множество Java библиотек линейной алгебры, например Apache common-maths, и получил тот же результат.

...