Исправить обучение LSTM с deeplearning4j - PullRequest
0 голосов
/ 27 апреля 2019

Я пытаюсь построить модель для классификации 5 видов человеческой деятельности на основе данных трехосного акселерометра с нейронной сетью LSTM.Я сконфигурировал свою модель, основываясь на примерах deeplearnin4j, но я думаю, что обучение не работает должным образом, потому что количество баллов после итераций не уменьшается после любого количества эпох.

Журналы тренировок в Консоли можно увидеть в следующем блоке:

Scanning for projects...

-------------------------< com.mycompany:LSTM >-------------------------
Building LSTM 1.0-SNAPSHOT
--------------------------------[ jar ]---------------------------------

--- exec-maven-plugin:1.2.1:exec (default-cli) @ LSTM ---
13:10:01.334 [main] INFO org.nd4j.linalg.factory.Nd4jBackend - Loaded [JCublasBackend] backend
13:10:04.058 [main] INFO org.nd4j.nativeblas.NativeOpsHolder - Number of threads used for NativeOps: 32
13:10:05.568 [main] INFO org.nd4j.nativeblas.Nd4jBlas - Number of threads used for BLAS: 0
13:10:05.572 [main] INFO org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner - Backend used: [CUDA]; OS: [Windows 10]
13:10:05.572 [main] INFO org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner - Cores: [12]; Memory: [3,5GB];
13:10:05.572 [main] INFO org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner - Blas vendor: [CUBLAS]
13:10:05.572 [main] INFO org.nd4j.linalg.jcublas.ops.executioner.CudaExecutioner - Device Name: [GeForce GTX 1050 Ti]; CC: [6.1]; Total/free memory: [4294967296]
13:10:05.711 [main] DEBUG org.nd4j.jita.handler.impl.CudaZeroHandler - Creating bucketID: 0
13:10:05.720 [main] DEBUG org.nd4j.jita.handler.impl.CudaZeroHandler - Creating bucketID: 2
13:10:05.725 [main] DEBUG org.nd4j.jita.handler.impl.CudaZeroHandler - Creating bucketID: 1
13:10:05.725 [main] DEBUG org.nd4j.jita.handler.impl.CudaZeroHandler - Creating bucketID: 5
13:10:06.160 [main] DEBUG org.nd4j.jita.handler.impl.CudaZeroHandler - Creating bucketID: 4
13:10:06.163 [main] DEBUG org.nd4j.jita.handler.impl.CudaZeroHandler - Creating bucketID: 3
13:10:27.140 [main] INFO org.deeplearning4j.nn.multilayer.MultiLayerNetwork - Starting MultiLayerNetwork with WorkspaceModes set to [training: ENABLED; inference: ENABLED], cacheMode set to [NONE]
13:10:28.832 [main] DEBUG org.deeplearning4j.nn.layers.recurrent.LSTM - CudnnLSTMHelper successfully initialized
13:10:28.855 [ADSI prefetch thread] DEBUG org.nd4j.linalg.memory.abstracts.Nd4jWorkspace - Steps: 5
13:10:32.604 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 0 is 1.5328358650207519
13:10:49.904 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 20 is 1.4521272659301758
13:11:07.154 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 40 is 1.0928638458251954
13:11:22.941 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 60 is 1.0838352203369142
13:11:36.821 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 80 is 1.5488048553466798
13:11:48.534 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 100 is 0.9116960525512695
13:12:08.894 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 120 is 1.2259522438049317
13:12:15.793 [main] DEBUG org.deeplearning4j.datasets.iterator.AsyncDataSetIterator - Manually destroying ADSI workspace
13:12:15.926 [ADSI prefetch thread] DEBUG org.nd4j.linalg.memory.abstracts.Nd4jWorkspace - Steps: 5
13:12:37.306 [main] DEBUG org.deeplearning4j.datasets.iterator.AsyncDataSetIterator - Manually destroying ADSI workspace
13:12:37.327 [main] INFO com.mycompany.lstm.UCISequenceClassificationExample - 

========================Evaluation Metrics========================
 # of classes:    5
 Accuracy:        0,4583
 Precision:       0,5973    (2 classes excluded from average)
 Recall:          0,3800
 F1 Score:        0,6127    (2 classes excluded from average)
Precision, recall & F1: macro-averaged (equally weighted avg. of 5 classes)

