CSVSequenceRecordReader создает совместимый набор данных для обучения сети LSTM? - PullRequest
0 голосов
/ 22 января 2019

Я хочу обучить простую сеть LSTM, но я получил исключение

java.lang.IllegalStateException: C (result) array is not F order or is a view. Nd4j.gemm requires the result array to be F order and not a view. C (result) array: [Rank: 2,Offset: 0 Order: f Shape: [10,1],  stride: [1,10]]

Я обучаю простой NN с одной ячейкой LSTM и одной выходной ячейкой для регрессии.

Я создал обучающий набор данных из 10 выборок с переменной длиной последовательности (от 5 до 10) в CSV-файлах, каждая выборка состоит только из одного значения для ввода и одного значения для вывода.

Я создал SequenceRecordReaderDataSetIterator из CSVSequenceRecordReader. Когда я тренирую свою сеть, код выдает исключение.

Я попытался сгенерировать случайный набор данных, кодирующий итератор набора данных напрямую с помощью 'f shape' INDarray, и код выполняется без ошибок.

Таким образом, проблема заключается в форме тензоров, созданных CSVSequenceRecordReader.

У кого-нибудь есть такие проблемы?

SingleFileTimeSeriesDataReader.java

package org.mmarini.lstmtest;

import java.io.IOException;

import org.datavec.api.records.reader.SequenceRecordReader;
import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
import org.datavec.api.split.NumberedFileInputSplit;
import org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;

/**
 *
 */
public class SingleFileTimeSeriesDataReader {

    private final int miniBatchSize;
    private final int numPossibleLabels;
    private final boolean regression;
    private final String filePattern;
    private final int maxFileIdx;
    private final int minFileIdx;
    private final int numInputs;

    /**
     * 
     * @param filePattern
     * @param minFileIdx
     * @param maxFileIdx
     * @param numInputs
     * @param numPossibleLabels
     * @param miniBatchSize
     * @param regression
     */
    public SingleFileTimeSeriesDataReader(final String filePattern, final int minFileIdx, final int maxFileIdx,
            final int numInputs, final int numPossibleLabels, final int miniBatchSize, final boolean regression) {
        this.miniBatchSize = miniBatchSize;
        this.numPossibleLabels = numPossibleLabels;
        this.regression = regression;
        this.filePattern = filePattern;
        this.maxFileIdx = maxFileIdx;
        this.minFileIdx = minFileIdx;
        this.numInputs = numInputs;
    }

    /**
     *
     * @return
     * @throws IOException
     * @throws InterruptedException
     */
    public DataSetIterator apply() throws IOException, InterruptedException {
        final SequenceRecordReader reader = new CSVSequenceRecordReader(0, ",");
        reader.initialize(new NumberedFileInputSplit(filePattern, minFileIdx, maxFileIdx));
        final DataSetIterator iter = new SequenceRecordReaderDataSetIterator(reader, miniBatchSize, numPossibleLabels,
                numInputs, regression);
        return iter;
    }
}

TestConfBuilder.java

/**
 *
 */
package org.mmarini.lstmtest;

import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;

/**
 * @author mmarini
 *
 */
public class TestConfBuilder {

    private final int noInputUnits;
    private final int noOutputUnits;
    private final int noLstmUnits;

    /**
     *
     * @param noInputUnits
     * @param noOutputUnits
     * @param noLstmUnits
     */
    public TestConfBuilder(final int noInputUnits, final int noOutputUnits, final int noLstmUnits) {
        super();
        this.noInputUnits = noInputUnits;
        this.noOutputUnits = noOutputUnits;
        this.noLstmUnits = noLstmUnits;
    }

    /**
     *
     * @return
     */
    public MultiLayerConfiguration build() {
        final NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder()
                .weightInit(WeightInit.XAVIER).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT);
        final LSTM lstmLayer = new LSTM.Builder().units(noLstmUnits).nIn(noInputUnits).activation(Activation.TANH)
                .build();
        final RnnOutputLayer outLayer = new RnnOutputLayer.Builder(LossFunction.MEAN_SQUARED_LOGARITHMIC_ERROR)
                .activation(Activation.IDENTITY).nOut(noOutputUnits).nIn(noLstmUnits).build();
        final MultiLayerConfiguration conf = builder.list(lstmLayer, outLayer).build();
        return conf;
    }
}

TestTrainingTest .java

package org.mmarini.lstmtest;

import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertNotNull;

import java.io.File;
import java.io.IOException;
import java.util.Arrays;

import org.deeplearning4j.datasets.iterator.INDArrayDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.junit.jupiter.api.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.ArrayUtil;

class TestTrainingTest {

