Прогнозирование и обучение RNN с использованием tf.nn.raw_rnn с Eager Execution - PullRequest
0 голосов
/ 04 апреля 2019

В настоящее время я пытаюсь реализовать вариант последовательного вариационного автоэнкодера с активным исполнением.Грубо говоря, я хочу прибегнуть к некоторым «подсказкам» из скрытых переменных при моделировании некоторых последовательных данных.Поскольку модель нейронных сетей не является стандартной, я создал собственную ячейку RNN, а также функцию loop_fn.Затем я передаю эти два в tf.nn.raw_rnn, то есть у меня есть tf.nn.raw_rnn (RNN_cell, loop_fn).

Теперь проблема в том, что я не имею никакой идеи двигаться дальше,ни для прогнозирования, ни для обучения в сети.

Что меня застревает, так это то, что tf.nn.raw_rnn (RNN_cell, loop_fn) выдает числовые значения вместо модели (скажем, как экземпляр класса tf.keras.Model).Итак, что я должен делать с этими числами?Иными словами, как относиться к tf.nn.raw_rnn (RNN_cell, loop_fn) как к модели, которая может потреблять новые входные данные и давать выходные данные?

Я смотрел некоторые учебные блоги о RNN.Тем не менее, ни один из них точно не использует tf.nn.raw_rnn (с нетерпеливым исполнением).У кого-нибудь есть подсказка?

import tensorflow as tf
tfe = tf.contrib.eager
tf.enable_eager_execution()

from tensorflow.keras.models import Model

#defining class containing each module of model
class ModuleBox(tf.keras.Model):
    def __init__(self, latent_dim, intermediate_dim):
        super(ModuleBox, self).__init__()
        self.latent_dim = latent_dim

        self.inference_net = Model(...)

        self.generative_net = Model(...)

class PreSSM(tf.contrib.rnn.RNNCell):
    def __init__(self, latent_dim = 4, intermediate_dim = 50):
        self.input_dim = latent_dim + 4 #note for toy problem

        module = ModuleBox(latent_dim, intermediate_dim)

        self.inference_net = module.inference_net

        self.generative_net = module.generative_net

    @property
    def state_size(self):
        return latent_dim

    @property
    def output_size(self):
        return 2 #(x,y) coordinate

    def __call__(self, inputs, state):
        next_state = self.inference_net(inputs)[-1]
        output = self.generative_net(next_state)
        return output, next_state

#the loop_fn function, needed by tf.nn.raw_rnn
def loop_fn(time, cell_output, cell_state, loop_state):
    emit_output = cell_output # ==None for time == 0
    if cell_output is None: # when time == 0
        next_cell_state = init_state
        emit_output = tf.zeros([output_dim])
    else :
        emit_output = cell_output
        next_cell_state = cell_state

    elements_finished = (time >= seq_length)
    finished = tf.reduce_all(elements_finished)

    if finished :
        next_input = tf.zeros(shape=(output_dim), dtype=tf.float32)
    else :
        next_input = tf.concat([inputs_ta.read(time), next_cell_state],-1)

    next_loop_state = None
    return (elements_finished, next_input, next_cell_state, emit_output, 
          next_loop_state)

#instatiation of RNN_cell
cell = PreSSM()

#the outputs
outputs_ta, final_state, _ = tf.nn.raw_rnn(cell, loop_fn)
outputs = outputs_ta.stack()
...