Почему моя модель настолько точна при использовании knn (), где k = 1? - PullRequest
0 голосов
/ 03 июля 2018

В настоящее время я использую уровни геномной экспрессии, возраст и интенсивность курения, чтобы предсказать количество дней, которые должны прожить пациенты с раком легких. У меня есть небольшое количество данных; 173 пациента и 20 438 переменных, включая уровни экспрессии генов (что составляет 20 436). Я разделил свои данные на тестирование и обучение, используя соотношение 80:20. В данных отсутствуют пропущенные значения.

Я использую knn () для обучения модели. Вот как выглядит код:

prediction <- knn(train = trainData, test = testData, cl = trainAnswers, k=1)

Ничто не кажется необычным, пока вы не заметите, что k = 1. "Почему к = 1?" Вы можете спросить. Причина k = 1 в том, что когда k = 1, модель является наиболее точной. Это не имеет смысла для меня. Есть довольно много проблем:

  1. Я использую knn () для прогнозирования непрерывной переменной. Я должен использовать что-то вроде Кокс, может быть.
  2. Модель слишком точная. Вот несколько примеров тестового ответа и прогнозов модели. Для первого пациента число дней до смерти составляет 274. Модель прогнозирует 268. Для второго пациента тест: 1147, прогноз: 1135. 3-й, тест: 354, прогноз: 370. 4-й, тест: 995, прогноз 995 . Как это возможно? Из всех тестовых данных модель была только в среднем 9.0625 дней! Средняя разница составила 7 дней, а режим - 6 дней. Вот график результатов: Гистограмма .

Итак, я думаю, что мой главный вопрос в том, что делает knn (), что представляет собой k, и насколько модель настолько точна, когда k = 1? Вот весь мой код (я не могу прикрепить фактические данные):

# install.packages(c('caret', 'skimr', 'RANN', 'randomForest', 'fastAdaboost', 'gbm', 'xgboost', 'caretEnsemble', 'C50', 'earth'))
library(caret)

# Gather the data and store it in variables
LUAD <- read.csv('/Users/username/Documents/ClinicalData.csv')
geneData <- read.csv('/Users/username/Documents/GenomicExpressionLevelData.csv')
geneData <- data.frame(geneData)
row.names(geneData) = geneData$X
geneData <- geneData[2:514]
colNamesGeneData <- gsub(".","-",colnames(geneData),fixed = TRUE)
colnames(geneData) = colNamesGeneData

# Organize the data
# Important columns are 148 (smoking), 123 (OS Month, basically how many days old), and the gene data. And column 2 (barcode).
LUAD = data.frame(LUAD$patient, LUAD$TOBACCO_SMOKING_HISTORY_INDICATOR, LUAD$OS_MONTHS, LUAD$days_to_death)[complete.cases(data.frame(LUAD$patient, LUAD$TOBACCO_SMOKING_HISTORY_INDICATOR, LUAD$OS_MONTHS, LUAD$days_to_death)), ]
rownames(LUAD)=LUAD$LUAD.patient
LUAD <- LUAD[2:4]

# intersect(rownames(LUAD),colnames(geneData))
# ind=which(colnames(geneData)=="TCGA-778-7167-01A-11R-2066-07")
gene_expression=geneData[, rownames(LUAD)]

# Merge the two datasets to use the geneomic expression levels in your model
LUAD <- data.frame(LUAD,t(gene_expression))
LUAD.days_to_death <- LUAD[,3]
LUAD <- LUAD[,c(1:2,4:20438)]
LUAD <- data.frame(LUAD.days_to_death,LUAD)

set.seed(401)

# Number of Rows in the training data (createDataPartition(dataSet, percentForTraining, boolReturnAsList))
trainRowNum <- createDataPartition(LUAD$LUAD.days_to_death, p=0.8, list=FALSE)

# Training/Test Dataset
trainData <- LUAD[trainRowNum, ]
testData <- LUAD[-trainRowNum, ]

