Обучение не прогрессирует вообще с RNN - PullRequest
0 голосов
/ 06 октября 2019

enter image description here Перевод с японского на английский. (Поскольку в Японии мало информации)

Я использую DL4j для прогнозирования цен на акции. Независимо от того, сколько проб и ошибок, обучение не происходит вообще. В чем причина?

Поскольку считываемые данные нормализованы, я думаю, что нет большой разницы для каждой марки.

После прикрепленного изображения, независимо от того, сколько часов вы изучаете, он останется прежним.

[Input] 0 Начальная цена акций
1 Высокая цена акций
2 Низкая цена акций
3 Конечная цена акций
4 Уникальное значение A
5 Уникальное значение B
Unique valueA и Unique valueB is Это результат вычисления по алгоритму, который я разработал. Он представляет характеристики дня и дает четкую связь с повышением и понижением будущих цен на акции.

[Выходные данные] В настоящее время мы ожидаем up``silent``down в течение 10 дней с тремя выходными данными.

Даже если вы попробуете это с двумя выходными данными up``other, обучение не будет прогрессировать.

Предскажите, будет ли цена акций расти или падать более чем на 10% в течение 10 дней.

Как упоминалось ранее, входные данные [4] и входные данные [5] относятся к скачкам цен на акции и могут быть предсказаны.

Я занимаюсь машинным обучением для устранения ложных срабатываний.

【DataSet】 Существуетогромный набор данных. Это пример. Шаг по времени составляет 100.

===========INPUT===================
[[[    1.0357,    1.0513,    1.0147  ...    0.9730    0.9695,    0.9734], 
  [    1.0388,    1.0341,    1.0022  ...    0.9656    0.9734,    0.9734], 
  [    1.0427,    1.0513,    1.0201  ...    0.9734    0.9734,    0.9738], 
  [    1.0318,    1.0287,    0.9855  ...    0.9656    0.9664,    0.9734], 
  [    1.6755,    1.2601,    0.3456  ...         0    0.7736,    6.7037], 
  [    0.3751,    0.3616,    1.1761  ...    0.2864    0.5629,    0.4131]], 

 [[    1.0821,    1.0801,    1.0601  ...    1.3061    1.3001,    1.2961], 
  [    1.1001,    1.0901,    1.0401  ...    1.2981    1.2821,    1.3381], 
  [    1.1001,    1.0901,    1.0601  ...    1.3101    1.3001,    1.3381], 
  [    1.0801,    1.0621,    1.0241  ...    1.2601    1.2821,    1.2961], 
  [    0.0688,    0.9349,    0.0838  ...    1.6641         0,    0.3009], 
  [    1.3479,    0.0673,         0  ...    0.6315    0.9210,   58.8286]], 

 [[    1.0900,    1.0708,    1.0708  ...    1.0900    1.0836,    1.0772], 
  [    1.0772,    1.0580,    1.1157  ...    1.0836    1.0708,    1.0516], 
  [    1.0964,    1.0708,    1.1157  ...    1.0900    1.0836,    1.0772], 
  [    1.0772,    1.0451,    1.0708  ...    1.0772    1.0708,    1.0516], 
  [    0.2267,    0.9888,         0  ...    0.5150    2.3653,         0], 
  [    1.7519,    0.6446,    1.7559  ...    2.0206    0.2649,    9.3929]], 

  ..., 

 [[    0.9925,    0.9925,    0.9953  ...    1.0109    1.0091,    1.0119], 
  [    0.9925,    0.9971,    0.9888  ...    1.0063    1.0119,    1.0082], 
  [    0.9971,    0.9980,    0.9953  ...    1.0119    1.0119,    1.0128], 
  [    0.9925,    0.9879,    0.9852  ...    1.0063    1.0073,    1.0063], 
  [    0.5251,    0.5053,    0.4411  ...    1.9785    0.9391,    7.1644], 
  [    0.9026,         0,    0.8412  ...    0.1732    0.1513,    0.0467]], 

 [[    0.7903,    0.7764,    0.7875  ...    1.0193    1.0165,    0.9216], 
  [    0.7764,    0.8015,    0.7819  ...    1.0221    1.0109,    0.9160], 
  [    0.7959,    0.8015,    0.7959  ...    1.0584    1.0556,    0.9355], 
  [    0.7764,    0.7764,    0.7819  ...    1.0193    1.0054,    0.8881], 
  [    0.7363,    0.3788,    0.3158  ...    0.4995    0.6567,    6.4159], 
  [    0.1783,    0.3246,    0.8313  ...    0.6535    0.3348,    0.1456]], 

 [[    0.8960,    0.9086,    0.9194  ...    0.9844    0.9553,    0.9875], 
  [    0.8935,    0.9169,    0.9068  ...    0.9680    0.9585,    1.0159], 
  [    0.8960,    0.9219,    0.9225  ...    0.9869    0.9831,    1.0229], 
  [    0.8821,    0.9023,    0.9017  ...    0.9642    0.9553,    0.9818], 
  [    4.5974,         0,    0.6352  ...    0.4522    0.6429,    0.7656], 
  [    0.3374,    0.3450,    1.0265  ...    1.5891    2.2172,    7.6396]]]
