Так что я использую тензор потока 1.10 и тренируюсь на подмножестве набора данных ретинопатии.Проблема в том, что он всегда предсказывает, какой класс встречается чаще всего.В моем случае это класс 0. Итак, я немного покопался и наткнулся на то, что называется недостаточной выборкой.Я пропустил все 0 (просто чтобы посмотреть, что происходит), он только предсказывает класс 2 после 0. Очевидно, класс 2 имеет самую высокую частоту.Вот код для оптимизации:
def data_pipe_line(data,checkpoint_path,i_data=None,epoch=5):
place_X=tf.placeholder(tf.float32,[None,500,400,3],name='p1')
place_Y=tf.placeholder(tf.int32,[None],name='p2')
infer_data=tf.data.Dataset.from_tensor_slices((place_X,place_Y))
infer_data=infer_data.batch(100)
iterator=tf.data.Iterator.from_structure(data.output_types,data.output_shapes)
next_image,next_label=iterator.get_next()
Y=tf.one_hot(next_label,5)
Y=tf.cast(Y,tf.float32)
logits=model(next_image,0.7)
print(logits)
print(Y)
train_iterator=iterator.make_initializer(data,name='train_op')
inference_iterator_op=iterator.make_initializer(infer_data,name='inference_op')
with tf.name_scope("loss"):
loss=tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits_v2(labels=Y,logits=logits),name='cost')
#the learning rate is so low because the batch-size is very small and has a lot of noise
optimizer=tf.train.AdamOptimizer(learning_rate=0.0005).minimize(loss)
#getting the accuracy
prediction=tf.argmax(logits,1,name='pred')
equality=tf.equal(prediction,tf.argmax(Y,1))
accuracy=tf.reduce_mean(tf.cast(equality,tf.float32))
init_op=tf.global_variables_initializer()
tf.summary.scalar("loss",loss)
tf.summary.scalar("accuracy",accuracy)
merged=tf.summary.merge_all()
saver=tf.train.Saver()
j=0
with tf.Session() as sess:
writer=tf.summary.FileWriter("./nn_logs",sess.graph)
sess.run(init_op)
for _ in range(epoch):
sess.run(train_iterator)
while True:
try:
#print(sess.run(logits))
j=j+1
summary = sess.run(merged)
_,acc,l=sess.run([optimizer,accuracy,loss])
if(j%20==0 or j==1):
print("iters: {}, loss: {:.10f}, training accuracy: {:.2f}".format(j, l, acc*100))
writer.add_summary(summary,j)
except tf.errors.OutOfRangeError:
break
saver.save(sess,checkpoint_path)
Модель хорошо тренируется, потери снижаются на некоторое время, а затем просто колеблются там (в диапазоне 5).Точность курса сильно колеблется, так как прогнозируется только 1 класс.