Warning: 2 classes were never predicted by the model and were excluded from average precision
Classes excluded from average precision: [0, 1]

=========================Confusion Matrix=========================
  0  1  2  3  4
----------------
  0  0 15  4  0 | 0 = 0
  0  0 13  6  0 | 1 = 1
  0  0 50 53  0 | 2 = 2
  0  0 49 53  1 | 3 = 3
  0  0  2  0 18 | 4 = 4

Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times
==================================================================
13:12:37.327 [ADSI prefetch thread] DEBUG org.nd4j.linalg.memory.abstracts.Nd4jWorkspace - Steps: 5
13:12:53.130 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 140 is 2.0276023864746096
13:13:10.387 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 160 is 1.5625686645507812
13:13:25.929 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 180 is 1.5873197555541991
13:13:40.327 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 200 is 0.9586276054382324
13:13:55.432 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 220 is 0.7576540946960449
13:14:11.708 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 240 is 0.875974464416504
13:14:26.234 [main] DEBUG org.deeplearning4j.datasets.iterator.AsyncDataSetIterator - Manually destroying ADSI workspace
13:14:26.241 [ADSI prefetch thread] DEBUG org.nd4j.linalg.memory.abstracts.Nd4jWorkspace - Steps: 5
13:14:47.523 [main] DEBUG org.deeplearning4j.datasets.iterator.AsyncDataSetIterator - Manually destroying ADSI workspace
13:14:47.531 [main] INFO com.mycompany.lstm.UCISequenceClassificationExample - 

========================Evaluation Metrics========================
 # of classes:    5
 Accuracy:        0,4280
 Precision:       0,5372    (2 classes excluded from average)
 Recall:          0,3806
 F1 Score:        0,5673    (2 classes excluded from average)
Precision, recall & F1: macro-averaged (equally weighted avg. of 5 classes)

Warning: 2 classes were never predicted by the model and were excluded from average precision
Classes excluded from average precision: [0, 1]

=========================Confusion Matrix=========================
  0  1  2  3  4
----------------
  0  0 14  5  0 | 0 = 0
  0  0  6 13  0 | 1 = 1
  0  0 18 85  0 | 2 = 2
  0  0 26 75  2 | 3 = 3
  0  0  0  0 20 | 4 = 4

Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times
==================================================================
13:14:47.531 [ADSI prefetch thread] DEBUG org.nd4j.linalg.memory.abstracts.Nd4jWorkspace - Steps: 5
13:14:58.324 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 260 is 0.8601213455200195
13:15:16.516 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 280 is 0.9153075218200684
13:15:32.987 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 300 is 1.5319385528564453
13:15:47.603 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 320 is 0.8646750450134277
13:16:02.376 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 340 is 1.4645065307617187
13:16:17.640 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 360 is 0.8540505409240723
13:16:36.874 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 380 is 1.950326681137085
13:16:36.874 [main] DEBUG org.deeplearning4j.datasets.iterator.AsyncDataSetIterator - Manually destroying ADSI workspace
13:16:36.882 [ADSI prefetch thread] DEBUG org.nd4j.linalg.memory.abstracts.Nd4jWorkspace - Steps: 5
13:16:57.005 [main] DEBUG org.deeplearning4j.datasets.iterator.AsyncDataSetIterator - Manually destroying ADSI workspace
13:16:57.012 [main] INFO com.mycompany.lstm.UCISequenceClassificationExample - 

========================Evaluation Metrics========================
 # of classes:    5
 Accuracy:        0,4205
 Precision:       0,4696    (2 classes excluded from average)
 Recall:          0,3767
 F1 Score:        0,5170    (2 classes excluded from average)
Precision, recall & F1: macro-averaged (equally weighted avg. of 5 classes)

Warning: 2 classes were never predicted by the model and were excluded from average precision
Classes excluded from average precision: [0, 1]

=========================Confusion Matrix=========================
  0  1  2  3  4
