В настоящее время я изучаю мультиклассовую классификацию с использованием регрессии logisti c в Tensoflow 2.0. Я думаю, что построил модель правильно, но во время тренировок я постоянно получаю 0.333 потери и точность 66.67% для каждой эпохи. Я считаю, что потери должны быть уменьшены, а точность увеличивается для каждой эпохи, но она остается неизменной, несмотря ни на что
Может кто-нибудь сказать мне, что происходит с моей моделью? И почему оно не сходится?
def logisticRegression2(x,weight,bias):
lr = tf.add(tf.matmul(x,weight),bias)
#return sigmoid fun
#return tf.nn.signmoid(lr)
return lr
def crossEntropy2(yTrue,yPredicted):
loss = tf.nn.softmax(yPredicted)
# reduce_mean: Computes the mean of elements across dimensions of a tensor.
return tf.reduce_mean(loss)
def getAccuracy2(y_true, y_pred):
y_true = tf.cast(y_true, dtype=tf.int64)
preds = tf.cast(tf.argmax(y_pred, axis=0), dtype=tf.int64)
preds = tf.equal(y_true, preds)
return tf.reduce_mean(tf.cast(preds, dtype=tf.float64))
def gradientDescent2(x,y,weight,bias):
with tf.GradientTape() as tape:
yPredicted = logisticRegression2(x,weight,bias)
lossValue = crossEntropy2(y,yPredicted)
return tape.gradient(lossValue, [weight,bias] )
learningRate = 0.01
batchSize = 128
n_batches = 10000
optimizer2 = tf.optimizers.SGD(learningRate)
dataset2 = tf.data.Dataset.from_tensor_slices((trX, trY))
dataset2 = dataset2.repeat().shuffle(xTrain.shape[0]).batch(batchSize)
weight = tf.Variable(tf.zeros([4,3], dtype = tf.float64))
bias = tf.Variable(tf.zeros([3], dtype = tf.float64))
Вот мое обучение l oop:
predicted = []
for i, (xx2,xy2) in enumerate(dataset2.take(10000), 1):
gradient = gradientDescent2(xx2,xy2,weight,bias)
optimizer2.apply_gradients(zip(gradient,[weight,bias]))
yPredicted = logisticRegression2(xx2,weight,bias)
loss = crossEntropy2(xy2,yPredicted)
accuracy = getAccuracy2(xy2,yPredicted)
print("Batch number: %i, loss: %f, accuracy: %f" % (i, loss, accuracy*100))
Вывод:
Batch number: 1, loss: 0.333333, accuracy: 66.666667
Batch number: 2, loss: 0.333333, accuracy: 66.666667
Batch number: 3, loss: 0.333333, accuracy: 66.666667
Batch number: 9998, loss: 0.333333, accuracy: 66.666667
Batch number: 9999, loss: 0.333333, accuracy: 66.666667
Batch number: 10000, loss: 0.333333, accuracy: 66.666667