Мне нужно обновить среду dl4j по крайней мере до альфа-1.0.0, чтобы получить поддержку CUDA 9 (требуется для RTX 2080).
После переключения зависимостей на альфа-1.0.0 моя сеть не работаетбольше не учусь.
Я обновил свой код, следуя руководству по обновлению в заметках о выпуске dl4j.
Поведение не меняется при использовании любой из более высоких версий.Новый код по-прежнему дает желаемые результаты при переходе к 0.9.1
Я не использую предварительную подготовку, поэтому изменение поведения fit () здесь не должно быть проблемой?!
Обучение с0.9.1: imgur.com/51645Lr.png
Обучение с 1.0.0-alpha: imgur.com/Nkirn7i.png
//Network setup
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.trainingWorkspaceMode(WorkspaceMode.SEPARATE)
.inferenceWorkspaceMode(WorkspaceMode.SEPARATE)
.seed(seed)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.weightInit(WeightInit.XAVIER)
.updater(new RmsProp.Builder().rmsDecay(0.95).learningRate(learningRate).build())
.l2(1e-4)
.list()
.layer(0, new LSTM.Builder()
.name("LSTM 1")
.nIn(nIn)
.nOut(lstmLayer1Size)
.activation(Activation.TANH)
.gateActivationFunction(Activation.HARDSIGMOID)
.dropOut(dropoutRatio)
.build())
.layer(1, new LSTM.Builder()
.name("LSTM 2")
.nIn(lstmLayer1Size)
.nOut(lstmLayer2Size)
.activation(Activation.TANH)
.gateActivationFunction(Activation.HARDSIGMOID)
.dropOut(dropoutRatio)
.build())
.layer(2, new DenseLayer.Builder()
.name("Dense")
.nIn(lstmLayer2Size)
.nOut(denseLayerSize)
.activation(Activation.RELU)
.build())
.layer(3, new RnnOutputLayer.Builder()
.nIn(denseLayerSize)
.nOut(nOut)
.activation(Activation.IDENTITY)
.lossFunction(LossFunctions.LossFunction.MSE)
.build())
.backpropType(BackpropType.TruncatedBPTT)
.tBPTTForwardLength(truncatedBPTTLength)
.tBPTTBackwardLength(truncatedBPTTLength)
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
//training loop
log.info("Training...");
for (int i = 0; i < epochs; i++) {
while (iterator.hasNext()){ net.fit(iterator.next());
}
iterator.reset(); // reset iterator
net.rnnClearPreviousState(); // clear previous state
}