----------------
  0  0  9  7  3 | 0 = 0
  0  0  6 12  1 | 1 = 1
  0  0 12 91  0 | 2 = 2
  0  0 21 79  3 | 3 = 3
  0  0  0  0 20 | 4 = 4

Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times
==================================================================
13:16:57.012 [ADSI prefetch thread] DEBUG org.nd4j.linalg.memory.abstracts.Nd4jWorkspace - Steps: 5
13:17:18.203 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 400 is 1.320208740234375
13:17:35.460 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 420 is 1.6308589935302735
13:17:52.665 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 440 is 1.1166643142700194
13:18:04.471 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 460 is 0.9357609748840332
13:18:19.177 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 480 is 1.330645751953125
13:18:38.274 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 500 is 0.7987275123596191
13:18:47.534 [main] DEBUG org.deeplearning4j.datasets.iterator.AsyncDataSetIterator - Manually destroying ADSI workspace
13:18:47.541 [ADSI prefetch thread] DEBUG org.nd4j.linalg.memory.abstracts.Nd4jWorkspace - Steps: 5
13:19:06.557 [main] DEBUG org.deeplearning4j.datasets.iterator.AsyncDataSetIterator - Manually destroying ADSI workspace
13:19:06.563 [main] INFO com.mycompany.lstm.UCISequenceClassificationExample - 

========================Evaluation Metrics========================
 # of classes:    5
 Accuracy:        0,4356
 Precision:       0,4292    (1 class excluded from average)
 Recall:          0,3931
 F1 Score:        0,4009    (1 class excluded from average)
Precision, recall & F1: macro-averaged (equally weighted avg. of 5 classes)

Warning: 1 class was never predicted by the model and was excluded from average precision
Classes excluded from average precision: [1]

=========================Confusion Matrix=========================
  0  1  2  3  4
----------------
  1  0  6  7  5 | 0 = 0
  1  0  5 12  1 | 1 = 1
  0  0  9 94  0 | 2 = 2
  1  0 14 85  3 | 3 = 3
  0  0  0  0 20 | 4 = 4

Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times
==================================================================
13:19:06.564 [ADSI prefetch thread] DEBUG org.nd4j.linalg.memory.abstracts.Nd4jWorkspace - Steps: 5
13:19:21.384 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 520 is 1.4552995681762695
13:19:38.909 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 540 is 0.7883223533630371
13:19:54.437 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 560 is 0.7487533569335938
13:20:09.842 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 580 is 0.8777185440063476
13:20:24.370 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 600 is 1.3696569442749023
13:20:40.463 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 620 is 0.7666335105895996
13:20:55.188 [main] DEBUG org.deeplearning4j.datasets.iterator.AsyncDataSetIterator - Manually destroying ADSI workspace
13:20:55.195 [ADSI prefetch thread] DEBUG org.nd4j.linalg.memory.abstracts.Nd4jWorkspace - Steps: 5
13:21:13.798 [main] DEBUG org.deeplearning4j.datasets.iterator.AsyncDataSetIterator - Manually destroying ADSI workspace
13:21:13.805 [main] INFO com.mycompany.lstm.UCISequenceClassificationExample - 

========================Evaluation Metrics========================
 # of classes:    5
 Accuracy:        0,4545
 Precision:       0,5491    (1 class excluded from average)
 Recall:          0,4285
 F1 Score:        0,4959    (1 class excluded from average)
Precision, recall & F1: macro-averaged (equally weighted avg. of 5 classes)

Warning: 1 class was never predicted by the model and was excluded from average precision
Classes excluded from average precision: [1]

=========================Confusion Matrix=========================
  0  1  2  3  4
----------------
  4  0  8  7  0 | 0 = 0
  1  0  5 12  1 | 1 = 1
  0  0 11 92  0 | 2 = 2
  2  0 15 85  1 | 3 = 3
  0  0  0  0 20 | 4 = 4

Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times
==================================================================
13:21:13.806 [ADSI prefetch thread] DEBUG org.nd4j.linalg.memory.abstracts.Nd4jWorkspace - Steps: 5
13:21:23.939 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 640 is 1.6179420471191406
13:21:41.866 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 660 is 1.214307975769043
13:21:56.973 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 680 is 0.922851276397705
13:22:12.932 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 700 is 0.8825350761413574
13:22:26.289 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 720 is 0.8634878158569336
13:22:42.122 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 740 is 1.2621063232421874
13:23:00.384 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 760 is 0.772999382019043
13:23:01.654 [main] DEBUG org.deeplearning4j.datasets.iterator.AsyncDataSetIterator - Manually destroying ADSI workspace
13:23:01.661 [ADSI prefetch thread] DEBUG org.nd4j.linalg.memory.abstracts.Nd4jWorkspace - Steps: 5
13:23:20.198 [main] DEBUG org.deeplearning4j.datasets.iterator.AsyncDataSetIterator - Manually destroying ADSI workspace
13:23:20.205 [main] INFO com.mycompany.lstm.UCISequenceClassificationExample - 

