Сбой встраивания Tensorflow из-за неодинакового графа - PullRequest
0 голосов
/ 01 марта 2019

Я создаю модель видео LSTM + с использованием поиска встраивания GloVe, но мне не удается, потому что мои значения заполнителей явно не находятся в том же графике, что и матрица встраивания GloVe.Соответствующий код:

class BaseModel(object):
    def __init__(self, batch_size, train_json_path, test_json_path):
        # just sets some class variables
        # ...

    def build_retrieval_models(self, sentence_embed_train, visual_feature_train, p_time_stamp, sentence_embed_test, visual_feature_test, time_stamp_test):
        with tf.variable_scope('Base_Model'):
            print 'Building training network...............................\n'
            bottom_positive = self.vid_model('v2s_lt', visual_feature_train, p_time_stamp)
            query_text_train = self.text_model('t2s_lt', sentence_embed_train)

            # reusing variables in model for test network
            tf.get_variable_scope().reuse_variables()
            print 'Building test network...............................\n'
            transformed_video_test = vid_model('v2s_lt', visual_feature_test, time_stamp_test)
            transformed_text_test = self.text_model('t2s_lt', sentence_embed_test)

            return bottom_positive, query_text_train, transformed_video_test, transformed_text_test

    def vid_model(self, name, bottom, time_stamp=None):
        if time_stamp is not None:
            bottom = tf.concat([bottom, time_stamp], axis=1)

        with tf.variable_scope(name):
            fc1 = fc_relu('vid_fc1', bottom, output_dim=self.visual_embedding_dim[0])
            fc2 = fc('vid_fc2', fc1, output_dim=self.visual_embedding_dim[1])

        return fc2

    def text_model(self, name, text_seq_batch):
        with tf.variable_scope(name):
            embedding_mat = tf.get_variable('glove_embedding', initializer=self.glove_embedding, trainable=False)
            embedded_seq = tf.nn.embedding_lookup(embedding_mat, text_seq_batch)

            lstm_top = lstm('text_lstm1', embedded_seq, None, output_dim=self.language_embedding_dim[0],
                    num_layers=1, forget_bias=1.0, apply_dropout=False,
                    concat_output=False)[-1]
            fc1 = fc('text_fc1', lstm_top, output_dim=self.language_embedding_dim[1])

        return fc1

    def init_placeholder(self):
        # text_seq_batch is a set of langauge indices that mark index of vocab word in glove vocab list
        text_seq_batch = tf.placeholder(tf.float32, shape=(self.batch_size, self.sentence_length))
        video_seq_batch = tf.placeholder(tf.float32, shape=(self.batch_size, self.visual_feature_dim))
        video_label_batch = tf.placeholder(tf.float32, shape=(self.batch_size, 2))

        text_seq_batch_test = tf.placeholder(tf.float32, shape=(self.test_batch_size, self.sentence_length))
        video_seq_batch_test = tf.placeholder(tf.float32, shape=(self.test_batch_size, self.visual_feature_dim))
        video_label_batch_test = tf.placeholder(tf.float32, shape=(self.test_batch_size, 2))

        return text_seq_batch, video_seq_batch, video_label_batch, text_seq_batch_test, video_seq_batch_test, video_label_batch_test

    def training(self, loss):
        v_dict = self.trainable_variables()
        vs_optimizer = tf.train.AdamOptimizer(self.vs_lr, name='vs_adam')
        vs_train_op = vs_optimizer.minimize(loss, var_list=v_dict)
        return vs_train_op

    def construct_model(self):
        self.text_seq_batch, self.video_seq_batch, self.video_label_batch, self.text_seq_batch_test, self.video_seq_batch_test, self.video_label_batch_test = self.init_placeholder()

        # build network
        bottom_pos_train, query_text_train, bottom_pos_test, query_text_test = self.build_retrieval_models(self.text_seq_batch, self.video_seq_batch, self.video_label_batch, self.text_seq_batch_test, self.video_seq_batch_test, self.video_label_batch_test)
        # compute loss, method not implemented for purposes of this question
        self.loss = self.compute_ranking_loss(bottom_pos_train, query_text_train)
        self.vs_train_op = self.training(self.loss)

        return self.loss, self.vs_train_op, bottom_pos_test, query_text_test

Вспомогательные функции, такие как fc, fc_relu и lstm, являются стандартными реализациями для полностью связанных слоев и LSTM.Я считаю, что проблема связана с тем, как я определяю вес модели при ее построении, но я не совсем понимаю, почему.В основной функции, которая обучает модель, код выглядит следующим образом:

def train_model():
    batch_size = 2
    max_steps = 30000
    train_json_path = 'data/train_data.json'
    test_json_path = 'data/val_data.json'

    model = base_model.BaseModel(batch_size, train_json_path, test_json_path)
    with tf.Graph().as_default():

        loss, vs_train_op, bottom_pos_test, query_text_test = model.construct_model()
        # create a session for running Ops on the Graph
        sess = tf.Session()
        init = tf.initialize_all_variables()
        sess.run(init)

        for step in xrange(max_steps):
            start_time = time.time()
            feed_dict = model.fill_feed_dict_train_reg()
            _, loss_value = sess.run([vs_train_op, loss], feed_dict=feed_dict)
            duration = time.time() - start_time

            print('Step %d: loss = %.3f (%.3f sec)' % (step, loss_value, duration))

def main(_):
    train_model()

if __name__ == '__main__':
    tf.app.run()

Поскольку все создается после создания графа Tensorflow, не должны ли все Tensors лежать в одном графе?Тем не менее, я получаю ошибку:

ValueError: Tensor("Base_Model/Placeholder:0", shape=(2, 2), dtype=float32) must be from the same graph as Tensor("Base_Model/glove_embedding:0", shape=(400001, 50), dtype=float32_ref).

, которая вытекает из строки embedded_seq = tf.nn.embedding_lookup(embedding_mat, text_seq_batch) в функции text_model BaseModel.

Кажется, что Tensorflow на самом деле жалуется на моюvideo_label_batch переменная, поскольку она имеет форму (2,2), но она не участвует во время поиска встраивания.Так почему же эта проблема в коде?

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...