Я хочу обучить простую сеть LSTM, но я получил исключение
java.lang.IllegalStateException: C (result) array is not F order or is a view. Nd4j.gemm requires the result array to be F order and not a view. C (result) array: [Rank: 2,Offset: 0 Order: f Shape: [10,1], stride: [1,10]]
Я обучаю простой NN с одной ячейкой LSTM и одной выходной ячейкой для регрессии.
Я создал обучающий набор данных из 10 выборок с переменной длиной последовательности (от 5 до 10) в CSV-файлах, каждая выборка состоит только из одного значения для ввода и одного значения для вывода.
Я создал SequenceRecordReaderDataSetIterator
из CSVSequenceRecordReader
.
Когда я тренирую свою сеть, код выдает исключение.
Я попытался сгенерировать случайный набор данных, кодирующий итератор набора данных напрямую с помощью 'f shape' INDarray, и код выполняется без ошибок.
Таким образом, проблема заключается в форме тензоров, созданных CSVSequenceRecordReader
.
У кого-нибудь есть такие проблемы?
SingleFileTimeSeriesDataReader.java
package org.mmarini.lstmtest;
import java.io.IOException;
import org.datavec.api.records.reader.SequenceRecordReader;
import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
import org.datavec.api.split.NumberedFileInputSplit;
import org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
/**
*
*/
public class SingleFileTimeSeriesDataReader {
private final int miniBatchSize;
private final int numPossibleLabels;
private final boolean regression;
private final String filePattern;
private final int maxFileIdx;
private final int minFileIdx;
private final int numInputs;
/**
*
* @param filePattern
* @param minFileIdx
* @param maxFileIdx
* @param numInputs
* @param numPossibleLabels
* @param miniBatchSize
* @param regression
*/
public SingleFileTimeSeriesDataReader(final String filePattern, final int minFileIdx, final int maxFileIdx,
final int numInputs, final int numPossibleLabels, final int miniBatchSize, final boolean regression) {
this.miniBatchSize = miniBatchSize;
this.numPossibleLabels = numPossibleLabels;
this.regression = regression;
this.filePattern = filePattern;
this.maxFileIdx = maxFileIdx;
this.minFileIdx = minFileIdx;
this.numInputs = numInputs;
}
/**
*
* @return
* @throws IOException
* @throws InterruptedException
*/
public DataSetIterator apply() throws IOException, InterruptedException {
final SequenceRecordReader reader = new CSVSequenceRecordReader(0, ",");
reader.initialize(new NumberedFileInputSplit(filePattern, minFileIdx, maxFileIdx));
final DataSetIterator iter = new SequenceRecordReaderDataSetIterator(reader, miniBatchSize, numPossibleLabels,
numInputs, regression);
return iter;
}
}
TestConfBuilder.java
/**
*
*/
package org.mmarini.lstmtest;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
/**
* @author mmarini
*
*/
public class TestConfBuilder {
private final int noInputUnits;
private final int noOutputUnits;
private final int noLstmUnits;
/**
*
* @param noInputUnits
* @param noOutputUnits
* @param noLstmUnits
*/
public TestConfBuilder(final int noInputUnits, final int noOutputUnits, final int noLstmUnits) {
super();
this.noInputUnits = noInputUnits;
this.noOutputUnits = noOutputUnits;
this.noLstmUnits = noLstmUnits;
}
/**
*
* @return
*/
public MultiLayerConfiguration build() {
final NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder()
.weightInit(WeightInit.XAVIER).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT);
final LSTM lstmLayer = new LSTM.Builder().units(noLstmUnits).nIn(noInputUnits).activation(Activation.TANH)
.build();
final RnnOutputLayer outLayer = new RnnOutputLayer.Builder(LossFunction.MEAN_SQUARED_LOGARITHMIC_ERROR)
.activation(Activation.IDENTITY).nOut(noOutputUnits).nIn(noLstmUnits).build();
final MultiLayerConfiguration conf = builder.list(lstmLayer, outLayer).build();
return conf;
}
}
TestTrainingTest .java
package org.mmarini.lstmtest;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import org.deeplearning4j.datasets.iterator.INDArrayDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.ArrayUtil;
class TestTrainingTest {
private static final int MINI_BATCH_SIZE = 10;
private static final int NUM_LABELS = 1;
private static final boolean REGRESSION = true;
private static final String SAMPLES_FILE = "src/test/resources/datatest/sample_%d.csv";
private static final int MIN_INPUTS_FILE_IDX = 0;
private static final int MAX_INPUTS_FILE_IDX = 9;
private static final int NUM_INPUTS_COLUMN = 1;
private static final int NUM_HIDDEN_UNITS = 1;
DataSetIterator createData() {
final double[][][] featuresAry = new double[][][] { { { 0.5, 0.2, 0.5 } }, { { 0.5, 1.0, 0.0 } } };
final double[] featuresData = ArrayUtil.flattenDoubleArray(featuresAry);
final int[] featuresShape = new int[] { 2, 1, 3 };
final INDArray features = Nd4j.create(featuresData, featuresShape, 'c');
final double[][][] labelsAry = new double[][][] { { { 1.0, -1.0, 1.0 }, { 1.0, -1.0, -1.0 } } };
final double[] labelsData = ArrayUtil.flattenDoubleArray(labelsAry);
final int[] labelsShape = new int[] { 2, 1, 3 };
final INDArray labels = Nd4j.create(labelsData, labelsShape, 'c');
final INDArrayDataSetIterator iter = new INDArrayDataSetIterator(
Arrays.asList(new Pair<INDArray, INDArray>(features, labels)), 2);
System.out.println(iter.inputColumns());
return iter;
}
private String file(String template) {
return new File(".", template).getAbsolutePath();
}
@Test
void testBuild() throws IOException, InterruptedException {
final SingleFileTimeSeriesDataReader reader = new SingleFileTimeSeriesDataReader(file(SAMPLES_FILE),
MIN_INPUTS_FILE_IDX, MAX_INPUTS_FILE_IDX, NUM_INPUTS_COLUMN, NUM_LABELS, MINI_BATCH_SIZE, REGRESSION);
final DataSetIterator data = reader.apply();
assertThat(data.inputColumns(), equalTo(NUM_INPUTS_COLUMN));
assertThat(data.totalOutcomes(), equalTo(NUM_LABELS));
final TestConfBuilder builder = new TestConfBuilder(NUM_INPUTS_COLUMN, NUM_LABELS, NUM_HIDDEN_UNITS);
final MultiLayerConfiguration conf = builder.build();
final MultiLayerNetwork net = new MultiLayerNetwork(conf);
assertNotNull(net);
net.init();
net.fit(data);
}
}
Я не ожидаю исключения, но я получил следующее исключение:
java.lang.IllegalStateException: C (result) array is not F order or is a view. Nd4j.gemm requires the result array to be F order and not a view. C (result) array: [Rank: 2,Offset: 0 Order: f Shape: [10,1], stride: [1,10]]
at org.nd4j.base.Preconditions.throwStateEx(Preconditions.java:641)
at org.nd4j.base.Preconditions.checkState(Preconditions.java:304)
at org.nd4j.linalg.factory.Nd4j.gemm(Nd4j.java:980)
at org.deeplearning4j.nn.layers.recurrent.LSTMHelpers.backpropGradientHelper(LSTMHelpers.java:696)
at org.deeplearning4j.nn.layers.recurrent.LSTM.backpropGradientHelper(LSTM.java:122)
at org.deeplearning4j.nn.layers.recurrent.LSTM.backpropGradient(LSTM.java:93)
at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.calcBackpropGradients(MultiLayerNetwork.java:1826)
at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.computeGradientAndScore(MultiLayerNetwork.java:2644)
at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.computeGradientAndScore(MultiLayerNetwork.java:2587)
at org.deeplearning4j.optimize.solvers.BaseOptimizer.gradientAndScore(BaseOptimizer.java:160)
at org.deeplearning4j.optimize.solvers.StochasticGradientDescent.optimize(StochasticGradientDescent.java:63)
at org.deeplearning4j.optimize.Solver.optimize(Solver.java:52)
at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.fitHelper(MultiLayerNetwork.java:1602)
at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.fit(MultiLayerNetwork.java:1521)
at org.mmarini.lstmtest.TestTrainingTest.testBuild(TestTrainingTest.java:77)
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
at org.junit.platform.commons.util.ReflectionUtils.invokeMethod(ReflectionUtils.java:532)
at org.junit.jupiter.engine.execution.ExecutableInvoker.invoke(ExecutableInvoker.java:115)
at org.junit.jupiter.engine.descriptor.TestMethodTestDescriptor.lambda$invokeTestMethod$6(TestMethodTestDescriptor.java:171)
at org.junit.platform.engine.support.hierarchical.ThrowableCollector.execute(ThrowableCollector.java:72)
at org.junit.jupiter.engine.descriptor.TestMethodTestDescriptor.invokeTestMethod(TestMethodTestDescriptor.java:167)
at org.junit.jupiter.engine.descriptor.TestMethodTestDescriptor.execute(TestMethodTestDescriptor.java:114)
at org.junit.jupiter.engine.descriptor.TestMethodTestDescriptor.execute(TestMethodTestDescriptor.java:59)
at org.junit.platform.engine.support.hierarchical.NodeTestTask.lambda$executeRecursively$4(NodeTestTask.java:108)
at org.junit.platform.engine.support.hierarchical.ThrowableCollector.execute(ThrowableCollector.java:72)
at org.junit.platform.engine.support.hierarchical.NodeTestTask.executeRecursively(NodeTestTask.java:98)
at org.junit.platform.engine.support.hierarchical.NodeTestTask.execute(NodeTestTask.java:74)
at java.util.ArrayList.forEach(ArrayList.java:1257)
at org.junit.platform.engine.support.hierarchical.SameThreadHierarchicalTestExecutorService.invokeAll(SameThreadHierarchicalTestExecutorService.java:38)
at org.junit.platform.engine.support.hierarchical.NodeTestTask.lambda$executeRecursively$4(NodeTestTask.java:112)
at org.junit.platform.engine.support.hierarchical.ThrowableCollector.execute(ThrowableCollector.java:72)
at org.junit.platform.engine.support.hierarchical.NodeTestTask.executeRecursively(NodeTestTask.java:98)
at org.junit.platform.engine.support.hierarchical.NodeTestTask.execute(NodeTestTask.java:74)
at java.util.ArrayList.forEach(ArrayList.java:1257)
at org.junit.platform.engine.support.hierarchical.SameThreadHierarchicalTestExecutorService.invokeAll(SameThreadHierarchicalTestExecutorService.java:38)
at org.junit.platform.engine.support.hierarchical.NodeTestTask.lambda$executeRecursively$4(NodeTestTask.java:112)
at org.junit.platform.engine.support.hierarchical.ThrowableCollector.execute(ThrowableCollector.java:72)
at org.junit.platform.engine.support.hierarchical.NodeTestTask.executeRecursively(NodeTestTask.java:98)
at org.junit.platform.engine.support.hierarchical.NodeTestTask.execute(NodeTestTask.java:74)
at org.junit.platform.engine.support.hierarchical.SameThreadHierarchicalTestExecutorService.submit(SameThreadHierarchicalTestExecutorService.java:32)
at org.junit.platform.engine.support.hierarchical.HierarchicalTestExecutor.execute(HierarchicalTestExecutor.java:57)
at org.junit.platform.engine.support.hierarchical.HierarchicalTestEngine.execute(HierarchicalTestEngine.java:51)
at org.junit.platform.launcher.core.DefaultLauncher.execute(DefaultLauncher.java:220)
at org.junit.platform.launcher.core.DefaultLauncher.lambda$execute$6(DefaultLauncher.java:188)
at org.junit.platform.launcher.core.DefaultLauncher.withInterceptedStreams(DefaultLauncher.java:202)
at org.junit.platform.launcher.core.DefaultLauncher.execute(DefaultLauncher.java:181)
at org.junit.platform.launcher.core.DefaultLauncher.execute(DefaultLauncher.java:128)
at org.eclipse.jdt.internal.junit5.runner.JUnit5TestReference.run(JUnit5TestReference.java:89)
at org.eclipse.jdt.internal.junit.runner.TestExecution.run(TestExecution.java:41)
at org.eclipse.jdt.internal.junit.runner.RemoteTestRunner.runTests(RemoteTestRunner.java:541)
at org.eclipse.jdt.internal.junit.runner.RemoteTestRunner.runTests(RemoteTestRunner.java:763)
at org.eclipse.jdt.internal.junit.runner.RemoteTestRunner.run(RemoteTestRunner.java:463)
at org.eclipse.jdt.internal.junit.runner.RemoteTestRunner.main(RemoteTestRunner.java:209)