=================OUTPUT==================
[[[         0,         0,         0  ...         0         0,         0], 
  [         0,         0,         0  ...         0         0,    1.0000], 
  [         0,         0,         0  ...         0         0,         0]], 

 [[         0,         0,         0  ...         0         0,    1.0000], 
  [         0,         0,         0  ...         0         0,         0], 
  [         0,         0,         0  ...         0         0,         0]], 

 [[         0,         0,         0  ...         0         0,         0], 
  [         0,         0,         0  ...         0         0,         0], 
  [         0,         0,         0  ...         0         0,    1.0000]], 

  ..., 

 [[         0,         0,         0  ...         0         0,         0], 
  [         0,         0,         0  ...         0         0,         0], 
  [         0,         0,         0  ...         0         0,    1.0000]], 

 [[         0,         0,         0  ...         0         0,         0], 
  [         0,         0,         0  ...         0         0,    1.0000], 
  [         0,         0,         0  ...         0         0,         0]], 

 [[         0,         0,         0  ...         0         0,    1.0000], 
  [         0,         0,         0  ...         0         0,         0], 
  [         0,         0,         0  ...         0         0,         0]]]
===========INPUT MASK===================
[[    1.0000,    1.0000,    1.0000  ...    1.0000    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000  ...    1.0000    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000  ...    1.0000    1.0000,    1.0000], 
  ..., 
 [    1.0000,    1.0000,    1.0000  ...    1.0000    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000  ...    1.0000    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000  ...    1.0000    1.0000,    1.0000]]
===========OUTPUT MASK===================
[[         0,         0,         0  ...         0         0,    1.0000], 
 [         0,         0,         0  ...         0         0,    1.0000], 
 [         0,         0,         0  ...         0         0,    1.0000], 
  ..., 
 [         0,         0,         0  ...         0         0,    1.0000], 
 [         0,         0,         0  ...         0         0,    1.0000], 
 [         0,         0,         0  ...         0         0,    1.0000]]

【NeuralNetwork】

NumInputs = 6
NumOutputs = 3
NumLstmLayers = 256
val conf = NeuralNetConfiguration.Builder()
                .seed(19920528)
                .weightInit(WeightInit.XAVIER)
                .miniBatch(true)
                .updater(Adam())
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)

                .list()
                .layer(LSTM.Builder()
                        .nIn(NumInputs)
                        .nOut(NumLstmLayers)
                        .activation(Activation.TANH)
                        .build()
                )
                .layer(LSTM.Builder()
                        .nIn(NumLstmLayers)
                        .nOut(NumLstmLayers)
                        .activation(Activation.TANH)
                        .build()
                )
                .layer(RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
                        .nIn(NumLstmLayers)
                        .nOut(NumOutputs)
                        .activation(Activation.SOFTMAX)
                        .build()
                )

//               .backpropType(BackpropType.TruncatedBPTT)
//                .tBPTTForwardLength(5)
//                .tBPTTBackwardLength(5)
                .build()

        val nn = MultiLayerNetwork(conf)
        nn.init()

【другая информация】 минибат = 32 эпохи = 10

...