Недавно я работаю над проектом ", предсказывающим будущие траектории объектов по их прошлым траекториям, используя LSTM в Tensorflow ".(Здесь траектория означает последовательность двухмерных позиций.)
Вход в LSTM - это, конечно, «прошлые траектории», а вывод - «будущие траектории».
Размер мини-пакет фиксируется при обучении.Однако количество прошедших траекторий в мини-партии может быть разным.Например, пусть размер мини-партии равен 10. Если у меня есть только 4 прошлых траектории для текущей итерации обучения, 6 из 10 в мини-партии дополняются нулевым значением.
При расчете потерь для обратного распространения я допускаю, что потери от 6 равны нулю, так что только 4 способствуют обратному распространению.
Проблема, которую я волную, заключается в следующем.Похоже, что Tensorflow все еще вычисляет градиенты для 6, даже если их потеря равна нулю.В результате скорость обучения становится меньше, когда я увеличиваю размер мини-партии, даже если я использую те же данные обучения.
Я также использовал функцию tf.where при расчете потерь.Однако время тренировки не уменьшается.
Как мне сократить время обучения?
Здесь я приложил свой псевдокод для обучения.
# For each frame in a sequence
for f in range(pred_length):
# For each element in a batch
for b in range(batch_size):
with tf.variable_scope("rnnlm") as scope:
if (f > 0 or b > 0):
scope.reuse_variables()
# for each pedestrian in an element
for p in range(MNP):
# ground-truth position
cur_gt_pose = ...
# loss mask
loss_mask_ped = ... # '1' or '0'
# go through RNN decoder
output_states_dec_list[b][p], zero_states_dec_list[b][p] = cell_dec(cur_embed_frm_dec,
zero_states_dec_list[b][p])
# fully connected layer for output
cur_pred_pose_dec = tf.nn.xw_plus_b(output_states_dec_list[b][p], output_wd, output_bd)
# go through embedding function for the next input
prev_embed_frms_dec_list[b][p] = tf.reshape(tf.nn.relu(tf.nn.xw_plus_b(cur_pred_pose_dec, embedding_wd, embedding_bd)), shape=(1, rnn_size))
# calculate MSE loss
mse_loss = tf.reduce_sum(tf.pow(tf.subtract(cur_pred_pose_dec, cur_gt_pose_dec), 2.0))
# only valid ped's traj contributes to the loss
self.loss += tf.multiply(mse_loss, loss_mask_ped)