    private static final int MINI_BATCH_SIZE = 10;
    private static final int NUM_LABELS = 1;
    private static final boolean REGRESSION = true;
    private static final String SAMPLES_FILE = "src/test/resources/datatest/sample_%d.csv";
    private static final int MIN_INPUTS_FILE_IDX = 0;
    private static final int MAX_INPUTS_FILE_IDX = 9;
    private static final int NUM_INPUTS_COLUMN = 1;
    private static final int NUM_HIDDEN_UNITS = 1;

    DataSetIterator createData() {
        final double[][][] featuresAry = new double[][][] { { { 0.5, 0.2, 0.5 } }, { { 0.5, 1.0, 0.0 } } };
        final double[] featuresData = ArrayUtil.flattenDoubleArray(featuresAry);
        final int[] featuresShape = new int[] { 2, 1, 3 };
        final INDArray features = Nd4j.create(featuresData, featuresShape, 'c');

        final double[][][] labelsAry = new double[][][] { { { 1.0, -1.0, 1.0 }, { 1.0, -1.0, -1.0 } } };
        final double[] labelsData = ArrayUtil.flattenDoubleArray(labelsAry);
        final int[] labelsShape = new int[] { 2, 1, 3 };
        final INDArray labels = Nd4j.create(labelsData, labelsShape, 'c');

        final INDArrayDataSetIterator iter = new INDArrayDataSetIterator(
                Arrays.asList(new Pair<INDArray, INDArray>(features, labels)), 2);
        System.out.println(iter.inputColumns());
        return iter;
    }

    private String file(String template) {
        return new File(".", template).getAbsolutePath();
    }

    @Test
    void testBuild() throws IOException, InterruptedException {
        final SingleFileTimeSeriesDataReader reader = new SingleFileTimeSeriesDataReader(file(SAMPLES_FILE),
                MIN_INPUTS_FILE_IDX, MAX_INPUTS_FILE_IDX, NUM_INPUTS_COLUMN, NUM_LABELS, MINI_BATCH_SIZE, REGRESSION);

        final DataSetIterator data = reader.apply();

        assertThat(data.inputColumns(), equalTo(NUM_INPUTS_COLUMN));
        assertThat(data.totalOutcomes(), equalTo(NUM_LABELS));

        final TestConfBuilder builder = new TestConfBuilder(NUM_INPUTS_COLUMN, NUM_LABELS, NUM_HIDDEN_UNITS);
        final MultiLayerConfiguration conf = builder.build();
        final MultiLayerNetwork net = new MultiLayerNetwork(conf);
        assertNotNull(net);
        net.init();
        net.fit(data);
    }

}

Я не ожидаю исключения, но я получил следующее исключение:

java.lang.IllegalStateException: C (result) array is not F order or is a view. Nd4j.gemm requires the result array to be F order and not a view. C (result) array: [Rank: 2,Offset: 0 Order: f Shape: [10,1],  stride: [1,10]]
    at org.nd4j.base.Preconditions.throwStateEx(Preconditions.java:641)
    at org.nd4j.base.Preconditions.checkState(Preconditions.java:304)
    at org.nd4j.linalg.factory.Nd4j.gemm(Nd4j.java:980)
    at org.deeplearning4j.nn.layers.recurrent.LSTMHelpers.backpropGradientHelper(LSTMHelpers.java:696)
    at org.deeplearning4j.nn.layers.recurrent.LSTM.backpropGradientHelper(LSTM.java:122)
    at org.deeplearning4j.nn.layers.recurrent.LSTM.backpropGradient(LSTM.java:93)
    at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.calcBackpropGradients(MultiLayerNetwork.java:1826)
    at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.computeGradientAndScore(MultiLayerNetwork.java:2644)
    at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.computeGradientAndScore(MultiLayerNetwork.java:2587)
    at org.deeplearning4j.optimize.solvers.BaseOptimizer.gradientAndScore(BaseOptimizer.java:160)
    at org.deeplearning4j.optimize.solvers.StochasticGradientDescent.optimize(StochasticGradientDescent.java:63)
    at org.deeplearning4j.optimize.Solver.optimize(Solver.java:52)
    at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.fitHelper(MultiLayerNetwork.java:1602)
    at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.fit(MultiLayerNetwork.java:1521)
    at org.mmarini.lstmtest.TestTrainingTest.testBuild(TestTrainingTest.java:77)
    at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.lang.reflect.Method.invoke(Method.java:498)
    at org.junit.platform.commons.util.ReflectionUtils.invokeMethod(ReflectionUtils.java:532)
    at org.junit.jupiter.engine.execution.ExecutableInvoker.invoke(ExecutableInvoker.java:115)
    at org.junit.jupiter.engine.descriptor.TestMethodTestDescriptor.lambda$invokeTestMethod$6(TestMethodTestDescriptor.java:171)
    at org.junit.platform.engine.support.hierarchical.ThrowableCollector.execute(ThrowableCollector.java:72)
    at org.junit.jupiter.engine.descriptor.TestMethodTestDescriptor.invokeTestMethod(TestMethodTestDescriptor.java:167)
    at org.junit.jupiter.engine.descriptor.TestMethodTestDescriptor.execute(TestMethodTestDescriptor.java:114)
    at org.junit.jupiter.engine.descriptor.TestMethodTestDescriptor.execute(TestMethodTestDescriptor.java:59)
    at org.junit.platform.engine.support.hierarchical.NodeTestTask.lambda$executeRecursively$4(NodeTestTask.java:108)
    at org.junit.platform.engine.support.hierarchical.ThrowableCollector.execute(ThrowableCollector.java:72)
    at org.junit.platform.engine.support.hierarchical.NodeTestTask.executeRecursively(NodeTestTask.java:98)
    at org.junit.platform.engine.support.hierarchical.NodeTestTask.execute(NodeTestTask.java:74)
    at java.util.ArrayList.forEach(ArrayList.java:1257)
    at org.junit.platform.engine.support.hierarchical.SameThreadHierarchicalTestExecutorService.invokeAll(SameThreadHierarchicalTestExecutorService.java:38)
    at org.junit.platform.engine.support.hierarchical.NodeTestTask.lambda$executeRecursively$4(NodeTestTask.java:112)
    at org.junit.platform.engine.support.hierarchical.ThrowableCollector.execute(ThrowableCollector.java:72)
    at org.junit.platform.engine.support.hierarchical.NodeTestTask.executeRecursively(NodeTestTask.java:98)
    at org.junit.platform.engine.support.hierarchical.NodeTestTask.execute(NodeTestTask.java:74)
    at java.util.ArrayList.forEach(ArrayList.java:1257)
    at org.junit.platform.engine.support.hierarchical.SameThreadHierarchicalTestExecutorService.invokeAll(SameThreadHierarchicalTestExecutorService.java:38)
    at org.junit.platform.engine.support.hierarchical.NodeTestTask.lambda$executeRecursively$4(NodeTestTask.java:112)
    at org.junit.platform.engine.support.hierarchical.ThrowableCollector.execute(ThrowableCollector.java:72)
    at org.junit.platform.engine.support.hierarchical.NodeTestTask.executeRecursively(NodeTestTask.java:98)
    at org.junit.platform.engine.support.hierarchical.NodeTestTask.execute(NodeTestTask.java:74)
    at org.junit.platform.engine.support.hierarchical.SameThreadHierarchicalTestExecutorService.submit(SameThreadHierarchicalTestExecutorService.java:32)
    at org.junit.platform.engine.support.hierarchical.HierarchicalTestExecutor.execute(HierarchicalTestExecutor.java:57)
    at org.junit.platform.engine.support.hierarchical.HierarchicalTestEngine.execute(HierarchicalTestEngine.java:51)
    at org.junit.platform.launcher.core.DefaultLauncher.execute(DefaultLauncher.java:220)
    at org.junit.platform.launcher.core.DefaultLauncher.lambda$execute$6(DefaultLauncher.java:188)
    at org.junit.platform.launcher.core.DefaultLauncher.withInterceptedStreams(DefaultLauncher.java:202)
    at org.junit.platform.launcher.core.DefaultLauncher.execute(DefaultLauncher.java:181)
    at org.junit.platform.launcher.core.DefaultLauncher.execute(DefaultLauncher.java:128)
    at org.eclipse.jdt.internal.junit5.runner.JUnit5TestReference.run(JUnit5TestReference.java:89)
    at org.eclipse.jdt.internal.junit.runner.TestExecution.run(TestExecution.java:41)
    at org.eclipse.jdt.internal.junit.runner.RemoteTestRunner.runTests(RemoteTestRunner.java:541)
    at org.eclipse.jdt.internal.junit.runner.RemoteTestRunner.runTests(RemoteTestRunner.java:763)
    at org.eclipse.jdt.internal.junit.runner.RemoteTestRunner.run(RemoteTestRunner.java:463)
    at org.eclipse.jdt.internal.junit.runner.RemoteTestRunner.main(RemoteTestRunner.java:209)

1 Ответ

0 голосов
/ 20 февраля 2019

Пожалуйста, смотрите сообщество DL4J Gitter: https://gitter.im/deeplearning4j/deeplearning4j

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...