Моя нейронная сеть не обучается должным образом, потому что я не обучил ее на достаточном количестве данных или из-за неправильного кодирования / параметров? - PullRequest
0 голосов
/ 25 января 2020

Я пытаюсь обучить нейронную сеть на наборе данных kingbase (извините, нет ссылки, поскольку она была недавно удалена). Я основываю свою архитектуру на нейронной сети AlphaZero:

https://kstatic.googleusercontent.com/files/2f51b2a749a284c2e2dfa13911da965f4855092a179469aedd15fbe4efe8f8cbf9c515ef83ac03a6515fa990e6f85fd827dcd477845e806f23a17845072dc7bd

Примечание: сеть AlphaZero была в основном основана на сети Alpha Go Zero, подробно описанной здесь : https://www.nature.com/articles/nature24270.epdf?author_access_token=VJXbVjaSHxFoctQQ4p2k4tRgN0jAjWel9jnR3ZoTv0PVW4gB86EEpGqTRDtpIz-2rmo8-KG06gqVobU5NSCFeHILHcVFUeMsbvwS-lxjqQGg98faovwjxeTUgZAUMnRQ

Одним из основных изменений, которые я сделал, является использование только временного шага 1 в качестве входного сигнала, поскольку на моем компьютере недостаточно места для , Чтобы ограничить эту проблему, я также уменьшил размер пакета до 200 вместо 4096. Другой фактор, который я изменил, - это скорость обучения, которую я изменил до 0,02, поскольку она не работала с 0,2. Несколько других факторов, которые я изменил для экономии памяти, включают изменение ввода таким образом, чтобы вместо отдельных входных плоскостей для фигур противника, используемых в сети AlphaZero, они сжимались в половину плоскостей путем кодирования этих фигур как -1 вместо 1. Я также сжал последние дополнительные плоскости в еще одну плоскость 8x8. Все это объединяется в 7x8x8 вход. Вот краткая информация о сети:

Model: "model_2"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, 7, 8, 8)      0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 7, 8, 256)    18688       input_1[0][0]                    
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 7, 8, 256)    1024        conv2d_1[0][0]                   
__________________________________________________________________________________________________
activation_1 (Activation)       (None, 7, 8, 256)    0           batch_normalization_1[0][0]      
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 7, 8, 256)    590080      activation_1[0][0]               
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, 7, 8, 256)    1024        conv2d_2[0][0]                   
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 7, 8, 256)    590080      batch_normalization_2[0][0]      
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 7, 8, 256)    18688       input_1[0][0]                    
__________________________________________________________________________________________________
batch_normalization_3 (BatchNor (None, 7, 8, 256)    1024        conv2d_3[0][0]                   
__________________________________________________________________________________________________
batch_normalization_4 (BatchNor (None, 7, 8, 256)    1024        conv2d_4[0][0]                   
__________________________________________________________________________________________________
add_1 (Add)                     (None, 7, 8, 256)    0           batch_normalization_3[0][0]      
                                                                 batch_normalization_4[0][0]      
__________________________________________________________________________________________________
activation_2 (Activation)       (None, 7, 8, 256)    0           add_1[0][0]                      
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, 7, 8, 256)    590080      activation_2[0][0]               
__________________________________________________________________________________________________
batch_normalization_5 (BatchNor (None, 7, 8, 256)    1024        conv2d_5[0][0]                   
__________________________________________________________________________________________________
...
conv2d_54 (Conv2D)              (None, 7, 8, 256)    590080      batch_normalization_53[0][0]     
__________________________________________________________________________________________________
conv2d_55 (Conv2D)              (None, 7, 8, 256)    18688       input_1[0][0]                    
__________________________________________________________________________________________________
batch_normalization_54 (BatchNo (None, 7, 8, 256)    1024        conv2d_54[0][0]                  
__________________________________________________________________________________________________
batch_normalization_55 (BatchNo (None, 7, 8, 256)    1024        conv2d_55[0][0]                  
__________________________________________________________________________________________________
add_18 (Add)                    (None, 7, 8, 256)    0           batch_normalization_54[0][0]     
                                                                 batch_normalization_55[0][0]     
__________________________________________________________________________________________________
activation_19 (Activation)      (None, 7, 8, 256)    0           add_18[0][0]                     
__________________________________________________________________________________________________
conv2d_56 (Conv2D)              (None, 7, 8, 256)    590080      activation_19[0][0]              
__________________________________________________________________________________________________
batch_normalization_56 (BatchNo (None, 7, 8, 256)    1024        conv2d_56[0][0]                  
__________________________________________________________________________________________________
conv2d_57 (Conv2D)              (None, 7, 8, 256)    590080      batch_normalization_56[0][0]     
__________________________________________________________________________________________________
conv2d_58 (Conv2D)              (None, 7, 8, 256)    18688       input_1[0][0]                    
__________________________________________________________________________________________________
batch_normalization_57 (BatchNo (None, 7, 8, 256)    1024        conv2d_57[0][0]                  
__________________________________________________________________________________________________
batch_normalization_58 (BatchNo (None, 7, 8, 256)    1024        conv2d_58[0][0]                  
__________________________________________________________________________________________________
add_19 (Add)                    (None, 7, 8, 256)    0           batch_normalization_57[0][0]     
                                                                 batch_normalization_58[0][0]     
__________________________________________________________________________________________________
activation_20 (Activation)      (None, 7, 8, 256)    0           add_19[0][0]                     
__________________________________________________________________________________________________
conv2d_60 (Conv2D)              (None, 7, 8, 1)      257         activation_20[0][0]              
__________________________________________________________________________________________________
batch_normalization_60 (BatchNo (None, 7, 8, 1)      4           conv2d_60[0][0]                  
__________________________________________________________________________________________________
conv2d_59 (Conv2D)              (None, 7, 8, 2)      514         activation_20[0][0]              
__________________________________________________________________________________________________
activation_22 (Activation)      (None, 7, 8, 1)      0           batch_normalization_60[0][0]     
__________________________________________________________________________________________________
batch_normalization_59 (BatchNo (None, 7, 8, 2)      8           conv2d_59[0][0]                  
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 7, 8, 256)    512         activation_22[0][0]              
__________________________________________________________________________________________________
activation_21 (Activation)      (None, 7, 8, 2)      0           batch_normalization_59[0][0]     
__________________________________________________________________________________________________
activation_23 (Activation)      (None, 7, 8, 256)    0           dense_1[0][0]                    
__________________________________________________________________________________________________
flatten_1 (Flatten)             (None, 112)          0           activation_21[0][0]              
__________________________________________________________________________________________________
flatten_2 (Flatten)             (None, 14336)        0           activation_23[0][0]              
__________________________________________________________________________________________________
policy_output (Dense)           (None, 4672)         527936      flatten_1[0][0]                  
__________________________________________________________________________________________________
value_output (Dense)            (None, 1)            14337       flatten_2[0][0]                  
__________________________________________________________________________________________________
reshape_1 (Reshape)             (None, 73, 8, 8)     0           policy_output[0][0]              
__________________________________________________________________________________________________
activation_24 (Activation)      (None, 1)            0           value_output[0][0]               
==================================================================================================
Total params: 23,399,760
Trainable params: 23,370,058
Non-trainable params: 29,702
__________________________________________________________________________________________________

Однако, это не очень хорошо работало при обучении. Comparison between the success of my neural network and an engine randomly guessing

У меня есть несколько идей относительно того, почему это работает не очень хорошо. Во-первых, я просто не обучал его достаточно долго ... Сеть AlphaZero потратила на обучение гораздо больше игр, чем эта, но это обучение под наблюдением, поэтому я ожидал, что оно даст гораздо больший прогресс, скорее, быстрее. Это также может быть усилено низкой скоростью обучения. Моя последняя теория состоит в том, что я допустил некоторую ошибку в кодировании сети, и она каким-то образом неправильно связана. Я также открыт для других идей о том, что я, возможно, сделал неправильно.

Обновление: Я обнаружил, что выходное значение сети сходится к 1 или -1, а затем к сети вряд ли тренируется из-за исчезающего градиента. Любая помощь в том, как предотвратить это, будет очень признательна.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...