========================Evaluation Metrics========================
 # of classes:    5
 Accuracy:        0,4924
 Precision:       0,5578    (1 class excluded from average)
 Recall:          0,4651
 F1 Score:        0,5434    (1 class excluded from average)
Precision, recall & F1: macro-averaged (equally weighted avg. of 5 classes)

Warning: 1 class was never predicted by the model and was excluded from average precision
Classes excluded from average precision: [1]

=========================Confusion Matrix=========================
  0  1  2  3  4
----------------
  6  0  3  8  2 | 0 = 0
  0  0  7 11  1 | 1 = 1
  0  0 31 72  0 | 2 = 2
  5  0 21 73  4 | 3 = 3
  0  0  0  0 20 | 4 = 4

Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times
==================================================================
13:23:20.206 [ADSI prefetch thread] DEBUG org.nd4j.linalg.memory.abstracts.Nd4jWorkspace - Steps: 5
13:23:38.008 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 780 is 1.0697845458984374
13:23:55.367 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 800 is 0.7442119121551514
13:24:14.028 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 820 is 1.108755111694336
13:24:25.903 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 840 is 0.6914281845092773
13:24:39.838 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 860 is 0.6886839866638184
13:24:57.245 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 880 is 0.883371639251709
13:25:08.683 [main] DEBUG org.deeplearning4j.datasets.iterator.AsyncDataSetIterator - Manually destroying ADSI workspace
13:25:08.690 [ADSI prefetch thread] DEBUG org.nd4j.linalg.memory.abstracts.Nd4jWorkspace - Steps: 5
13:25:27.483 [main] DEBUG org.deeplearning4j.datasets.iterator.AsyncDataSetIterator - Manually destroying ADSI workspace
13:25:27.490 [main] INFO com.mycompany.lstm.UCISequenceClassificationExample - 

========================Evaluation Metrics========================
 # of classes:    5
 Accuracy:        0,4318
 Precision:       0,5647    (1 class excluded from average)
 Recall:          0,4169
 F1 Score:        0,5048    (1 class excluded from average)
Precision, recall & F1: macro-averaged (equally weighted avg. of 5 classes)

Warning: 1 class was never predicted by the model and was excluded from average precision
Classes excluded from average precision: [1]

=========================Confusion Matrix=========================
  0  1  2  3  4
----------------
  4  0 13  2  0 | 0 = 0
  0  0 10  9  0 | 1 = 1
  0  0 23 80  0 | 2 = 2
  2  0 31 67  3 | 3 = 3
  0  0  0  0 20 | 4 = 4

Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times
==================================================================
13:25:27.491 [ADSI prefetch thread] DEBUG org.nd4j.linalg.memory.abstracts.Nd4jWorkspace - Steps: 5
13:25:42.416 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 900 is 1.1810382843017577
13:25:59.881 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 920 is 0.8167665481567383
13:26:16.202 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 940 is 1.4128643035888673
13:26:31.338 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 960 is 0.7820789813995361
13:26:44.886 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 980 is 0.8911567687988281
13:27:01.673 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 1000 is 0.9900901794433594
13:27:16.626 [main] DEBUG org.deeplearning4j.datasets.iterator.AsyncDataSetIterator - Manually destroying ADSI workspace
13:27:16.633 [ADSI prefetch thread] DEBUG org.nd4j.linalg.memory.abstracts.Nd4jWorkspace - Steps: 5
13:27:35.392 [main] DEBUG org.deeplearning4j.datasets.iterator.AsyncDataSetIterator - Manually destroying ADSI workspace
13:27:35.399 [main] INFO com.mycompany.lstm.UCISequenceClassificationExample - 

На входе у меня есть 2 файла для каждого образца:

  • в файле функций Iимеют значения ускорения для 3 осей, разделенных (3 столбца), и в каждой строке есть 3 значения (x, y, z ось), количество строк в каждом файле объектов образца отличается для каждого образца.Например:
