Вопрос внедрения BRNN - PullRequest
       32

Вопрос внедрения BRNN

0 голосов
/ 29 апреля 2018

Я пытаюсь внедрить двунаправленный RNN с нуля, и у меня проблема с этим. Давайте предположим, что мы внедрили ячейку RNN с заданным количеством скрытых единиц, в этом случае прямой проход для BRNN будет следующим (псевдокод):

def brnn_forward(input):
    hiddden_state = RNN.forward(input)
    reversed_input = reverse(input)
    hiddden_state_reversed = RNN.forward(reversed_input)
    output = concatenate(hiddden_state, hiddden_state_reversed)
    return output

Но тогда я не знаю, как реализовать обратный проход. Я получаю производную ошибку dA (shape = (hidden_units, batch_size, times)) из следующего слоя с формой вывода прямого прохода (если, конечно, у нас не было конкатенации выходов, которые удвоили количество скрытых единиц после прямого прохода). Однако стандартная обратная функция ячейки RNN принимает значение dA в форме прямого входа, поэтому я попробовал:

def brnn_backward(dA):
    h = number_of_hidden_units
    d_hiddden_state = RNN.backward(dA[:h,:,:])
    d_hiddden_state_reversed = RNN.backward(dA[h:,:,:])
    dA_for_previous_layer = d_hiddden_state+d_hiddden_state_reversed
    return dA_for_previous_layer

Но это не сработало и дало мне результаты хуже, чем с однонаправленным RNN. Также я не уверен, как найти производную ошибку для предыдущего слоя (в случае, если у нас есть слой внедрения, например). Может ли кто-нибудь помочь с обратным пасом?

1 Ответ

0 голосов
/ 19 июля 2018

На мой взгляд, назад и вперед в BiRNN - это один и тот же процесс, их различие - последовательность ввода.

Вам не нужно реализовывать определенную функцию для прямой или обратной ячейки, просто создайте нормальную RNN cell, задавайте входные состояния и получайте выходные состояния. В прямой цепочке вы можете передать в нее обычную последовательность ввода и передать обратную последовательность ввода в обратной цепочке.

Если у вас есть время, вы можете прочитать код, который TensorFlow реализует BiRNN на основе обычного RNN. https://github.com/tensorflow/tensorflow/blob/r1.9/tensorflow/python/ops/rnn.py#L315

...