x = trainData[, c(2:20438)]
y = trainData$LUAD.days_to_death
v = testData[, c(2:20438)]
w = testData$LUAD.days_to_death

# Imputing missing values into the data
preProcess_missingdata_model <- preProcess(trainData, method='knnImpute')
library(RANN)
if (anyNA(trainData)) {
    trainData <- predict(preProcess_missingdata_model, newdata = trainData)
}
anyNA(trainData)

# Normalizing the data
preProcess_range_model <- preProcess(trainData, method='range')
trainData <- predict(preProcess_range_model, newdata = trainData)
trainData$LUAD.days_to_death <- y
apply(trainData[,1:20438], 2, FUN=function(x){c('min'=min(x), 'max'=max(x))})

preProcess_range_model_Test <- preProcess(testData, method='range')
testData <- predict(preProcess_range_model_Test, newdata = testData)
testData$LUAD.days_to_death <- w
apply(testData[,1:20438], 2, FUN=function(v){c('min'=min(v), 'max'=max(v))})

# To uncomment, select the text and press 'command' + 'shift' + 'c'
# set.seed(401)
# options(warn=-1)
# subsets <- c(1:10)
# ctrl <- rfeControl(functions = rfFuncs,
#                    method = "repeatedcv",
#                    repeats = 5,
#                    verbose = TRUE)
# lmProfile <- rfe(x=trainData[1:20437], y=trainAnswers,
#                  sizes = subsets,
#                  rfeControl = ctrl)
# lmProfile

trainAnswers <- trainData[,1]
testAnswers <- testData[,1]

library(class)
prediction <- knn(train = trainData, test = testData, cl = trainAnswers, k=1)

#install.packages("plotly")
library(plotly)
Test_Question_Number <- c(1:32)
prediction2 <- data.frame(prediction[1:32])
prediction2 <- as.numeric(as.vector(prediction2[c(1:32),]))
data <- data.frame(Test_Question_Number, prediction2, testAnswers)
names(data) <- c("Test Question Number","Prediction","Answer")

p <- plot_ly(data, x = ~Test_Question_Number, y = ~prediction2, type = 'bar', name = 'Prediction') %>%
    add_trace(y = ~testAnswers, name = 'Answer') %>%
    layout(yaxis = list(title = 'Days to Death'), barmode = 'group')
p
merge <- data.frame(prediction2,testAnswers)

difference <- abs((merge[,1])-(merge[,2]))
difference <- sort(difference)
meanDifference <- mean(difference)
medianDifference <- median(difference)
modeDifference <- names(table(difference))[table(difference)==max(table(difference))]
cat("Mean difference:", meanDifference, "\n")
cat("Median difference:", medianDifference, "\n")
cat("Mode difference:", modeDifference,"\n")

Наконец, для уточнения, ClinicalData.csv - это данные о возрасте, днях до смерти и интенсивности курения. Другой .csv - это данные геномной экспрессии. Данные над строкой 29 на самом деле не имеют значения, поэтому вы можете просто перейти к той части кода, где написано «set.seed (401)».

Редактировать: Некоторые образцы данных:

days_to_death    OS_MONTHS
121              3.98

NACC1   2001.5708   2363.8063   1419.879
NACC2   58.2948     61.8157     43.4386
NADK    706.868     1053.4424   732.1562
NADSYN1 1628.7634   912.1034    638.6471
NAE1    832.8825    793.3014    689.7123
NAF1    140.3264    165.4858    186.355
NAGA    1523.3441   1524.4619   1858.9074
NAGK    983.6809    899.869     1168.2003
NAGLU   621.3457    510.9453    1172.511
NAGPA   346.9762    257.5654    275.5533
NAGS    460.7732    107.2116    321.9763
NAIF1   217.1219    202.5108    132.3054
NAIP    101.2305    87.8942     77.261
NALCN   13.9628     36.7031     48.0809
NAMPT   3245.6584   1257.8849   5465.6387
...