Метрика точности подраздела категорий в Керасе - PullRequest
1 голос
/ 17 мая 2019

У меня проблема классификации с 3 классами. Давайте определим их как классы 0,1 и 2. В моем случае класс 0 не важен, то есть то, что классифицируется как класс 0, не имеет значения. Однако важна точность, точность, отзыв и частота ошибок только для классов 1 и 2. Я хотел бы определить метрику точности, которая рассматривает только подраздел данных, относящийся к 1 и 2, и дает мне меру этого как модель обучения. Я не прошу код на точность или f1 или точность / отзыв - те, которые я нашел и могу реализовать сам. То, что я спрашиваю, для кода, который может помочь выбрать подраздел категорий для выполнения этих метрик. Визуально с матрицей путаницы: Дано:

>  0   1   2
>0 10  3   4
>1 2   5   1
>2 8   5   9

Я бы хотел провести тренировку с точностью только для следующего подмножества:

>  1   2
>1 5   1
>2 5   9

Возможная идея: Объединить разбитые по категориям, аргументы argmaxed y_pred и argmaxed y_true, отбросить все экземпляры, где появляется 0, повторно распутать их обратно в массив one_hot и сделать простую двоичную точность в отношении того, что осталось?

Edit: Я пытался исключить 0-класс через этот код, но это не имеет смысла. 0-категория эффективно включается в 1-категорию (то есть истинные положительные значения как 0, так и 1 в конечном итоге помечаются как 1). Все еще ищете помощи - кто-нибудь может помочь, пожалуйста?

#this solution does not work :(
def my_acc(y_true, y_pred):
#excluding the 0-category
y_true_cust = y_true[:,np.r_[1:3]]
y_pred_cust = y_pred[:,np.r_[1:3]]
#binary accuracy source code, slightly edited
y_pred_cat = Ker.round(y_pred_cust)
eql_cust = Ker.equal(y_true_cust, y_pred_cust)
return Ker.mean(eql_cust, axis = -1)

@ Эшвин Гит Д'Са

correct_guesses_3cat = 10 + 5 + 9
print(correct_guesses_3cat)
24

total_guesses_3cat = 10+3+4+2+5+1+8+5+9
print(total_guesses_3cat)
47

accuracy_3cat = 24/47
print(accuracy_3cat)
51.1 %

correct_guesses_2cat =5 + 9
print(correct_guesses_2cat)
14

total_guesses_2cat = 5+1+5+9
print(total_guesses_2cat)
20

accuracy_2cat = 14/20
print(accuracy_2cat)
70.0 %
...