Я новичок в Tensorflow.Сейчас я пытаюсь создать простую 4-х слойную полностью подключенную нейронную сеть для классификации набора данных CIFAR-10.Однако на моем тестовом наборе точность нейронной сети на тестовом наборе полностью статична и составляет 11%.
Я знаю, что полностью подключенная нейронная сеть, вероятно, не идеальна для этой задачи, но странно, что сеть вообще не улучшается / не изменяется.Поэтому мне было интересно, если кто-нибудь знает решение моей проблемы.Я скопировал мой код ниже, любая помощь приветствуется!Большое спасибо.
import tensorflow as tf
import numpy as np
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
y_train_one_hot = np.zeros((y_train.shape[0], 10))
for i in range(y_train.shape[0]):
y_train_one_hot[i][y_train[i]] = 1
y_test_one_hot = np.zeros((y_test.shape[0], 10))
for i in range(y_test.shape[0]):
y_test_one_hot[i][y_test[i]] = 1
x = tf.placeholder(dtype=tf.float32, shape=(None, 32, 32, 3), name='X')
y = tf.placeholder(dtype=tf.float32, shape=(None, 10), name='Y')
keep_prob = tf.placeholder(tf.float32)
x_flatten = tf.reshape(x, [-1, 32*32*3])
nn = tf.layers.dense(x_flatten, 1028, activation=tf.nn.relu)
nn = tf.nn.dropout(nn, keep_prob)
nn = tf.layers.dense(nn, 1028, activation=tf.nn.relu)
nn = tf.nn.dropout(nn, keep_prob)
nn = tf.layers.dense(nn, 512, activation=tf.nn.relu)
nn = tf.nn.dropout(nn, keep_prob)
prediction = tf.layers.dense(nn, 10, activation=tf.nn.relu)
cross_entropy = tf.losses.softmax_cross_entropy(onehot_labels=y, logits=prediction)
loss = tf.reduce_mean(cross_entropy)
train_step = tf.train.AdamOptimizer().minimize(loss)
correct_prediction = tf.equal(tf.argmax(prediction, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
for iteration in range(1000):
sess.run(train_step, feed_dict={x: x_train[:1000], y: y_train_one_hot[:1000], keep_prob: 0.5})
if iteration % 10 == 0:
acc = sess.run(accuracy, feed_dict={x: x_test[:100], y: y_test_one_hot[:100], keep_prob: 1.0})
print(acc)