Настройка усеченного размера BPTT для каждого входа для нескольких входов в deeplearning4j - PullRequest
0 голосов
/ 22 декабря 2018

У меня есть сеть, которая имеет два входа - один - временной ряд (рекуррентный), другой - прямая прямая связь.Следующая часть кода для построителя графиков должна сказать все это:

    final ComputationGraphConfiguration.GraphBuilder graphBuilder = builder.graphBuilder()
            .backpropType(BackpropType.TruncatedBPTT)
            .tBPTTBackwardLength(tbpttSize)
            .tBPTTForwardLength(tbpttSize)
            .addInputs("recurrentInput", "nonRecurrentInput")
            .setInputTypes(
                    InputType.recurrent(numFeaturesRecurrent),
                    InputType.feedForward(numFeaturesNonRecurrent))
            .addLayer("encoder",
                    new LSTM.Builder()
                            .nIn(numFeaturesRecurrent)
                            .nOut(hiddenRecurrentSize)
                            .activation(Activation.TANH)
                            .build(),
                    "recurrentInput")
            .addVertex("thoughtVector",
                    new LastTimeStepVertex("recurrentInput"), "encoder")
            .addVertex("merge",
                    new MergeVertex(), "thoughtVector", "nonRecurrentInput")
            ...

Параметры конфигурации TruncatedBPTT применяются ко всему вводу, и я получаю следующую ошибку:

java.lang.IllegalArgumentException: NDArrayIndex is out of range. Beginning index: 50 must be less than its size: 13
    at org.nd4j.linalg.indexing.NDArrayIndex.validate(NDArrayIndex.java:459)
    at org.nd4j.linalg.indexing.NDArrayIndex.resolve(NDArrayIndex.java:364)
    at org.nd4j.linalg.api.ndarray.BaseNDArray.get(BaseNDArray.java:4996)
    at org.deeplearning4j.nn.graph.ComputationGraph.getSubsetsForTbptt(ComputationGraph.java:3619)
    at org.deeplearning4j.nn.graph.ComputationGraph.doTruncatedBPTT(ComputationGraph.java:3568)
    at org.deeplearning4j.nn.graph.ComputationGraph.fitHelper(ComputationGraph.java:1140)
    at org.deeplearning4j.nn.graph.ComputationGraph.fit(ComputationGraph.java:1098)
    at org.deeplearning4j.nn.graph.ComputationGraph.fit(ComputationGraph.java:1006)
    at org.mypackage.MultivariatePredictorNet.train(MultivariatePredictorNet.java:140)
    at org.mypackage.MultivariatePredictorNet.main(MultivariatePredictorNet.java:209)

13 isточное количество функций в однократном вводе.Итак, как я могу сделать так, чтобы конфигурация TruncatedBPTT применялась только к повторяющимся данным?

...