Если вы внимательно посмотрите на определение этой функции:
def LSTM_cell(hidden_layer_size, batch_size,number_of_layers, dropout=True, dropout_rate=0.8):
# ...
if dropout:
layer = tf.contrib.rnn.DropoutWrapper(layer, output_keep_prob=dropout_rate)
применяет функцию отсева к выходу каждой ячейки LSTM, tf.nn.dropout () и tf.contrib.rnn.DropoutWrapper () произвольно устанавливает некоторый процент от тензора элементов к нулю, вы можете проверить ссылки для более подробной информации. Согласно определению LSTM_cell (), каждый раз, когда вы звоните
o = session.run([model.logits], feed_dict={model.inputs:X_test[0:0+batch_size]})
каждый выходной нейрон каждой ячейки LSTM в вашей модели случайно установлен на ноль с вероятностью 1 - 0,8 = 0,2 = 20%. Следовательно, ваша модель является стохастической, и вы получаете разные результаты даже при работе модели с одинаковыми входными данными.
Выпадение - это метод регуляризации, полезный при обучении нейронных сетей, бесполезно (и, возможно, нелогично) применять его во время режима проверки и тестирования. Я не хочу называть код, который вы упомянули, неверным, но, как правило, выпадающий код реализуется с использованием заполнителя, подобного следующему:
def LSTM_cell(hidden_layer_size, batch_size,number_of_layers, dropout_rate):
# ...
layer = tf.contrib.rnn.BasicLSTMCell(hidden_layer_size)
layer = tf.contrib.rnn.DropoutWrapper(layer, output_keep_prob=dropout_rate)
class StockPredictionRNN(object):
def __init__(...)
# ...
self.dropout_placeholder = tf.placeholder(tf.float32)
cell, init_state = LSTM_cell(hidden_layer_size, batch_size, number_of_layers, self.dropout_placeholder)
Установите показатель отсева, например, 0,8 на этапе обучения:
for i in range(epochs):
# ...
o, c, _ = session.run([model.logits, model.loss, model.opt], feed_dict={model.inputs:X_batch, model.targets:y_batch, model.dropout_placeholder: 0.8})
Отключить отсев, установив показатель отсева на 1,0 во время фазы тестирования:
o = session.run([model.logits], feed_dict={model.inputs:X_test[i:i+batch_size], model.dropout_placeholder: 1.0})
Для получения дополнительной информации об отсеве, пожалуйста, проверьте оригинал бумаги .