В настоящее время я пытаюсь реализовать вариант последовательного вариационного автоэнкодера с активным исполнением.Грубо говоря, я хочу прибегнуть к некоторым «подсказкам» из скрытых переменных при моделировании некоторых последовательных данных.Поскольку модель нейронных сетей не является стандартной, я создал собственную ячейку RNN, а также функцию loop_fn.Затем я передаю эти два в tf.nn.raw_rnn, то есть у меня есть tf.nn.raw_rnn (RNN_cell, loop_fn).
Теперь проблема в том, что я не имею никакой идеи двигаться дальше,ни для прогнозирования, ни для обучения в сети.
Что меня застревает, так это то, что tf.nn.raw_rnn (RNN_cell, loop_fn) выдает числовые значения вместо модели (скажем, как экземпляр класса tf.keras.Model).Итак, что я должен делать с этими числами?Иными словами, как относиться к tf.nn.raw_rnn (RNN_cell, loop_fn) как к модели, которая может потреблять новые входные данные и давать выходные данные?
Я смотрел некоторые учебные блоги о RNN.Тем не менее, ни один из них точно не использует tf.nn.raw_rnn (с нетерпеливым исполнением).У кого-нибудь есть подсказка?
import tensorflow as tf
tfe = tf.contrib.eager
tf.enable_eager_execution()
from tensorflow.keras.models import Model
#defining class containing each module of model
class ModuleBox(tf.keras.Model):
def __init__(self, latent_dim, intermediate_dim):
super(ModuleBox, self).__init__()
self.latent_dim = latent_dim
self.inference_net = Model(...)
self.generative_net = Model(...)
class PreSSM(tf.contrib.rnn.RNNCell):
def __init__(self, latent_dim = 4, intermediate_dim = 50):
self.input_dim = latent_dim + 4 #note for toy problem
module = ModuleBox(latent_dim, intermediate_dim)
self.inference_net = module.inference_net
self.generative_net = module.generative_net
@property
def state_size(self):
return latent_dim
@property
def output_size(self):
return 2 #(x,y) coordinate
def __call__(self, inputs, state):
next_state = self.inference_net(inputs)[-1]
output = self.generative_net(next_state)
return output, next_state
#the loop_fn function, needed by tf.nn.raw_rnn
def loop_fn(time, cell_output, cell_state, loop_state):
emit_output = cell_output # ==None for time == 0
if cell_output is None: # when time == 0
next_cell_state = init_state
emit_output = tf.zeros([output_dim])
else :
emit_output = cell_output
next_cell_state = cell_state
elements_finished = (time >= seq_length)
finished = tf.reduce_all(elements_finished)
if finished :
next_input = tf.zeros(shape=(output_dim), dtype=tf.float32)
else :
next_input = tf.concat([inputs_ta.read(time), next_cell_state],-1)
next_loop_state = None
return (elements_finished, next_input, next_cell_state, emit_output,
next_loop_state)
#instatiation of RNN_cell
cell = PreSSM()
#the outputs
outputs_ta, final_state, _ = tf.nn.raw_rnn(cell, loop_fn)
outputs = outputs_ta.stack()