Я пытаюсь обучить глубокую сеть LSTM для классификации.
Я подаю в качестве входа вектор характеристик длины 1000 в сеть LSTM.
linput = tf.placeholder(tf.float32,shape=[None,20,1000],name= 'lstm_input')
with tf.device('/gpu:0'):
with tf.variable_scope('deep__lstm__7'):
lstm_cells = [tf.nn.rnn_cell.LSTMCell(hidden_units, state_is_tuple=True) for hidden_units in [1024,1024,1024,1024,1024,1024,512]] # tf.nn.rnn_cell.GRUCell or tf.nn.rnn_cell.BasicRNNCell instead
cells = tf.nn.rnn_cell.MultiRNNCell(lstm_cells, state_is_tuple=True)
init_state = cells.zero_state(5, tf.float32)
rnn_outputs, final_state = tf.nn.dynamic_rnn(cells, linput, initial_state=init_state)
W = tf.get_variable('W', [512,6])
b = tf.get_variable('b', [6],initializer=tf.constant_initializer(0.0))
outputs = tf.reshape(rnn_outputs, [-1, hidden_units])
op1 = tf.reshape(outputs[inds[0]],[1,hidden_units])
op2 = tf.reshape(outputs[inds[1]],[1,hidden_units])
op3 = tf.reshape(outputs[inds[2]],[1,hidden_units])
op4 = tf.reshape(outputs[inds[3]],[1,hidden_units])
op5 = tf.reshape(outputs[inds[4]],[1,hidden_units])
output = tf.concat([op1,op2,op3,op4,op5],axis=0)
logits = tf.matmul(output, W) + b
Я использовал op1, op2, op3, op4, op5, потому что я использовал размер пакета, равный 5, и я использовал 20 временных шагов для каждого в пакете.Итак, я беру выходные данные сети LSTM на последнем шаге по времени для каждого элемента пакета.
Входные данные для моей глубокой сети LSTM являются функциями, извлеченными из кадров видео в KTHнабор данных.
То есть 1-е измерение обозначает размер пакета, второе измерение обозначает количество кадров, т. е. временной интервал, а последнее - длину вектора признаков.
Я выполняю выборку20 кадров случайным образом, чтобы удалить любую избыточность в кадрах.
Теперь проблема, с которой я сталкиваюсь, заключается в том, что ошибка уменьшается, но точность колеблется от 0% до 40% даже после 1000 итераций.
Я попытался увеличить количество слоев с 3 до 7, как теперь показывает код, и попытался перегрузить сеть, но все еще та же проблема.
Может кто-нибудь сказать, что я все еще делаю неправильно?