Tensorflow, tf.train.batch, безрезультатно - PullRequest
0 голосов
/ 06 мая 2018

Я новичок в tf.train.batch, поэтому я написал пример для его тестирования. Когда я запустил код, я не получил результата, и процесс все еще работал.

Вы встречали такую ​​же ситуацию раньше? Большое спасибо заранее!

import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import numpy as np
import tensorflow as tf


a = [[1,2,3,4],[1,2,3,4],[1,2,3,4],[1,2,3,4]]
b = [1,2,3,4]
input_queue = tf.train.slice_input_producer([a, b],num_epochs=None,shuffle=False)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    for i in range(4):


        x,y = tf.train.batch([a,b], batch_size=2)


        x_,y_ =sess.run([x,y])
        print(x_,y_)

    coord.request_stop()
    coord.join(threads)

Кроме того, работает функция tf.train.slice_input_producer. Когда я игнорирую tf.train.batch, код становится:

import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import numpy as np
import tensorflow as tf


a = [[1,2,3,4],[1,2,3,4],[1,2,3,4],[1,2,3,4]]
b = [1,2,3,4]
input_queue = tf.train.slice_input_producer([a, b],num_epochs=None,shuffle=False)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    for i in range(4):

     print(sess.run(input_queue))

coord.request_stop()
coord.join(threads)

Результат:

[array([1, 2, 3, 4]), 1]
[array([1, 2, 3, 4]), 2]
[array([1, 2, 3, 4]), 3]
[array([1, 2, 3, 4]), 4]

1 Ответ

0 голосов
/ 07 мая 2018

Я думаю, что главная проблема в том, что вы не указали enqueue_many до True, чтобы он просто повторял всю партию. Вы можете прочитать больше на официальный документ . Вот рабочий образец:

a = [[1,2,3,4],[1,2,3,4],[1,2,3,4],[1,2,3,4]]
b = [1,2,3,4]
a, b = tf.train.batch([a,b], batch_size=1, num_threads=1, capacity=4, enqueue_many=True)
with tf.Session() as sess:
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess, coord)
    for i in range(4):
        print(sess.run([a,b]))
    coord.request_stop()
    coord.join(threads)
...