Ошибка перекрестной энтропии остается неизменной для различных значений - PullRequest
1 голос
/ 28 мая 2019

Я использую Cross Entropy с Softmax в качестве функции потерь для моей нейронной сети. Кросс-энтропийная функция, которую я написал, выглядит следующим образом:

def CrossEntropy(calculated,desired):
    sum=0
    n=len(calculated)
    for i in range(0,n):
        sum+=(desired[i] * math.log(calculated[i])) + ((1-desired[i])* math.log(1-calculated[i]))

    crossentropy=(-1)*sum/n
    return crossentropy

Теперь давайте предположим, что желаемый результат равен [1,0,0,0] , и мы тестируем его для двух рассчитанных выходов, т.е. a = [0,1,0.9,0.1,0.1] и b = [0,1,0,1,0,1,0,9] . Проблема в том, что для обоих этих вычисленных выходных данных функция будет возвращать одно и то же значение для кросс-энтропии. Итак, как нейронная сеть узнает, какой выход является правильным?

1 Ответ

0 голосов
/ 30 мая 2019

Это ожидается, потому что у вас есть симметрия данных в ваших двух calculated случаях.

В вашем примере желаемый результат равен [1, 0, 0, 0].Таким образом, истинный класс - это первый класс.Однако и в a, и в b ваш прогноз для первого класса одинаков (0,1).Также для других классов (истинные негативы - 2-й, 3-й и 4-й класс) у вас есть эта симметрия данных (класс 2 и класс 4 одинаково важны в отношении расчета потерь).

 a -> 0.9,0.1,0.1
       ^
       |       |
               V
 b -> 0.1,0.1,0.9

Таким образом, у вас естьта же самая ожидаемая потеря.

Если вы удалите эту симметрию, вы получите другую перекрестную потерю энтропии.Смотрите примеры ниже:


# The first two are from your examples.
print CrossEntropy(calculated=[0.1,0.9,0.1,0.1], desired=[1, 0, 0, 0])
print CrossEntropy(calculated=[0.1,0.1,0.1,0.9], desired=[1, 0, 0, 0])

# below we have prediction for the last class as 0.75 thus break the data symmetry.
print CrossEntropy(calculated=[0.1,0.1,0.1,0.75], desired=[1, 0, 0, 0])

# below we have prediction for the true class as 0.45.
print CrossEntropy(calculated=[0.45,0.1,0.1,0.9], desired=[1, 0, 0, 0])


result:
1.20397280433
1.20397280433
0.974900121357
0.827953455132
...