У меня есть сеть, которая имеет два входа - один - временной ряд (рекуррентный), другой - прямая прямая связь.Следующая часть кода для построителя графиков должна сказать все это:
final ComputationGraphConfiguration.GraphBuilder graphBuilder = builder.graphBuilder()
.backpropType(BackpropType.TruncatedBPTT)
.tBPTTBackwardLength(tbpttSize)
.tBPTTForwardLength(tbpttSize)
.addInputs("recurrentInput", "nonRecurrentInput")
.setInputTypes(
InputType.recurrent(numFeaturesRecurrent),
InputType.feedForward(numFeaturesNonRecurrent))
.addLayer("encoder",
new LSTM.Builder()
.nIn(numFeaturesRecurrent)
.nOut(hiddenRecurrentSize)
.activation(Activation.TANH)
.build(),
"recurrentInput")
.addVertex("thoughtVector",
new LastTimeStepVertex("recurrentInput"), "encoder")
.addVertex("merge",
new MergeVertex(), "thoughtVector", "nonRecurrentInput")
...
Параметры конфигурации TruncatedBPTT применяются ко всему вводу, и я получаю следующую ошибку:
java.lang.IllegalArgumentException: NDArrayIndex is out of range. Beginning index: 50 must be less than its size: 13
at org.nd4j.linalg.indexing.NDArrayIndex.validate(NDArrayIndex.java:459)
at org.nd4j.linalg.indexing.NDArrayIndex.resolve(NDArrayIndex.java:364)
at org.nd4j.linalg.api.ndarray.BaseNDArray.get(BaseNDArray.java:4996)
at org.deeplearning4j.nn.graph.ComputationGraph.getSubsetsForTbptt(ComputationGraph.java:3619)
at org.deeplearning4j.nn.graph.ComputationGraph.doTruncatedBPTT(ComputationGraph.java:3568)
at org.deeplearning4j.nn.graph.ComputationGraph.fitHelper(ComputationGraph.java:1140)
at org.deeplearning4j.nn.graph.ComputationGraph.fit(ComputationGraph.java:1098)
at org.deeplearning4j.nn.graph.ComputationGraph.fit(ComputationGraph.java:1006)
at org.mypackage.MultivariatePredictorNet.train(MultivariatePredictorNet.java:140)
at org.mypackage.MultivariatePredictorNet.main(MultivariatePredictorNet.java:209)
13 isточное количество функций в однократном вводе.Итак, как я могу сделать так, чтобы конфигурация TruncatedBPTT применялась только к повторяющимся данным?