Я получил ошибку NaN при использовании библиотеки Encog 3.0 - PullRequest
1 голос
/ 08 января 2012

У меня проблема при построении рекуррентной нейронной сети elman с использованием библиотеки encog 3.0.Я получил ошибку NaN после нескольких итераций.

Это мой код в MATLAB:

net=newlrn (minmax(usdjpytrain),[4,1], {'logsig', 'purelin'}, 'trainscg', 'learngdm', 'mse'); 
net.trainParam.epochs = 1000; 
net.trainParam.goal = 1e-5;

Я хочу создать рекуррентную нейронную сеть, как описано выше в Java.по умолчанию значение скорости обучения в MATLAB составляет 0,01, а импульс 0,9.но я получаю NaN ошибку, если я использую скорость обучения и импульс, как это.это мой код в Java:

/*
 * To change this template, choose Tools | Templates
 * and open the template in the editor.
 */

package modifikasiJSR;

import org.encog.Encog;
import org.encog.engine.network.activation.ActivationSigmoid;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLDataSet;
import org.encog.ml.train.MLTrain;
import org.encog.neural.networks.BasicNetwork;
import org.encog.neural.networks.training.CalculateScore;
import org.encog.neural.networks.training.TrainingSetScore;
import org.encog.neural.networks.training.propagation.back.Backpropagation;
import org.encog.neural.pattern.ElmanPattern;
import dataSet.DataDefault;
import database.Parameter;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.train.strategy.HybridStrategy;
import org.encog.neural.networks.structure.AnalyzeNetwork;
import org.encog.neural.networks.training.propagation.scg.ScaledConjugateGradient;

/**
 *
 * @author Akmal
 */
public class MseTrain{
    private String mataUang;
    private int jumlahDataSet;
    Parameter param = new Parameter();
    private double [] bias_hid_out = new double[2];
    private double [] bias_hid_in = new double[param.getHidden_node()];
    private double [] bobot_hid_out = new double[param.getHidden_node()];
    private double [] bobot_hid_in = new double[7*param.getHidden_node()];


    public MseTrain(String mataUang){
        this.mataUang = mataUang;
    }

    static BasicNetwork createElmanNetwork() {
        // construct an Elman type network
        Parameter param = new Parameter();
        ElmanPattern pattern = new ElmanPattern();
        pattern.setActivationFunction(new ActivationSigmoid());
        pattern.setInputNeurons(7);
        pattern.addHiddenLayer(param.getHidden_node());
        pattern.setOutputNeurons(1);
        return (BasicNetwork)pattern.generate();
    }

    public void mse(double[][] data, double [][] target) {
        System.out.println("menjalankan method MSE");
        DataDefault dataSet = new DataDefault(mataUang);
        jumlahDataSet = dataSet.getJumTraining();
        final MLDataSet trainingSet = new BasicMLDataSet(data, target);
        final BasicNetwork elmanNetwork = MseTrain.createElmanNetwork();
        final double elmanError = trainNetwork("Elman", elmanNetwork,trainingSet);
        System.out.println("Best error rate with Elman Network: " + elmanError);
        Encog.getInstance().shutdown();
    }

