Я пытаюсь обучить нейронную сеть на наборе данных kingbase (извините, нет ссылки, поскольку она была недавно удалена). Я основываю свою архитектуру на нейронной сети AlphaZero:
https://kstatic.googleusercontent.com/files/2f51b2a749a284c2e2dfa13911da965f4855092a179469aedd15fbe4efe8f8cbf9c515ef83ac03a6515fa990e6f85fd827dcd477845e806f23a17845072dc7bd
Примечание: сеть AlphaZero была в основном основана на сети Alpha Go Zero, подробно описанной здесь : https://www.nature.com/articles/nature24270.epdf?author_access_token=VJXbVjaSHxFoctQQ4p2k4tRgN0jAjWel9jnR3ZoTv0PVW4gB86EEpGqTRDtpIz-2rmo8-KG06gqVobU5NSCFeHILHcVFUeMsbvwS-lxjqQGg98faovwjxeTUgZAUMnRQ
Одним из основных изменений, которые я сделал, является использование только временного шага 1 в качестве входного сигнала, поскольку на моем компьютере недостаточно места для , Чтобы ограничить эту проблему, я также уменьшил размер пакета до 200 вместо 4096. Другой фактор, который я изменил, - это скорость обучения, которую я изменил до 0,02, поскольку она не работала с 0,2. Несколько других факторов, которые я изменил для экономии памяти, включают изменение ввода таким образом, чтобы вместо отдельных входных плоскостей для фигур противника, используемых в сети AlphaZero, они сжимались в половину плоскостей путем кодирования этих фигур как -1 вместо 1. Я также сжал последние дополнительные плоскости в еще одну плоскость 8x8. Все это объединяется в 7x8x8 вход. Вот краткая информация о сети:
Model: "model_2"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) (None, 7, 8, 8) 0
__________________________________________________________________________________________________
conv2d_1 (Conv2D) (None, 7, 8, 256) 18688 input_1[0][0]
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 7, 8, 256) 1024 conv2d_1[0][0]
__________________________________________________________________________________________________
activation_1 (Activation) (None, 7, 8, 256) 0 batch_normalization_1[0][0]
__________________________________________________________________________________________________
conv2d_2 (Conv2D) (None, 7, 8, 256) 590080 activation_1[0][0]
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, 7, 8, 256) 1024 conv2d_2[0][0]
__________________________________________________________________________________________________
conv2d_3 (Conv2D) (None, 7, 8, 256) 590080 batch_normalization_2[0][0]
__________________________________________________________________________________________________
conv2d_4 (Conv2D) (None, 7, 8, 256) 18688 input_1[0][0]
__________________________________________________________________________________________________
batch_normalization_3 (BatchNor (None, 7, 8, 256) 1024 conv2d_3[0][0]
__________________________________________________________________________________________________
batch_normalization_4 (BatchNor (None, 7, 8, 256) 1024 conv2d_4[0][0]
__________________________________________________________________________________________________
add_1 (Add) (None, 7, 8, 256) 0 batch_normalization_3[0][0]
batch_normalization_4[0][0]
__________________________________________________________________________________________________
activation_2 (Activation) (None, 7, 8, 256) 0 add_1[0][0]
__________________________________________________________________________________________________
conv2d_5 (Conv2D) (None, 7, 8, 256) 590080 activation_2[0][0]
__________________________________________________________________________________________________
batch_normalization_5 (BatchNor (None, 7, 8, 256) 1024 conv2d_5[0][0]
__________________________________________________________________________________________________
...
conv2d_54 (Conv2D) (None, 7, 8, 256) 590080 batch_normalization_53[0][0]
__________________________________________________________________________________________________
conv2d_55 (Conv2D) (None, 7, 8, 256) 18688 input_1[0][0]
__________________________________________________________________________________________________
batch_normalization_54 (BatchNo (None, 7, 8, 256) 1024 conv2d_54[0][0]
__________________________________________________________________________________________________
batch_normalization_55 (BatchNo (None, 7, 8, 256) 1024 conv2d_55[0][0]
__________________________________________________________________________________________________
add_18 (Add) (None, 7, 8, 256) 0 batch_normalization_54[0][0]
batch_normalization_55[0][0]
__________________________________________________________________________________________________
activation_19 (Activation) (None, 7, 8, 256) 0 add_18[0][0]
__________________________________________________________________________________________________
conv2d_56 (Conv2D) (None, 7, 8, 256) 590080 activation_19[0][0]
__________________________________________________________________________________________________
batch_normalization_56 (BatchNo (None, 7, 8, 256) 1024 conv2d_56[0][0]
__________________________________________________________________________________________________
conv2d_57 (Conv2D) (None, 7, 8, 256) 590080 batch_normalization_56[0][0]
__________________________________________________________________________________________________
conv2d_58 (Conv2D) (None, 7, 8, 256) 18688 input_1[0][0]
__________________________________________________________________________________________________
batch_normalization_57 (BatchNo (None, 7, 8, 256) 1024 conv2d_57[0][0]
__________________________________________________________________________________________________
batch_normalization_58 (BatchNo (None, 7, 8, 256) 1024 conv2d_58[0][0]
__________________________________________________________________________________________________
add_19 (Add) (None, 7, 8, 256) 0 batch_normalization_57[0][0]
batch_normalization_58[0][0]
__________________________________________________________________________________________________
activation_20 (Activation) (None, 7, 8, 256) 0 add_19[0][0]
__________________________________________________________________________________________________
conv2d_60 (Conv2D) (None, 7, 8, 1) 257 activation_20[0][0]
__________________________________________________________________________________________________
batch_normalization_60 (BatchNo (None, 7, 8, 1) 4 conv2d_60[0][0]
__________________________________________________________________________________________________
conv2d_59 (Conv2D) (None, 7, 8, 2) 514 activation_20[0][0]
__________________________________________________________________________________________________
activation_22 (Activation) (None, 7, 8, 1) 0 batch_normalization_60[0][0]
__________________________________________________________________________________________________
batch_normalization_59 (BatchNo (None, 7, 8, 2) 8 conv2d_59[0][0]
__________________________________________________________________________________________________
dense_1 (Dense) (None, 7, 8, 256) 512 activation_22[0][0]
__________________________________________________________________________________________________
activation_21 (Activation) (None, 7, 8, 2) 0 batch_normalization_59[0][0]
__________________________________________________________________________________________________
activation_23 (Activation) (None, 7, 8, 256) 0 dense_1[0][0]
__________________________________________________________________________________________________
flatten_1 (Flatten) (None, 112) 0 activation_21[0][0]
__________________________________________________________________________________________________
flatten_2 (Flatten) (None, 14336) 0 activation_23[0][0]
__________________________________________________________________________________________________
policy_output (Dense) (None, 4672) 527936 flatten_1[0][0]
__________________________________________________________________________________________________
value_output (Dense) (None, 1) 14337 flatten_2[0][0]
__________________________________________________________________________________________________
reshape_1 (Reshape) (None, 73, 8, 8) 0 policy_output[0][0]
__________________________________________________________________________________________________
activation_24 (Activation) (None, 1) 0 value_output[0][0]
==================================================================================================
Total params: 23,399,760
Trainable params: 23,370,058
Non-trainable params: 29,702
__________________________________________________________________________________________________
Однако, это не очень хорошо работало при обучении.
У меня есть несколько идей относительно того, почему это работает не очень хорошо. Во-первых, я просто не обучал его достаточно долго ... Сеть AlphaZero потратила на обучение гораздо больше игр, чем эта, но это обучение под наблюдением, поэтому я ожидал, что оно даст гораздо больший прогресс, скорее, быстрее. Это также может быть усилено низкой скоростью обучения. Моя последняя теория состоит в том, что я допустил некоторую ошибку в кодировании сети, и она каким-то образом неправильно связана. Я также открыт для других идей о том, что я, возможно, сделал неправильно.
Обновление: Я обнаружил, что выходное значение сети сходится к 1 или -1, а затем к сети вряд ли тренируется из-за исчезающего градиента. Любая помощь в том, как предотвратить это, будет очень признательна.