Как преобразовать числа с плавающей точкой в ​​целочисленные значения для сравнения друг с другом? - PullRequest
0 голосов
/ 19 марта 2020

Я хочу классифицировать набор данных, используя Decision Tree (DT) для вычисления точности, для вычисления точности мы сравниваем результат DTree с метками класса 1 или 2, но проблема в том, что функция DTree возвращает плавающую точку числа в порядке величины 1e3. был получен результат классификатора DT:

DT =
 1.0e+03 *
 1.5311
 1.2482
 3.0774
 1.2482
 1.0627
 1.5311
 2.6613
 3.3919
 1.3951
 1.2482
 3.3919
 1.2482

по этой ссылке, mathworks Я прикрепил программу и функцию Matlab и набор данных.

Я сравнил DT (i) (результат дерева решений) с помощью ytest (i) (последний столбец тестовых данных, которые являются метками классов), чтобы определить, где реальный результат равен идеальному результату для вычисления точности классификатора с использованием TP, TN, FN , FP.

Например, TP (True Positive), когда мы правильно определили экземпляр, принадлежит классу 1 (это наше наблюдение по результату классификатора), а также метка экземпляра в тестовых данных равна 1 ( идеальный результат), поэтому одна единица добавляется в TP. При вычислении TP, TN, FP, FN мы используем их в формуле точности точности = (TP + TN) / (TP + TN + FN + FP), что его диапазон равен [0 1], но ADT = 0/0 = NaN, поскольку переменные tp_dt, tn_dt, fp_dt и fn_dt вычисляются как ноль.

Как преобразовать выходные данные функции DTree в целочисленные значения 1 или 2?

Функция классификатора DT:

function ppred=DTree (xtest,xtrain,ytrain)
    DTreeModel=ClassificationTree.fit(xtrain,ytrain);
    ppred=DTreeModel.predict(xtest);
end

Программа:

clc;
clear;
close all;
load colon.mat
data=colon;
[n,m]=size(data);
for a=1:n
   if data(a,m)==0
       data(a,m)=2;
   end
end
S=[ consists of 30 number of columns from the dataset];
data0=data(:,S);
rows=(1:n);
test_count=floor((0.2)*n);
[n,m]=size(data0);
test_rows=randsample(rows,test_count);
train_rows=setdiff(rows,test_rows);
test=data0(test_rows,:);
train=data0(train_rows,:);
xtest=test(:,1:m-1);
ytest=test(:,m);
xtrain=train(:,1:m-1);
ytrain=train(:,m);
DT=DTree(xtest,xtrain,ytrain);
tp_dt=0; tn_dt=0; fp_dt=0; fn_dt=0;
for i=1:test_count
    if(DT(i)==1 && ytest(i)==1)
        tp_dt=tp_dt+1;
    end
    if(DT(i)==2 && ytest(i)==2)
        tn_dt=tn_dt+1;
    end
    if(DT(i)==2 && ytest(i)==1)
        fp_dt=fp_dt+1;
    end
    if(DT(i)==1 && ytest(i)==2)
        fn_dt=fn_dt+1;
    end
end
ADT=(tp_dt+tn_dt)/(tp_dt+tn_dt+fp_dt+fn_dt);
disp('Accuracy');
disp(ADT);
...