Tensorflow зависает при sess.run (summary_merge) - PullRequest
0 голосов
/ 07 июня 2019

Когда я запускаю tenorflow.Session.run (tf.summary.merge_all ()), он там зависает.

Я делаю kaggle и хочу использовать предварительно обученную модель в качестве позвоночника позвоночника нашей модели. Поэтому я импортирую мета-график предварительно обученной модели и добавляю новые загруженные узлы в загруженный график, чтобы соответствовать форме аннотации целевого набора данных. Все идет хорошо, пока я не попытаюсь визуализировать тренировочный процесс с помощью тензорной доски. Он застрял на sess.run (tf.summary.merge_all ()).

Хотя я не пробовал это солютин ! Кажется, это решение не в моем случае.

Если это так, как я могу переписать свой код в соответствии с этим решением? Пожалуйста, дайте мне полный код. Спасибо!

Моя часть определения графа:

g1 = tf.Graph()
with g1.as_default():
    with tf.name_scope("global_step"):
        global_steps = tf.Variable(0, trainable=False)

    with tf.name_scope("x1"):
        x1 = tf.placeholder(dtype=tf.float32, shape=[None, 160, 160, 3], name='x1')

    with tf.name_scope("x2"):
        x2 = tf.placeholder(dtype=tf.float32, shape=[None, 160, 160, 3], name='x2')

    with tf.name_scope("y_label"):
        y = tf.placeholder(dtype=tf.float32, shape=[None, 1], name='y')

    with tf.variable_scope("model") as scope:
        with tf.name_scope("model_for_x1"):
            saver1 = tf.train.import_meta_graph('./model/model.meta')
        scope.reuse_variables()
        with tf.name_scope("model_for_x2"):
            saver2 = tf.train.import_meta_graph('./model/model.meta') 

    # extract the output from pre-trained model
        .
        .
        .
        .
        .(omit some codes)
        .
        .
        .
        .
    # final output
    with tf.name_scope("final_output"):
        final_output = tf.identity(FC2, 'final_output')

    # loss function
    with tf.name_scope("Loss"):
        loss_per = tf.nn.sigmoid_cross_entropy_with_logits(
            labels=y,
            logits=final_output,
            name='loss_per')
        loss = tf.reduce_mean(loss_per, name='loss_average')

    # Optimizer
    with tf.name_scope("Optimizer"):
        optimizer = tf.train.AdamOptimizer(0.00001, name='Adam2')  # already has a adam optimizer, so rename it
        train_step = optimizer.minimize(loss, global_step=global_steps)

    # visualization
    with tf.name_scope("Summary"):
        loss_summary = tf.summary.scalar("loss", loss)
        merged_summary = tf.summary.merge_all()

Моя учебная часть:

with tf.Session(graph=g1) as sess:
    train_writer = tf.summary.FileWriter("./logs" + "/train", sess.graph)
    val_writer = tf.summary.FileWriter("./logs" + "/val")
    tf.global_variables_initializer().run()  # initialize the weights
    saver1.restore(sess, './model/model')  # cover the weights by pre-trained model
    saver2.restore(sess, './model/model')  # cover the weights by pre-trained model
    print("*****************Start Training!!!******************")
    for epochNum in range(epochs):
        valid_los, summary = sess.run([loss, loss_summary],  <<------ hang here, if I threw loss_summary away, it will run fluently.
                                      feed_dict={
                                        x1: valid_batch_data[0],
                                        x2: valid_batch_data[1],
                                        y: valid_annotation})
        val_writer.add_summary(summary, global_steps)
        for iterNum in range(len(train)//batch_size):
            test_batch_data, test_annotation = next(gen(train, train_person_to_images_map, batch_size, (160, 160)))
            print('**********************************')
            train_los, _, summary = sess.run([loss, train_step, merged_summary],
                                             feed_dict={
                                                x1: test_batch_data[0],
                                                x2: test_batch_data[1],
                                                y: np.reshape(test_annotation, (batch_size, 1))})
            train_writer.add_summary(summary, global_steps)
            print('epoch: %d, iteration: %d, train_loss_per_iter: %f, valid_loss_per_epoch: %f' % (epochNum + 1, iterNum + 1, valid_los))
    train_writer.close()
    val_writer.close()

Там нет никаких сообщений об ошибках, а только зависать там!

...