Как получить имена классов в классификации? - PullRequest
0 голосов
/ 29 июня 2019

Я новичок в ML, я построил модель SVM для классификации некоторых входных данных. Я использовал панду, чтобы прочитать мой набор данных. Результаты классификации печатаются в виде индексов, каждый из которых соответствует названию меток (классов) в моем наборе данных. Как я могу преобразовать эти индексы в их имена (строки)?

например, у меня есть три класса: [Question, General, Info], но когда я пытаюсь классифицировать вход, результатом является одно из следующих чисел: [0,1,2] Я хочу преобразовать эти числа в названия классов, которые у меня есть.

вот часть моего кода:

data = pandas.read_csv("classes.csv",encoding='utf-16' )


Train_X, Test_X, Train_Y, Test_Y = sklearn.model_selection.train_test_split(data['input'],data['Class'],test_size=0.3,random_state=None)

Test_Y и Train_Y являются списками чисел (классов), каждое число относится к одному классу. Как узнать, что представляет собой каждое число?

1 Ответ

0 голосов
/ 30 июня 2019

Первое, что вам нужно знать: ваша модель работает как положено.В большинстве случаев будет выводиться вероятность для каждого ярлыка.Таким образом, если ваша модель выдает что-то вроде [0.1, 0.1, 0.8], это означает, что классифицируемый образец имеет 80%, чтобы принадлежать метке в позиции 2. Если вы пропустите все метки в порядке, указанном в вашем вопросе, то есть [question, general, info], это означает, что данный образец относится к классу info.Обратите внимание, что здесь важен порядок, и вам нужно убедиться, что при подаче модели в ваш код.

Поэтому, чтобы вывести строку вместо числа, вам нужно получить число, выводимое модельюи проверьте метку в списке или словаре, содержащем это отношение.Используя в качестве примера список:

labels_str = ['question', 'general', 'info']

# preds is a np.array containing the probabilities
preds = model(some_sample)

# this function returns the position of the max value in the array
pos_pred = preds.argmax() 

print ("The label for this sample is {}".format(labels_str[pos_pred])

Вы поняли идею?

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...