Здесь есть две разные ошибки:
predicted_y = tf.nn.softmax(logits)
probas=tf.argmax(predicted_y, axis=1)
Первая состоит в том, что, поскольку ваш y
не закодирован как горячий код, вы не должны использовать softmax
, но sigmoid
(что-то вы правильно делаете в своем loss
определении); Итак, первая строка должна быть
predicted_y = tf.nn.sigmoid(logits)
Вторая строка, опять же, поскольку ваш y
не кодируется в одноразовом формате, не выполняет то, что вы думаете, так как ваши предсказания являются одноэлементными массивами argmax
по определению 0, поэтому вы не получите правильное преобразование вероятностей в сложные прогнозы (какие жесткие прогнозы в любом случае не используются для расчета RO C - для этого вам нужны вероятности) .
Вы должны сбросить probas
в целом и изменить свой prediction_function
на:
prediction_function=lambda vector1: predicted_y.eval({input_x:vector1})
Таким образом, а для learning_rate=0.1
AU C перейдет в 1.0 от самая первая итерация:
loss at iter 0:0.0085
train auc: 0.9998902365402557
test auc: 1.0
loss at iter 1:0.0066
train auc: 1.0
test auc: 1.0
loss at iter 2:0.0052
train auc: 1.0
test auc: 1.0
loss at iter 3:0.0042
train auc: 1.0
test auc: 1.0
loss at iter 4:0.0035
train auc: 1.0
test auc: 1.0
и вы получите правильные прогнозы для X_train
:
np.round(prediction_function(X_train)).reshape(1,-1)
# result:
array([[0., 1., 1., 0., 1., 1., 1., 0., 1., 1., 1., 0., 1., 1., 1., 1.,
1., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0.,
1., 1., 0., 1., 1., 0., 0., 0., 0., 1., 1., 0., 0., 1., 0., 0.,
1., 1., 0., 0., 1., 1., 1., 0., 0., 1., 0., 1., 0., 0., 0., 1.,
0., 1., 1., 1., 0., 1., 0., 1., 0., 0., 1., 0., 1., 1., 1., 1.,
0., 0., 1., 1., 0., 1., 1., 0., 1., 0., 0., 0., 1., 0., 1., 1.,
0., 1., 1., 0., 1., 1., 1., 1., 0., 1., 0., 1., 0., 1., 1., 1.,
1., 0., 0., 1., 0., 0., 1., 0., 1., 0., 0., 0., 1., 1., 0., 0.,
0., 0., 0., 1., 0., 1., 1., 1., 1., 1., 0., 0., 0., 1., 1., 1.,
0., 0., 0., 1., 1., 1., 1., 0., 0., 1., 1., 0., 1., 1., 1., 0.,
1., 1., 0., 1., 1., 1., 0., 1., 0., 1., 1., 0., 0., 1., 1., 0.,
1., 1., 1., 1., 0., 0., 1., 1., 0., 0., 0., 0., 1., 1., 0., 0.,
0., 0., 1., 0., 0., 1., 1., 0., 1., 0., 0., 1., 1., 0., 0., 1.,
1., 0., 0., 1., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1.,
1., 0., 1., 1., 1., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0.,
1., 1., 1., 1., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 1., 1.,
0., 1., 1., 0., 1., 0., 1., 0., 0., 0., 1., 0., 0., 1.]],
dtype=float32)