Я пытаюсь построить модель для классификации 5 видов человеческой деятельности на основе данных трехосного акселерометра с нейронной сетью LSTM.Я сконфигурировал свою модель, основываясь на примерах deeplearnin4j, но я думаю, что обучение не работает должным образом, потому что количество баллов после итераций не уменьшается после любого количества эпох.
Журналы тренировок в Консоли можно увидеть в следующем блоке:
Scanning for projects...
-------------------------< com.mycompany:LSTM >-------------------------
Building LSTM 1.0-SNAPSHOT
--------------------------------[ jar ]---------------------------------
--- exec-maven-plugin:1.2.1:exec (default-cli) @ LSTM ---
13:10:01.334 [main] INFO org.nd4j.linalg.factory.Nd4jBackend - Loaded [JCublasBackend] backend
13:10:04.058 [main] INFO org.nd4j.nativeblas.NativeOpsHolder - Number of threads used for NativeOps: 32
13:10:05.568 [main] INFO org.nd4j.nativeblas.Nd4jBlas - Number of threads used for BLAS: 0
13:10:05.572 [main] INFO org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner - Backend used: [CUDA]; OS: [Windows 10]
13:10:05.572 [main] INFO org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner - Cores: [12]; Memory: [3,5GB];
13:10:05.572 [main] INFO org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner - Blas vendor: [CUBLAS]
13:10:05.572 [main] INFO org.nd4j.linalg.jcublas.ops.executioner.CudaExecutioner - Device Name: [GeForce GTX 1050 Ti]; CC: [6.1]; Total/free memory: [4294967296]
13:10:05.711 [main] DEBUG org.nd4j.jita.handler.impl.CudaZeroHandler - Creating bucketID: 0
13:10:05.720 [main] DEBUG org.nd4j.jita.handler.impl.CudaZeroHandler - Creating bucketID: 2
13:10:05.725 [main] DEBUG org.nd4j.jita.handler.impl.CudaZeroHandler - Creating bucketID: 1
13:10:05.725 [main] DEBUG org.nd4j.jita.handler.impl.CudaZeroHandler - Creating bucketID: 5
13:10:06.160 [main] DEBUG org.nd4j.jita.handler.impl.CudaZeroHandler - Creating bucketID: 4
13:10:06.163 [main] DEBUG org.nd4j.jita.handler.impl.CudaZeroHandler - Creating bucketID: 3
13:10:27.140 [main] INFO org.deeplearning4j.nn.multilayer.MultiLayerNetwork - Starting MultiLayerNetwork with WorkspaceModes set to [training: ENABLED; inference: ENABLED], cacheMode set to [NONE]
13:10:28.832 [main] DEBUG org.deeplearning4j.nn.layers.recurrent.LSTM - CudnnLSTMHelper successfully initialized
13:10:28.855 [ADSI prefetch thread] DEBUG org.nd4j.linalg.memory.abstracts.Nd4jWorkspace - Steps: 5
13:10:32.604 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 0 is 1.5328358650207519
13:10:49.904 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 20 is 1.4521272659301758
13:11:07.154 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 40 is 1.0928638458251954
13:11:22.941 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 60 is 1.0838352203369142
13:11:36.821 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 80 is 1.5488048553466798
13:11:48.534 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 100 is 0.9116960525512695
13:12:08.894 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 120 is 1.2259522438049317
13:12:15.793 [main] DEBUG org.deeplearning4j.datasets.iterator.AsyncDataSetIterator - Manually destroying ADSI workspace
13:12:15.926 [ADSI prefetch thread] DEBUG org.nd4j.linalg.memory.abstracts.Nd4jWorkspace - Steps: 5
13:12:37.306 [main] DEBUG org.deeplearning4j.datasets.iterator.AsyncDataSetIterator - Manually destroying ADSI workspace
13:12:37.327 [main] INFO com.mycompany.lstm.UCISequenceClassificationExample -
========================Evaluation Metrics========================
# of classes: 5
Accuracy: 0,4583
Precision: 0,5973 (2 classes excluded from average)
Recall: 0,3800
F1 Score: 0,6127 (2 classes excluded from average)
Precision, recall & F1: macro-averaged (equally weighted avg. of 5 classes)
Warning: 2 classes were never predicted by the model and were excluded from average precision
Classes excluded from average precision: [0, 1]
=========================Confusion Matrix=========================
0 1 2 3 4
----------------
0 0 15 4 0 | 0 = 0
0 0 13 6 0 | 1 = 1
0 0 50 53 0 | 2 = 2
0 0 49 53 1 | 3 = 3
0 0 2 0 18 | 4 = 4
Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times
==================================================================
13:12:37.327 [ADSI prefetch thread] DEBUG org.nd4j.linalg.memory.abstracts.Nd4jWorkspace - Steps: 5
13:12:53.130 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 140 is 2.0276023864746096
13:13:10.387 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 160 is 1.5625686645507812
13:13:25.929 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 180 is 1.5873197555541991
13:13:40.327 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 200 is 0.9586276054382324
13:13:55.432 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 220 is 0.7576540946960449
13:14:11.708 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 240 is 0.875974464416504
13:14:26.234 [main] DEBUG org.deeplearning4j.datasets.iterator.AsyncDataSetIterator - Manually destroying ADSI workspace
13:14:26.241 [ADSI prefetch thread] DEBUG org.nd4j.linalg.memory.abstracts.Nd4jWorkspace - Steps: 5
13:14:47.523 [main] DEBUG org.deeplearning4j.datasets.iterator.AsyncDataSetIterator - Manually destroying ADSI workspace
13:14:47.531 [main] INFO com.mycompany.lstm.UCISequenceClassificationExample -
========================Evaluation Metrics========================
# of classes: 5
Accuracy: 0,4280
Precision: 0,5372 (2 classes excluded from average)
Recall: 0,3806
F1 Score: 0,5673 (2 classes excluded from average)
Precision, recall & F1: macro-averaged (equally weighted avg. of 5 classes)
Warning: 2 classes were never predicted by the model and were excluded from average precision
Classes excluded from average precision: [0, 1]
=========================Confusion Matrix=========================
0 1 2 3 4
----------------
0 0 14 5 0 | 0 = 0
0 0 6 13 0 | 1 = 1
0 0 18 85 0 | 2 = 2
0 0 26 75 2 | 3 = 3
0 0 0 0 20 | 4 = 4
Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times
==================================================================
13:14:47.531 [ADSI prefetch thread] DEBUG org.nd4j.linalg.memory.abstracts.Nd4jWorkspace - Steps: 5
13:14:58.324 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 260 is 0.8601213455200195
13:15:16.516 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 280 is 0.9153075218200684
13:15:32.987 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 300 is 1.5319385528564453
13:15:47.603 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 320 is 0.8646750450134277
13:16:02.376 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 340 is 1.4645065307617187
13:16:17.640 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 360 is 0.8540505409240723
13:16:36.874 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 380 is 1.950326681137085
13:16:36.874 [main] DEBUG org.deeplearning4j.datasets.iterator.AsyncDataSetIterator - Manually destroying ADSI workspace
13:16:36.882 [ADSI prefetch thread] DEBUG org.nd4j.linalg.memory.abstracts.Nd4jWorkspace - Steps: 5
13:16:57.005 [main] DEBUG org.deeplearning4j.datasets.iterator.AsyncDataSetIterator - Manually destroying ADSI workspace
13:16:57.012 [main] INFO com.mycompany.lstm.UCISequenceClassificationExample -
========================Evaluation Metrics========================
# of classes: 5
Accuracy: 0,4205
Precision: 0,4696 (2 classes excluded from average)
Recall: 0,3767
F1 Score: 0,5170 (2 classes excluded from average)
Precision, recall & F1: macro-averaged (equally weighted avg. of 5 classes)
Warning: 2 classes were never predicted by the model and were excluded from average precision
Classes excluded from average precision: [0, 1]
=========================Confusion Matrix=========================
0 1 2 3 4
----------------
0 0 9 7 3 | 0 = 0
0 0 6 12 1 | 1 = 1
0 0 12 91 0 | 2 = 2
0 0 21 79 3 | 3 = 3
0 0 0 0 20 | 4 = 4
Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times
==================================================================
13:16:57.012 [ADSI prefetch thread] DEBUG org.nd4j.linalg.memory.abstracts.Nd4jWorkspace - Steps: 5
13:17:18.203 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 400 is 1.320208740234375
13:17:35.460 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 420 is 1.6308589935302735
13:17:52.665 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 440 is 1.1166643142700194
13:18:04.471 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 460 is 0.9357609748840332
13:18:19.177 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 480 is 1.330645751953125
13:18:38.274 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 500 is 0.7987275123596191
13:18:47.534 [main] DEBUG org.deeplearning4j.datasets.iterator.AsyncDataSetIterator - Manually destroying ADSI workspace
13:18:47.541 [ADSI prefetch thread] DEBUG org.nd4j.linalg.memory.abstracts.Nd4jWorkspace - Steps: 5
13:19:06.557 [main] DEBUG org.deeplearning4j.datasets.iterator.AsyncDataSetIterator - Manually destroying ADSI workspace
13:19:06.563 [main] INFO com.mycompany.lstm.UCISequenceClassificationExample -
========================Evaluation Metrics========================
# of classes: 5
Accuracy: 0,4356
Precision: 0,4292 (1 class excluded from average)
Recall: 0,3931
F1 Score: 0,4009 (1 class excluded from average)
Precision, recall & F1: macro-averaged (equally weighted avg. of 5 classes)
Warning: 1 class was never predicted by the model and was excluded from average precision
Classes excluded from average precision: [1]
=========================Confusion Matrix=========================
0 1 2 3 4
----------------
1 0 6 7 5 | 0 = 0
1 0 5 12 1 | 1 = 1
0 0 9 94 0 | 2 = 2
1 0 14 85 3 | 3 = 3
0 0 0 0 20 | 4 = 4
Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times
==================================================================
13:19:06.564 [ADSI prefetch thread] DEBUG org.nd4j.linalg.memory.abstracts.Nd4jWorkspace - Steps: 5
13:19:21.384 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 520 is 1.4552995681762695
13:19:38.909 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 540 is 0.7883223533630371
13:19:54.437 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 560 is 0.7487533569335938
13:20:09.842 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 580 is 0.8777185440063476
13:20:24.370 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 600 is 1.3696569442749023
13:20:40.463 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 620 is 0.7666335105895996
13:20:55.188 [main] DEBUG org.deeplearning4j.datasets.iterator.AsyncDataSetIterator - Manually destroying ADSI workspace
13:20:55.195 [ADSI prefetch thread] DEBUG org.nd4j.linalg.memory.abstracts.Nd4jWorkspace - Steps: 5
13:21:13.798 [main] DEBUG org.deeplearning4j.datasets.iterator.AsyncDataSetIterator - Manually destroying ADSI workspace
13:21:13.805 [main] INFO com.mycompany.lstm.UCISequenceClassificationExample -
========================Evaluation Metrics========================
# of classes: 5
Accuracy: 0,4545
Precision: 0,5491 (1 class excluded from average)
Recall: 0,4285
F1 Score: 0,4959 (1 class excluded from average)
Precision, recall & F1: macro-averaged (equally weighted avg. of 5 classes)
Warning: 1 class was never predicted by the model and was excluded from average precision
Classes excluded from average precision: [1]
=========================Confusion Matrix=========================
0 1 2 3 4
----------------
4 0 8 7 0 | 0 = 0
1 0 5 12 1 | 1 = 1
0 0 11 92 0 | 2 = 2
2 0 15 85 1 | 3 = 3
0 0 0 0 20 | 4 = 4
Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times
==================================================================
13:21:13.806 [ADSI prefetch thread] DEBUG org.nd4j.linalg.memory.abstracts.Nd4jWorkspace - Steps: 5
13:21:23.939 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 640 is 1.6179420471191406
13:21:41.866 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 660 is 1.214307975769043
13:21:56.973 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 680 is 0.922851276397705
13:22:12.932 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 700 is 0.8825350761413574
13:22:26.289 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 720 is 0.8634878158569336
13:22:42.122 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 740 is 1.2621063232421874
13:23:00.384 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 760 is 0.772999382019043
13:23:01.654 [main] DEBUG org.deeplearning4j.datasets.iterator.AsyncDataSetIterator - Manually destroying ADSI workspace
13:23:01.661 [ADSI prefetch thread] DEBUG org.nd4j.linalg.memory.abstracts.Nd4jWorkspace - Steps: 5
13:23:20.198 [main] DEBUG org.deeplearning4j.datasets.iterator.AsyncDataSetIterator - Manually destroying ADSI workspace
13:23:20.205 [main] INFO com.mycompany.lstm.UCISequenceClassificationExample -
========================Evaluation Metrics========================
# of classes: 5
Accuracy: 0,4924
Precision: 0,5578 (1 class excluded from average)
Recall: 0,4651
F1 Score: 0,5434 (1 class excluded from average)
Precision, recall & F1: macro-averaged (equally weighted avg. of 5 classes)
Warning: 1 class was never predicted by the model and was excluded from average precision
Classes excluded from average precision: [1]
=========================Confusion Matrix=========================
0 1 2 3 4
----------------
6 0 3 8 2 | 0 = 0
0 0 7 11 1 | 1 = 1
0 0 31 72 0 | 2 = 2
5 0 21 73 4 | 3 = 3
0 0 0 0 20 | 4 = 4
Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times
==================================================================
13:23:20.206 [ADSI prefetch thread] DEBUG org.nd4j.linalg.memory.abstracts.Nd4jWorkspace - Steps: 5
13:23:38.008 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 780 is 1.0697845458984374
13:23:55.367 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 800 is 0.7442119121551514
13:24:14.028 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 820 is 1.108755111694336
13:24:25.903 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 840 is 0.6914281845092773
13:24:39.838 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 860 is 0.6886839866638184
13:24:57.245 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 880 is 0.883371639251709
13:25:08.683 [main] DEBUG org.deeplearning4j.datasets.iterator.AsyncDataSetIterator - Manually destroying ADSI workspace
13:25:08.690 [ADSI prefetch thread] DEBUG org.nd4j.linalg.memory.abstracts.Nd4jWorkspace - Steps: 5
13:25:27.483 [main] DEBUG org.deeplearning4j.datasets.iterator.AsyncDataSetIterator - Manually destroying ADSI workspace
13:25:27.490 [main] INFO com.mycompany.lstm.UCISequenceClassificationExample -
========================Evaluation Metrics========================
# of classes: 5
Accuracy: 0,4318
Precision: 0,5647 (1 class excluded from average)
Recall: 0,4169
F1 Score: 0,5048 (1 class excluded from average)
Precision, recall & F1: macro-averaged (equally weighted avg. of 5 classes)
Warning: 1 class was never predicted by the model and was excluded from average precision
Classes excluded from average precision: [1]
=========================Confusion Matrix=========================
0 1 2 3 4
----------------
4 0 13 2 0 | 0 = 0
0 0 10 9 0 | 1 = 1
0 0 23 80 0 | 2 = 2
2 0 31 67 3 | 3 = 3
0 0 0 0 20 | 4 = 4
Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times
==================================================================
13:25:27.491 [ADSI prefetch thread] DEBUG org.nd4j.linalg.memory.abstracts.Nd4jWorkspace - Steps: 5
13:25:42.416 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 900 is 1.1810382843017577
13:25:59.881 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 920 is 0.8167665481567383
13:26:16.202 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 940 is 1.4128643035888673
13:26:31.338 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 960 is 0.7820789813995361
13:26:44.886 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 980 is 0.8911567687988281
13:27:01.673 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 1000 is 0.9900901794433594
13:27:16.626 [main] DEBUG org.deeplearning4j.datasets.iterator.AsyncDataSetIterator - Manually destroying ADSI workspace
13:27:16.633 [ADSI prefetch thread] DEBUG org.nd4j.linalg.memory.abstracts.Nd4jWorkspace - Steps: 5
13:27:35.392 [main] DEBUG org.deeplearning4j.datasets.iterator.AsyncDataSetIterator - Manually destroying ADSI workspace
13:27:35.399 [main] INFO com.mycompany.lstm.UCISequenceClassificationExample -
На входе у меня есть 2 файла для каждого образца:
- в файле функций Iимеют значения ускорения для 3 осей, разделенных (3 столбца), и в каждой строке есть 3 значения (x, y, z ось), количество строк в каждом файле объектов образца отличается для каждого образца.Например:
-1.4757843000000002 0.9027405000000001 8.998032
-1.4912566999999999 0.98605347 8.887344
-1.5733795 0.95510864 8.881393
-1.6447906000000003 0.93130493 8.804031
-1.6876373 0.9931946 8.838547
-1.7602386000000003 0.9515380999999999 8.852829
-1.7102509 1.0074768 8.79332
-1.7804718000000002 0.9920044 8.881393
-1.7590485 1.0110474 8.854019000000001
-1.7126312 1.0503235 8.889725
-1.7221526999999999 1.1217346000000001 8.948044
-1.6543120999999998 1.1276855 8.927811
-1.6590729 1.0872191999999998 9.105148
-1.6626433999999997 1.1622008999999998 9.218216
- в файле метки У меня есть только одно значение типа int в первой строке, указывающее ожидаемую выходную активность
Ниже я помещу всю своюкод LSTM.
SequenceRecordReader trainFeatures = new CSVSequenceRecordReader(0, csvSplitBy);
SequenceRecordReader trainLabels = new CSVSequenceRecordReader();
try {
trainFeatures.initialize(new NumberedFileInputSplit(featuresDirTrain.getAbsolutePath() + "/%d.csv", 0, trainCount-1));
trainLabels.initialize(new NumberedFileInputSplit(labelsDirTrain.getAbsolutePath() + "/%d.csv", 0, trainCount-1));
} catch (IOException ex) {
java.util.logging.Logger.getLogger(ExternalDatasetVersion.class.getName()).log(Level.SEVERE, null, ex);
} catch (InterruptedException ex) {
java.util.logging.Logger.getLogger(ExternalDatasetVersion.class.getName()).log(Level.SEVERE, null, ex);
}
int miniBatchSize = 10;
int numLabelClasses = 5;
DataSetIterator trainData = new SequenceRecordReaderDataSetIterator(trainFeatures, trainLabels, miniBatchSize, numLabelClasses,
false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
//Normalize the training data
DataNormalization normalizer = new NormalizerStandardize();
normalizer.fit(trainData); //Collect training data statistics
trainData.reset();
//Use previously collected statistics to normalize on-the-fly. Each DataSet returned by 'trainData' iterator will be normalized
trainData.setPreProcessor(normalizer);
// ----- Load the test data -----
//Same process as for the training data.
SequenceRecordReader testFeatures = new CSVSequenceRecordReader(0, csvSplitBy);
SequenceRecordReader testLabels = new CSVSequenceRecordReader();
try {
testFeatures.initialize(new NumberedFileInputSplit(featuresDirTest.getAbsolutePath() + "/%d.csv", 0, testCount-1));
testLabels.initialize(new NumberedFileInputSplit(labelsDirTest.getAbsolutePath() + "/%d.csv", 0, testCount-1));
} catch (IOException ex) {
java.util.logging.Logger.getLogger(ExternalDatasetVersion.class.getName()).log(Level.SEVERE, null, ex);
} catch (InterruptedException ex) {
java.util.logging.Logger.getLogger(ExternalDatasetVersion.class.getName()).log(Level.SEVERE, null, ex);
}
DataSetIterator testData = new SequenceRecordReaderDataSetIterator(testFeatures, testLabels, miniBatchSize, numLabelClasses,
false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
testData.setPreProcessor(normalizer);
// ----- Configure the network -----
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(123) //Random number generator seed for improved repeatability. Optional.
.weightInit(WeightInit.XAVIER)
.updater(new Nesterovs(0.005))
.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue) //Not always required, but helps with this data set
.gradientNormalizationThreshold(0.5)
.list()
.layer(0, new LSTM.Builder().activation(Activation.TANH).nIn(3).nOut(10).build())
.layer(1, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(10).nOut(numLabelClasses).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
net.setListeners(new ScoreIterationListener(20)); //Print the score (loss function value) every 20 iterations
// ----- Train the network, evaluating the test set performance at each epoch -----
int nEpochs = 200;
String str = "Test set evaluation at epoch %d: Accuracy = %.2f, F1 = %.2f";
for (int i = 0; i < nEpochs; i++) {
net.fit(trainData);
//Evaluate on the test set:
//Evaluation evaluation = net.evaluate(testData);
//log.info(String.format(str, i, evaluation.accuracy(), evaluation.f1()));
testData.reset();
trainData.reset();
}
Evaluation evaluation = net.evaluate(testData);
log.info(evaluation.stats());
log.info("----- Example Complete -----");
Может кто-нибудь помочь мне с этим, пожалуйста?