Итак, сейчас я испытываю болезненное погружение в изучение глубокого обучения с использованием Deep Learning 4j, в частности RL4j и обучения с подкреплением. Я был относительно неудачен в обучении моего компьютера, как играть в змею, но я продолжаю.
В любом случае, так что у меня возникла проблема, которую я не могу решить, я заставлю свою программу работать, пока я go сплю или нахожусь на работе (да, я работаю в основной отрасли) и Когда я проверяю, он выдавал эту ошибку во всех запущенных потоках, и программа полностью остановилась, учтите, что это обычно происходит примерно через час после начала обучения.
Exception in thread "Thread-8" java.lang.RuntimeException: Output from network is not a probability distribution: [[ ?, ?, ?]]
at org.deeplearning4j.rl4j.policy.ACPolicy.nextAction(ACPolicy.java:82)
at org.deeplearning4j.rl4j.policy.ACPolicy.nextAction(ACPolicy.java:37)
at org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete.trainSubEpoch(AsyncThreadDiscrete.java:96)
at org.deeplearning4j.rl4j.learning.async.AsyncThread.handleTraining(AsyncThread.java:144)
at org.deeplearning4j.rl4j.learning.async.AsyncThread.run(AsyncThread.java:121)
Вот как я настраиваю свою сеть
private static A3CDiscrete.A3CConfiguration CARTPOLE_A3C =
new A3CDiscrete.A3CConfiguration(
(new java.util.Random()).nextInt(), //Random seed
220, //Max step By epoch
500000, //Max step
6, //Number of threads
50, //t_max
75, //num step noop warmup
0.1, //reward scaling
0.987, //gamma
1.0 //td-error clipping
);
private static final ActorCriticFactorySeparateStdDense.Configuration CARTPOLE_NET_A3C = ActorCriticFactorySeparateStdDense.Configuration
.builder().updater(new Adam(.005)).l2(.01).numHiddenNodes(32).numLayer(3).build();
Кроме того, вход в мою сеть - это вся сетка для моей игры-змеи 16x16, помещенная в один двойной массив.
В том случае, если она имеет какое-то отношение к моей функции вознаграждения, это то, что
if(!snake.inGame()) {
return -5.3; //snake dies
}
if(snake.gotApple()) {
return 5.0+.37*(snake.getLength()); //snake gets apple
}
return 0; //survives
Мой вопрос Как предотвратить возникновение этой ошибки? Я действительно понятия не имею, что происходит, и это делает создание моей сети довольно трудным, да, я уже проверил в Интернете ответы, все, что появляется, это как 2 билета GitHub с 2018 года.
Если это так интересно вам не нужно go копать здесь - это функция от ACPolicy, которая выдает ошибку
public Integer nextAction(INDArray input) {
INDArray output = actorCritic.outputAll(input)[1];
if (rnd == null) {
return Learning.getMaxAction(output);
}
float rVal = rnd.nextFloat();
for (int i = 0; i < output.length(); i++) {
//System.out.println(i + " " + rVal + " " + output.getFloat(i));
if (rVal < output.getFloat(i)) {
return i;
} else
rVal -= output.getFloat(i);
}
throw new RuntimeException("Output from network is not a probability distribution: " + output);
}
Любая помощь, которую вы можете предложить, высоко ценится