Невозможно понять tf.nn.raw_rnn - PullRequest
0 голосов
/ 09 мая 2018

В официальной документации из tf.nn.raw_rnn мы имеем структуру emit в качестве третьего вывода loop_fn при первом запуске loop_fn.

Позже emit_structure используется для копирования tf.zeros_like(emit_structure) в записи мини-пакетов, которые заканчиваются на emit = tf.where(finished, tf.zeros_like(emit_structure), emit).

мое непонимание или паршивая документация со стороны Google такова: структура emit None, поэтому tf.where(finished, tf.zeros_like(emit_structure), emit) собирается выбросить ValueError, как это делает tf.zeros_like(None). Может кто-нибудь, пожалуйста, заполните то, что мне здесь не хватает?

1 Ответ

0 голосов
/ 09 мая 2018

Да, документ довольно запутанный в этом месте. Если вы посмотрите на внутренние элементы tf.nn.raw_rnn, то ключевым термином будет "в псевдокоде" , поэтому пример в документе не является точным.

Точный исходный код выглядит следующим образом (может отличаться в зависимости от версии тензорного потока):

if emit_structure is not None:
  flat_emit_structure = nest.flatten(emit_structure)
  flat_emit_size = [emit.shape if emit.shape.is_fully_defined() else
                    array_ops.shape(emit) for emit in flat_emit_structure]
  flat_emit_dtypes = [emit.dtype for emit in flat_emit_structure]
else:
  emit_structure = cell.output_size
  flat_emit_size = nest.flatten(emit_structure)
  flat_emit_dtypes = [flat_state[0].dtype] * len(flat_emit_size)

Таким образом, он обрабатывает случай, когда emit_structure is None, и просто принимает значение cell.output_size. Вот почему на самом деле ничего не ломается.

...