Я пытаюсь прочитать набор данных cifar10 и использовать его для обучения модели, поэтому я пытаюсь читать партии и запускать сеанс, как показано ниже:
# Optimizer
opt = tf.train.AdamOptimizer(0.0001)
global_step = tf.get_variable('global_step', initializer=tf.constant(0), trainable=False)
train_op = opt.apply_gradients(zip(grads, var_list), global_step=global_step)
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
image_batch, label_batch = tf.train.batch([x_train, y_train], batch_size=batch_size)
#image_batch_uint8 = tf.cast(image_batch, tf.uint8)
# Train
with tf.Session() as sess:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
sess.run(tf.global_variables_initializer())
for i in range(10000000):
_loss_value, _reward_value, _ = sess.run([loss, reward, train_op], feed_dict={
images_ph: image_batch,
labels_ph: label_batch
})
if i % 100 == 0:
print('iter: ', i, '\tloss: ', _loss_value, '\treward: ', _reward_value)
Однако я получаю эту ошибку:
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 900, in run
run_metadata_ptr)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1086, in _run
'feed with key ' + str(feed) + '.')
The value of a feed cannot be a tf.Tensor object. Acceptable feed values include Python scalars, strings, lists, numpy ndarrays, or TensorHandles.For reference, the tensor object was Tensor("batch:0", shape=(32, 50000, 32, 32, 3), dtype=uint8) which was passed to the feed with key Tensor("Placeholder:0", shape=(?, 1024), dtype=float32).
Что я делаю не так? Как я могу убедиться, что все наборы данных будут заполнены как эпохи, есть ли более простой способ подачи данных в набор данных ??