Мультикласс логистика c регрессия Tensorflow 2.0 - PullRequest
0 голосов
/ 13 апреля 2020

В настоящее время я изучаю мультиклассовую классификацию с использованием регрессии logisti c в Tensoflow 2.0. Я думаю, что построил модель правильно, но во время тренировок я постоянно получаю 0.333 потери и точность 66.67% для каждой эпохи. Я считаю, что потери должны быть уменьшены, а точность увеличивается для каждой эпохи, но она остается неизменной, несмотря ни на что

Может кто-нибудь сказать мне, что происходит с моей моделью? И почему оно не сходится?

def logisticRegression2(x,weight,bias):
    lr = tf.add(tf.matmul(x,weight),bias)
    #return sigmoid fun
    #return tf.nn.signmoid(lr)
    return lr

def crossEntropy2(yTrue,yPredicted):
    loss = tf.nn.softmax(yPredicted)
    # reduce_mean: Computes the mean of elements across dimensions of a tensor.
    return tf.reduce_mean(loss)

def getAccuracy2(y_true, y_pred):
    y_true = tf.cast(y_true, dtype=tf.int64)
    preds = tf.cast(tf.argmax(y_pred, axis=0), dtype=tf.int64)
    preds = tf.equal(y_true, preds)
    return tf.reduce_mean(tf.cast(preds, dtype=tf.float64))


def gradientDescent2(x,y,weight,bias):
    with tf.GradientTape() as tape:
        yPredicted = logisticRegression2(x,weight,bias)
        lossValue = crossEntropy2(y,yPredicted)
        return tape.gradient(lossValue, [weight,bias] )

learningRate = 0.01
batchSize = 128
n_batches = 10000
optimizer2 = tf.optimizers.SGD(learningRate)


dataset2 = tf.data.Dataset.from_tensor_slices((trX, trY))
dataset2 = dataset2.repeat().shuffle(xTrain.shape[0]).batch(batchSize)

weight = tf.Variable(tf.zeros([4,3], dtype = tf.float64))
bias = tf.Variable(tf.zeros([3], dtype = tf.float64))

Вот мое обучение l oop:

predicted = []
for i, (xx2,xy2) in enumerate(dataset2.take(10000), 1):
    gradient = gradientDescent2(xx2,xy2,weight,bias)
    optimizer2.apply_gradients(zip(gradient,[weight,bias]))

    yPredicted = logisticRegression2(xx2,weight,bias)
    loss = crossEntropy2(xy2,yPredicted)
    accuracy = getAccuracy2(xy2,yPredicted)
    print("Batch number: %i, loss: %f, accuracy: %f" % (i, loss, accuracy*100))

Вывод:

Batch number: 1, loss: 0.333333, accuracy: 66.666667
Batch number: 2, loss: 0.333333, accuracy: 66.666667
Batch number: 3, loss: 0.333333, accuracy: 66.666667

Batch number: 9998, loss: 0.333333, accuracy: 66.666667
Batch number: 9999, loss: 0.333333, accuracy: 66.666667
Batch number: 10000, loss: 0.333333, accuracy: 66.666667
...