Я хочу классифицировать набор данных, используя Decision Tree (DT) для вычисления точности, для вычисления точности мы сравниваем результат DTree с метками класса 1 или 2, но проблема в том, что функция DTree возвращает плавающую точку числа в порядке величины 1e3. был получен результат классификатора DT:
DT =
1.0e+03 *
1.5311
1.2482
3.0774
1.2482
1.0627
1.5311
2.6613
3.3919
1.3951
1.2482
3.3919
1.2482
по этой ссылке, mathworks Я прикрепил программу и функцию Matlab и набор данных.
Я сравнил DT (i) (результат дерева решений) с помощью ytest (i) (последний столбец тестовых данных, которые являются метками классов), чтобы определить, где реальный результат равен идеальному результату для вычисления точности классификатора с использованием TP, TN, FN , FP.
Например, TP (True Positive), когда мы правильно определили экземпляр, принадлежит классу 1 (это наше наблюдение по результату классификатора), а также метка экземпляра в тестовых данных равна 1 ( идеальный результат), поэтому одна единица добавляется в TP. При вычислении TP, TN, FP, FN мы используем их в формуле точности точности = (TP + TN) / (TP + TN + FN + FP), что его диапазон равен [0 1], но ADT = 0/0 = NaN, поскольку переменные tp_dt, tn_dt, fp_dt и fn_dt вычисляются как ноль.
Как преобразовать выходные данные функции DTree в целочисленные значения 1 или 2?
Функция классификатора DT:
function ppred=DTree (xtest,xtrain,ytrain)
DTreeModel=ClassificationTree.fit(xtrain,ytrain);
ppred=DTreeModel.predict(xtest);
end
Программа:
clc;
clear;
close all;
load colon.mat
data=colon;
[n,m]=size(data);
for a=1:n
if data(a,m)==0
data(a,m)=2;
end
end
S=[ consists of 30 number of columns from the dataset];
data0=data(:,S);
rows=(1:n);
test_count=floor((0.2)*n);
[n,m]=size(data0);
test_rows=randsample(rows,test_count);
train_rows=setdiff(rows,test_rows);
test=data0(test_rows,:);
train=data0(train_rows,:);
xtest=test(:,1:m-1);
ytest=test(:,m);
xtrain=train(:,1:m-1);
ytrain=train(:,m);
DT=DTree(xtest,xtrain,ytrain);
tp_dt=0; tn_dt=0; fp_dt=0; fn_dt=0;
for i=1:test_count
if(DT(i)==1 && ytest(i)==1)
tp_dt=tp_dt+1;
end
if(DT(i)==2 && ytest(i)==2)
tn_dt=tn_dt+1;
end
if(DT(i)==2 && ytest(i)==1)
fp_dt=fp_dt+1;
end
if(DT(i)==1 && ytest(i)==2)
fn_dt=fn_dt+1;
end
end
ADT=(tp_dt+tn_dt)/(tp_dt+tn_dt+fp_dt+fn_dt);
disp('Accuracy');
disp(ADT);