-1.4757843000000002 0.9027405000000001  8.998032
-1.4912566999999999 0.98605347  8.887344
-1.5733795  0.95510864  8.881393
-1.6447906000000003 0.93130493  8.804031
-1.6876373  0.9931946   8.838547
-1.7602386000000003 0.9515380999999999  8.852829
-1.7102509  1.0074768   8.79332
-1.7804718000000002 0.9920044   8.881393
-1.7590485  1.0110474   8.854019000000001
-1.7126312  1.0503235   8.889725
-1.7221526999999999 1.1217346000000001  8.948044
-1.6543120999999998 1.1276855   8.927811
-1.6590729  1.0872191999999998  9.105148
-1.6626433999999997 1.1622008999999998  9.218216 
  • в файле метки У меня есть только одно значение типа int в первой строке, указывающее ожидаемую выходную активность

Ниже я помещу всю своюкод LSTM.

       SequenceRecordReader trainFeatures = new CSVSequenceRecordReader(0, csvSplitBy);

        SequenceRecordReader trainLabels = new CSVSequenceRecordReader();

        try {

            trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTrain.getAbsolutePath() + "/%d.csv", 0, trainCount-1));

            trainLabels.initialize(new NumberedFileInputSplit(labelsDirTrain.getAbsolutePath() + "/%d.csv", 0, trainCount-1));

        } catch (IOException ex) {

            java.util.logging.Logger.getLogger(ExternalDatasetVersion.class.getName()).log(Level.SEVERE, null, ex);

        } catch (InterruptedException ex) {

            java.util.logging.Logger.getLogger(ExternalDatasetVersion.class.getName()).log(Level.SEVERE, null, ex);

        }



        int miniBatchSize = 10;

        int numLabelClasses = 5;    

        DataSetIterator trainData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels, miniBatchSize, numLabelClasses,

            false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);



        //Normalize the training data

        DataNormalization normalizer = new NormalizerStandardize();

        normalizer.fit(trainData);              //Collect training data statistics

        trainData.reset();



        //Use previously collected statistics to normalize on-the-fly. Each DataSet returned by 'trainData' iterator will be normalized

        trainData.setPreProcessor(normalizer);





        // ----- Load the test data -----

        //Same process as for the training data.

        SequenceRecordReader testFeatures = new CSVSequenceRecordReader(0, csvSplitBy);

        SequenceRecordReader testLabels = new CSVSequenceRecordReader();

        try {

            testFeatures.initialize(new NumberedFileInputSplit(featuresDirTest.getAbsolutePath() + "/%d.csv", 0, testCount-1));        

            testLabels.initialize(new NumberedFileInputSplit(labelsDirTest.getAbsolutePath() + "/%d.csv", 0, testCount-1));

        } catch (IOException ex) {

            java.util.logging.Logger.getLogger(ExternalDatasetVersion.class.getName()).log(Level.SEVERE, null, ex);

        } catch (InterruptedException ex) {

            java.util.logging.Logger.getLogger(ExternalDatasetVersion.class.getName()).log(Level.SEVERE, null, ex);

        } 

        DataSetIterator testData = new SequenceRecordReaderDataSetIterator(testFeatures, testLabels, miniBatchSize, numLabelClasses,

            false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);



        testData.setPreProcessor(normalizer);   


        // ----- Configure the network -----

        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()

                .seed(123)    //Random number generator seed for improved repeatability. Optional.

                .weightInit(WeightInit.XAVIER)

                .updater(new Nesterovs(0.005))

                .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)  //Not always required, but helps with this data set

                .gradientNormalizationThreshold(0.5)

                .list()

                .layer(0, new LSTM.Builder().activation(Activation.TANH).nIn(3).nOut(10).build())

                .layer(1, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT)

                        .activation(Activation.SOFTMAX).nIn(10).nOut(numLabelClasses).build())

                .build();



        MultiLayerNetwork net = new MultiLayerNetwork(conf);

        net.init();



        net.setListeners(new ScoreIterationListener(20));   //Print the score (loss function value) every 20 iterations





        // ----- Train the network, evaluating the test set performance at each epoch -----

        int nEpochs = 200;

        String str = "Test set evaluation at epoch %d: Accuracy = %.2f, F1 = %.2f";

        for (int i = 0; i < nEpochs; i++) {

            net.fit(trainData);



            //Evaluate on the test set:

            //Evaluation evaluation = net.evaluate(testData);            

            //log.info(String.format(str, i, evaluation.accuracy(), evaluation.f1()));



            testData.reset();

            trainData.reset();

        }

        Evaluation evaluation = net.evaluate(testData); 

        log.info(evaluation.stats());



        log.info("----- Example Complete -----");

Может кто-нибудь помочь мне с этим, пожалуйста?

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...