RL4J A3 C DeepLearning Создание выходных данных из сети не является распределением вероятностей - PullRequest
0 голосов
/ 26 марта 2020

Итак, сейчас я испытываю болезненное погружение в изучение глубокого обучения с использованием Deep Learning 4j, в частности RL4j и обучения с подкреплением. Я был относительно неудачен в обучении моего компьютера, как играть в змею, но я продолжаю.

В любом случае, так что у меня возникла проблема, которую я не могу решить, я заставлю свою программу работать, пока я go сплю или нахожусь на работе (да, я работаю в основной отрасли) и Когда я проверяю, он выдавал эту ошибку во всех запущенных потоках, и программа полностью остановилась, учтите, что это обычно происходит примерно через час после начала обучения.

Exception in thread "Thread-8" java.lang.RuntimeException: Output from network is not a probability distribution: [[         ?,         ?,         ?]]
at org.deeplearning4j.rl4j.policy.ACPolicy.nextAction(ACPolicy.java:82)
at org.deeplearning4j.rl4j.policy.ACPolicy.nextAction(ACPolicy.java:37)
at org.deeplearning4j.rl4j.learning.async.AsyncThreadDiscrete.trainSubEpoch(AsyncThreadDiscrete.java:96)
at org.deeplearning4j.rl4j.learning.async.AsyncThread.handleTraining(AsyncThread.java:144)
at org.deeplearning4j.rl4j.learning.async.AsyncThread.run(AsyncThread.java:121)

Вот как я настраиваю свою сеть

    private static A3CDiscrete.A3CConfiguration CARTPOLE_A3C =
        new A3CDiscrete.A3CConfiguration(
                (new java.util.Random()).nextInt(),            //Random seed
                220,            //Max step By epoch
                500000,         //Max step
                6,              //Number of threads
                50,              //t_max
                75,             //num step noop warmup
                0.1,           //reward scaling
                0.987,           //gamma
                1.0           //td-error clipping
        );


private static final ActorCriticFactorySeparateStdDense.Configuration CARTPOLE_NET_A3C =  ActorCriticFactorySeparateStdDense.Configuration
.builder().updater(new Adam(.005)).l2(.01).numHiddenNodes(32).numLayer(3).build();

Кроме того, вход в мою сеть - это вся сетка для моей игры-змеи 16x16, помещенная в один двойной массив.

В том случае, если она имеет какое-то отношение к моей функции вознаграждения, это то, что

if(!snake.inGame()) {
        return -5.3; //snake dies 
    }
    if(snake.gotApple()) {
        return 5.0+.37*(snake.getLength()); //snake gets apple
    }
    return 0; //survives

Мой вопрос Как предотвратить возникновение этой ошибки? Я действительно понятия не имею, что происходит, и это делает создание моей сети довольно трудным, да, я уже проверил в Интернете ответы, все, что появляется, это как 2 билета GitHub с 2018 года.

Если это так интересно вам не нужно go копать здесь - это функция от ACPolicy, которая выдает ошибку

 public Integer nextAction(INDArray input) {
    INDArray output = actorCritic.outputAll(input)[1];
    if (rnd == null) {
        return Learning.getMaxAction(output);
    }
    float rVal = rnd.nextFloat();
    for (int i = 0; i < output.length(); i++) {
        //System.out.println(i + " " + rVal + " " + output.getFloat(i));
        if (rVal < output.getFloat(i)) {
            return i;
        } else
            rVal -= output.getFloat(i);
    }

    throw new RuntimeException("Output from network is not a probability distribution: " + output);
}

Любая помощь, которую вы можете предложить, высоко ценится

1 Ответ

2 голосов
/ 26 марта 2020

То, что вы видите, это то, что ваша сеть работает на NaN. Вот что означают знаки вопроса в исключении. Есть много причин, почему это может произойти. Вы говорите, что запускаете его в течение достаточно долгого времени, поэтому может случиться так, что в какой-то момент вы будете испытывать недостаток или переполнение. Некоторая регуляризация может помочь или некоторое ограничение градиента.

Тем не менее, сам RL4J перерабатывается начиная с бета6 и должен быть в гораздо лучшем состоянии к следующему выпуску.

Если вы хотите попробовать текущее состояние, есть снимки, которые вы можете использовать, а также рабочий пример A3 C на https://github.com/RobAltena/cartpole/blob/master/src/main/java/A3CCartpole.java

Для некоторых Если вам нужна более полная помощь, вам, вероятно, стоит заглянуть на форум сообщества DL4J по адресу community.konduit.ai. Он больше подходит туда-сюда, что, вероятно, понадобится, чтобы помочь вам построить успешный ИИ для вашей игры змея.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...