У меня есть последовательность, которая очень предсказуема. Ниже вы можете увидеть его фрагмент:
deque([4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4])
По сути, это переменное, но все еще предсказуемое число 4 с, затем 8 и 28 каждые три 8.
I хочу построить очень простую модель LSTM для онлайн-прогнозирования: каждый раз, когда приходит новое число, оно добавляется справа от окна. Таким образом, LSTM обучается на старой последовательности, состоящей из [0: seq_length] элементов deque, причем обучающей целью является элемент [seq_length] . Затем окно сдвигается и прогнозирование выполняется для элементов [1: seq_length + 1]. В конце самый левый элемент deque отбрасывается. Моя интуиция подсказывает мне, что это должно заставить сеть запомнить последовательность.
Однако моя сеть отвечает только на 4. Через некоторое время, что удивительно, она начинает отвечать только на 8, пропуская почти все время. И затем, (долго), позже, возвращается к ответу только 4.
Моя модель структурирована, как показано. Естественно, я уже экспериментировал с различными значениями seq_length и lstm_cells , ни одно из которых не дало мне успеха. Это были последние прогоны:
seq_length = 64 #Length of the sequence to be inserted into the LSTM
vocab_size = 4 #Size of the final dense layer of the model
lstm_cells = 16 #Size of the LSTM layer
model = Sequential()
model.add(LSTM(lstm_cells, input_shape=(seq_length, 1)))
model.add(Dense(vocab_size))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['categorical_accuracy'])
Ниже описано, как данные готовятся, обучаются и прогнозируются на модели. Переменная sequence является декой, показанной в начале этого поста. Я веду список vocab = [4,8,28] , который построен на времени выполнения при появлении новых чисел, поэтому vocab [i] переводит класс i в соответствующий номер последовательности. Затем я создаю словарь legend , чтобы сделать обратное. Это более-менее постоянное онлайн-сообщение * oop:
while True:
# Receives new number and puts it into the deque:
sequence.append(generateNextNumber())
# At this point, please note that the length of the deque is seq_length + 1.
# Dictionary to convert numbers to classes:
legend = dict([(v, k) for k, v in enumerate(vocab)])
# Converts the deque into a list:
seq_list = list(sequence)
# Each iteration is comprised of 1 training and 1 prediction. These are the training sequence and target:
train_seq = [ [legend[i]] for i in seq_list[:seq_length] ]
train_target = legend[ seq_list[seq_length] ]
# And the prediction sequence just shifts the window by 1:
pred_seq = [ [legend[i]] for i in seq_list[1:] ]
# Batches data into a batch of size 1:
x = np.zeros((1, seq_length, 1))
y = np.zeros((1, vocab_size))
x[0,:] = train_seq
y[0,:] = to_categorical( train_target, num_classes=vocab_size )
# Online training:
model.fit(x=x, y=y, batch_size=1, epochs=1, verbose=0)
# Now that one training step is done, make a prediction:
x_pred = np.zeros((1, seq_length, 1))
x[0,:] = pred_seq
predicted_onehot = model.predict(x_pred)
# Avoids "index out of range" erros when the LSTM vocab is still being built:
predicted_index = min(np.argmax(predicted_onehot), len(vocab)-1)
predicted_number = vocab[ predicted_index ]
# Reverts deque length to seq_length:
sequence.popleft()
И, наконец, это пример выходных данных:
HIT! Current hit rate: 34.753665869071725 (predicted: 4, sequence was: deque([4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4]))
HIT! Current hit rate: 34.75566735175926 (predicted: 4, sequence was: deque([4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4]))
Predicted 4 but it was 8
HIT! Current hit rate: 34.75660255820374 (predicted: 4, sequence was: deque([4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4]))
HIT! Current hit rate: 34.758603766640086 (predicted: 4, sequence was: deque([4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4]))
HIT! Current hit rate: 34.7606048523142 (predicted: 4, sequence was: deque([4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4]))
HIT! Current hit rate: 34.76260581523739 (predicted: 4, sequence was: deque([4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4]))
HIT! Current hit rate: 34.76460665542095 (predicted: 4, sequence was: deque([4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4]))
HIT! Current hit rate: 34.76660737287616 (predicted: 4, sequence was: deque([4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4]))
Predicted 4 but it was 28
HIT! Current hit rate: 34.767541707556425 (predicted: 4, sequence was: deque([4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 8, 4, 4, 4, 4, 4, 4, 28, 4]))
Что происходит не так?
Большое спасибо заранее.