Прогноз LSTM с учетом состояния для одной горячей метки возвращает случайные числа с плавающей точкой - PullRequest
0 голосов
/ 12 мая 2019

Особенности (форма и один ряд данных):

(900, 11250, 7)
[[ 0.53544164  0.53544164 -0.00109051 ...  0.53420806  0.53420806
   0.49437675]
 [ 0.53544164  0.53435117 -0.00109051 ...  0.4985942   0.5164011
   0.4878163 ]
 [ 0.53435117  0.5338059   0.         ...  0.5168697   0.5262418
   0.50984067]
 ...
 [ 0.51799345  0.5185387   0.         ...  0.418463    0.4343955
   0.418463  ]
 [ 0.51799345  0.51799345  0.         ...  0.42314902  0.43720713
   0.4212746 ]
 [ 0.51799345  0.5174482  -0.00109051 ...  0.42502344  0.44376758
   0.42502344]]

Метки (форма и один ряд данных):

(900, 8)
[[0 0 0 ... 0 0 1]
 [0 0 0 ... 0 0 1]
 [0 0 0 ... 0 0 1]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]]

Модель:

model = Sequential()
model.add(LSTM(neurons, batch_input_shape=(window_size, n_steps, inputs_n), stateful=True))
model.add(Dense(outputs_n, activation='linear'))
model.compile(loss='mse', optimizer='adam', metrics=['accuracy'])

Тренировка (в цикле за 1 день данных):

X,y = split_sequence(features, labels, n_steps)
X = X.reshape(X.shape[0], X.shape[1], inputs_n)
history = model.train_on_batch(X, y)

Прогноз (в цикле за 1 день данных):

X,y = split_sequence(features, labels, n_steps)
X = X.reshape(X.shape[0], X.shape[1], inputs_n)
prediction = model.predict_on_batch(X)

Вывод прогнозов:

(900, 8)
[array([[-0.2294541 , -0.0739788 , -0.00822558, ...,  0.14238606,
        -0.12440741,  0.18252842],
       [-0.23193507, -0.07398052, -0.009443  , ...,  0.13879785,
        -0.12469499,  0.18462572],
       [-0.23936558, -0.07401707, -0.0090025 , ...,  0.13221815,
        -0.12486663,  0.18865344],
       ...,
       [-0.30348817, -0.09799367,  0.03219191, ...,  0.10227684,
        -0.11021828,  0.22304526],
       [-0.30362955, -0.09783267,  0.03222591, ...,  0.1022176 ,
        -0.11022741,  0.22306752],
       [-0.30342066, -0.09787753,  0.03216775, ...,  0.10230445,
        -0.11023697,  0.22295707]], dtype=float32)]

Эти предсказания верны? Я ожидал бы что-то близкое к 00 и 11. есть ли способ предсказать вероятность 0 или 1 для каждой метки вместо одной горячей позиции?

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...