Использование Neuroph в Java - PullRequest
       13

Использование Neuroph в Java

0 голосов
/ 16 ноября 2018

Я пытаюсь перенести программу из Matlab в Java.Matlab был использован, потому что он имеет очень полную реализацию нейронной сети.Я хочу сейчас перенести проект на Java.Я ищу всеобъемлющую библиотеку на Java и наткнулся на Neuroph.Поэтому для начала мне нужно запустить очень простой пример, чтобы убедиться, что все работает, прежде чем пытаться все портировать.Я наткнулся на этот урок.https://www.baeldung.com/neuroph. Я пытался реализовать это в Eclipse.Реализация не имеет ошибок, в результате чего базовый NN неверен.Я ожидаю одного для этого примера и всегда получаю ноль.

Тестирование: 1, 0 Ожидаемое: 1,0 Результат: 0,0 Тестирование: 0, 1 Ожидаемое: 1,0 Результат: 0,0 Тестирование: 1, 1 Ожидаемое: 0,0 Результат:0.0 Тестирование: 0, 0 Ожидаемое: 0.0 Результат: 0.0

Кто-нибудь может подсказать, почему NN настроен неправильно?Большое спасибо

import org.neuroph.core.*;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.data.DataSetRow;
import org.neuroph.nnet.learning.BackPropagation;
import org.neuroph.util.*;


public class NeuralNetworkExample {

public static void main(String[] args) {






    Layer inputLayer = new Layer();
    inputLayer.addNeuron(new Neuron());
    inputLayer.addNeuron(new Neuron());


    Layer hiddenLayerOne = new Layer();
    hiddenLayerOne.addNeuron(new Neuron());
    hiddenLayerOne.addNeuron(new Neuron());
    hiddenLayerOne.addNeuron(new Neuron());
    hiddenLayerOne.addNeuron(new Neuron());

    Layer hiddenLayerTwo = new Layer(); 
    hiddenLayerTwo.addNeuron(new Neuron()); 
    hiddenLayerTwo.addNeuron(new Neuron()); 
    hiddenLayerTwo.addNeuron(new Neuron()); 
    hiddenLayerTwo.addNeuron(new Neuron());

    Layer outputLayer = new Layer();
    outputLayer.addNeuron(new Neuron());


    NeuralNetwork<BackPropagation> ann = new NeuralNetwork<BackPropagation>();
    ann.addLayer(0, inputLayer);
    ann.addLayer(1, hiddenLayerOne);
    ConnectionFactory.fullConnect(ann.getLayerAt(0), ann.getLayerAt(1));
    ann.addLayer(2, hiddenLayerTwo);
    ConnectionFactory.fullConnect(ann.getLayerAt(1), ann.getLayerAt(2));
    ann.addLayer(3, outputLayer);
    ConnectionFactory.fullConnect(ann.getLayerAt(2), ann.getLayerAt(3));
    ConnectionFactory.fullConnect(ann.getLayerAt(0), 
      ann.getLayerAt(ann.getLayersCount()-1), false);
    ann.setInputNeurons(inputLayer.getNeurons());
    ann.setOutputNeurons(outputLayer.getNeurons());


    int input=2;
    int output=1;       
    DataSet ds = new DataSet(input,output);

    DataSetRow rOne   = new DataSetRow(new double[] {0, 1}, new double[] {1});
    ds.addRow(rOne);
    DataSetRow rTwo   = new DataSetRow(new double[] {1, 1}, new double[] {0});
    ds.addRow(rTwo);
    DataSetRow rThree = new DataSetRow(new double[] {0, 0}, new double[] {0});
    ds.addRow(rThree);
    DataSetRow rFour  = new DataSetRow(new double[] {1, 0}, new double[] {1});
    ds.addRow(rFour);


    BackPropagation backPropagation = new BackPropagation();
    backPropagation.setMaxIterations(1000);
    ann.learn(ds,backPropagation);






    ann.setInput(1,0);
    ann.calculate();
    double[] out = ann.getOutput();
    System.out.println(out[0]);





}

}
...