    private double trainNetwork(final String what,final BasicNetwork network, final MLDataSet trainingSet) {
        network.reset();
        // train the neural network
        System.out.println(network.getActivation(0));
        System.out.println(network.getActivation(1));
        System.out.println(network.getActivation(2));

        CalculateScore score = new TrainingSetScore(trainingSet);

        final MLTrain trainMain = new Backpropagation(network, trainingSet, 0.01, 0.9);
        final MLTrain trainAlt = new ScaledConjugateGradient(network, trainingSet);
        trainMain.addStrategy(new HybridStrategy(trainAlt));

        //set Bobot hid In
        network.setWeight(0, 0, 0, 3.0439);
        network.setWeight(0, 0, 1, 2.8421);
        network.setWeight(0, 0, 2, -2.8529);
        network.setWeight(0, 0, 3, 2.7013);

        network.setWeight(0, 1, 0, 1.2741);
        network.setWeight(0, 1, 1, -2.8055);
        network.setWeight(0, 1, 2, -1.6862);
        network.setWeight(0, 1, 3, 0.3049);

        network.setWeight(0, 2, 0, 4.5054);
        network.setWeight(0, 2, 1, 3.3152);
        network.setWeight(0, 2, 2, -2.6663);
        network.setWeight(0, 2, 3, 3.1311);

        network.setWeight(0, 3, 0, 4.4113);
        network.setWeight(0, 3, 1, -0.1022);
        network.setWeight(0, 3, 2, 2.2913);
        network.setWeight(0, 3, 3, -2.3347);

        network.setWeight(0, 4, 0, -1.2521);
        network.setWeight(0, 4, 1, 4.8178);
        network.setWeight(0, 4, 2, 3.6979);
        network.setWeight(0, 4, 3, 4.9682);

        network.setWeight(0, 5, 0, 1.4512);
        network.setWeight(0, 5, 1, -3.1329);
        network.setWeight(0, 5, 2, 2.5726);
        network.setWeight(0, 5, 3, 2.7323);

        network.setWeight(0, 6, 0, 2.3088);
        network.setWeight(0, 6, 1, 2.4109);
        network.setWeight(0, 6, 2, 2.4835);
        network.setWeight(0, 6, 3, -0.9406);

        //set Bias hid In
        network.setWeight(0, 7, 0, -9.3723);
        network.setWeight(0, 7, 1, -2.1356);
        network.setWeight(0, 7, 2, -5.0966);
        network.setWeight(0, 7, 3, -1.5829);

        //set Bobot Hid out
        network.setWeight(1, 0, 0, -0.6263);
        network.setWeight(1, 1, 0, -0.0205);
        network.setWeight(1, 2, 0, -0.1088);
        network.setWeight(1, 3, 0, 0.2926);

        //set Bias Hid out
        network.setWeight(1, 4, 0, 0.4187);

        int epoch = 0;
        do{
            trainMain.iteration();
            System.out.println("Training " + what + ", Epoch #" + epoch + " Error:" + trainMain.getError());
            System.out.println("Error:" + trainMain.getError());
            epoch++;
        }while(trainMain.getError()>0.00001&&epoch<=1);

        System.out.println("Neural Network Results:");
        for(MLDataPair pair: trainingSet ) {
            final MLData output = network.compute(pair.getInput());
            System.out.println(
            pair.getInput().getData(0)
            + ","
            + pair.getInput().getData(1)
            + ","
            + pair.getInput().getData(2)
            + ","
            + pair.getInput().getData(3)
            + ","
            + pair.getInput().getData(4)
            + ","
            + pair.getInput().getData(5)
            + ","
            + pair.getInput().getData(6)
            + ", actual="
            + output.getData(0)
            + ",ideal="
            + pair.getIdeal().getData(0));
        }

        System.out.println("bobot : "+network.dumpWeights());

        AnalyzeNetwork bobotBias = new AnalyzeNetwork(network);
        int jumBobotIn = param.getHidden_node()*7;
        System.out.println("jumBobotIn : "+jumBobotIn);
        int indexBobotIn = 0;
        int indexBobotOut = 0;
        for(int i=0; i<bobotBias.getWeightValues().length; i++){
            System.out.println("bobot : "+bobotBias.getWeightValues()[i]);
            if(i<jumBobotIn){
                bobot_hid_in[indexBobotIn] = bobotBias.getWeightValues()[i];
                indexBobotIn++;
            }
            else{
                bobot_hid_out[indexBobotOut] = bobotBias.getWeightValues()[i];
                indexBobotOut++;
            }
        }
        int indexBiasIn = 0;
        int indexBiasOut = 0;
        int jumBiasIn = param.getHidden_node();
        for(int i=0; i<bobotBias.getBiasValues().length; i++){
            System.out.println("bias : "+bobotBias.getBiasValues()[i]);
            if(i<jumBiasIn){
                bias_hid_in[indexBiasIn] = bobotBias.getBiasValues()[i];
                indexBiasIn++;
            }
            else{
                bias_hid_out[indexBiasOut] = bobotBias.getBiasValues()[i];
                indexBiasOut++;
            }
        }
        return trainMain.getError();
    }

    /**
     * @return the bias_hid_out
     */
    public double[] getBias_hid_out() {
        return bias_hid_out;
    }

    /**
     * @return the bias_hid_in
     */
    public double[] getBias_hid_in() {
        return bias_hid_in;
    }

    /**
     * @return the bobot_hid_out
     */
    public double[] getBobot_hid_out() {
        return bobot_hid_out;
    }

    /**
     * @return the bobot_hid_in
     */
    public double[] getBobot_hid_in() {
        return bobot_hid_in;
    }
}

всем, пожалуйста, помогите мне, я застрял здесь на несколько дней.TT

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...