В настоящее время я использую уровни геномной экспрессии, возраст и интенсивность курения, чтобы предсказать количество дней, которые должны прожить пациенты с раком легких. У меня есть небольшое количество данных; 173 пациента и 20 438 переменных, включая уровни экспрессии генов (что составляет 20 436). Я разделил свои данные на тестирование и обучение, используя соотношение 80:20. В данных отсутствуют пропущенные значения.
Я использую knn () для обучения модели. Вот как выглядит код:
prediction <- knn(train = trainData, test = testData, cl = trainAnswers, k=1)
Ничто не кажется необычным, пока вы не заметите, что k = 1. "Почему к = 1?" Вы можете спросить. Причина k = 1 в том, что когда k = 1, модель является наиболее точной. Это не имеет смысла для меня. Есть довольно много проблем:
- Я использую knn () для прогнозирования непрерывной переменной. Я должен использовать что-то вроде Кокс, может быть.
- Модель слишком точная. Вот несколько примеров тестового ответа и прогнозов модели. Для первого пациента число дней до смерти составляет 274. Модель прогнозирует 268. Для второго пациента тест: 1147, прогноз: 1135. 3-й, тест: 354, прогноз: 370. 4-й, тест: 995, прогноз 995 . Как это возможно? Из всех тестовых данных модель была только в среднем 9.0625 дней! Средняя разница составила 7 дней, а режим - 6 дней. Вот график результатов:
Гистограмма .
Итак, я думаю, что мой главный вопрос в том, что делает knn (), что представляет собой k, и насколько модель настолько точна, когда k = 1? Вот весь мой код (я не могу прикрепить фактические данные):
# install.packages(c('caret', 'skimr', 'RANN', 'randomForest', 'fastAdaboost', 'gbm', 'xgboost', 'caretEnsemble', 'C50', 'earth'))
library(caret)
# Gather the data and store it in variables
LUAD <- read.csv('/Users/username/Documents/ClinicalData.csv')
geneData <- read.csv('/Users/username/Documents/GenomicExpressionLevelData.csv')
geneData <- data.frame(geneData)
row.names(geneData) = geneData$X
geneData <- geneData[2:514]
colNamesGeneData <- gsub(".","-",colnames(geneData),fixed = TRUE)
colnames(geneData) = colNamesGeneData
# Organize the data
# Important columns are 148 (smoking), 123 (OS Month, basically how many days old), and the gene data. And column 2 (barcode).
LUAD = data.frame(LUAD$patient, LUAD$TOBACCO_SMOKING_HISTORY_INDICATOR, LUAD$OS_MONTHS, LUAD$days_to_death)[complete.cases(data.frame(LUAD$patient, LUAD$TOBACCO_SMOKING_HISTORY_INDICATOR, LUAD$OS_MONTHS, LUAD$days_to_death)), ]
rownames(LUAD)=LUAD$LUAD.patient
LUAD <- LUAD[2:4]
# intersect(rownames(LUAD),colnames(geneData))
# ind=which(colnames(geneData)=="TCGA-778-7167-01A-11R-2066-07")
gene_expression=geneData[, rownames(LUAD)]
# Merge the two datasets to use the geneomic expression levels in your model
LUAD <- data.frame(LUAD,t(gene_expression))
LUAD.days_to_death <- LUAD[,3]
LUAD <- LUAD[,c(1:2,4:20438)]
LUAD <- data.frame(LUAD.days_to_death,LUAD)
set.seed(401)
# Number of Rows in the training data (createDataPartition(dataSet, percentForTraining, boolReturnAsList))
trainRowNum <- createDataPartition(LUAD$LUAD.days_to_death, p=0.8, list=FALSE)
# Training/Test Dataset
trainData <- LUAD[trainRowNum, ]
testData <- LUAD[-trainRowNum, ]
x = trainData[, c(2:20438)]
y = trainData$LUAD.days_to_death
v = testData[, c(2:20438)]
w = testData$LUAD.days_to_death
# Imputing missing values into the data
preProcess_missingdata_model <- preProcess(trainData, method='knnImpute')
library(RANN)
if (anyNA(trainData)) {
trainData <- predict(preProcess_missingdata_model, newdata = trainData)
}
anyNA(trainData)
# Normalizing the data
preProcess_range_model <- preProcess(trainData, method='range')
trainData <- predict(preProcess_range_model, newdata = trainData)
trainData$LUAD.days_to_death <- y
apply(trainData[,1:20438], 2, FUN=function(x){c('min'=min(x), 'max'=max(x))})
preProcess_range_model_Test <- preProcess(testData, method='range')
testData <- predict(preProcess_range_model_Test, newdata = testData)
testData$LUAD.days_to_death <- w
apply(testData[,1:20438], 2, FUN=function(v){c('min'=min(v), 'max'=max(v))})
# To uncomment, select the text and press 'command' + 'shift' + 'c'
# set.seed(401)
# options(warn=-1)
# subsets <- c(1:10)
# ctrl <- rfeControl(functions = rfFuncs,
# method = "repeatedcv",
# repeats = 5,
# verbose = TRUE)
# lmProfile <- rfe(x=trainData[1:20437], y=trainAnswers,
# sizes = subsets,
# rfeControl = ctrl)
# lmProfile
trainAnswers <- trainData[,1]
testAnswers <- testData[,1]
library(class)
prediction <- knn(train = trainData, test = testData, cl = trainAnswers, k=1)
#install.packages("plotly")
library(plotly)
Test_Question_Number <- c(1:32)
prediction2 <- data.frame(prediction[1:32])
prediction2 <- as.numeric(as.vector(prediction2[c(1:32),]))
data <- data.frame(Test_Question_Number, prediction2, testAnswers)
names(data) <- c("Test Question Number","Prediction","Answer")
p <- plot_ly(data, x = ~Test_Question_Number, y = ~prediction2, type = 'bar', name = 'Prediction') %>%
add_trace(y = ~testAnswers, name = 'Answer') %>%
layout(yaxis = list(title = 'Days to Death'), barmode = 'group')
p
merge <- data.frame(prediction2,testAnswers)
difference <- abs((merge[,1])-(merge[,2]))
difference <- sort(difference)
meanDifference <- mean(difference)
medianDifference <- median(difference)
modeDifference <- names(table(difference))[table(difference)==max(table(difference))]
cat("Mean difference:", meanDifference, "\n")
cat("Median difference:", medianDifference, "\n")
cat("Mode difference:", modeDifference,"\n")
Наконец, для уточнения, ClinicalData.csv - это данные о возрасте, днях до смерти и интенсивности курения. Другой .csv - это данные геномной экспрессии. Данные над строкой 29 на самом деле не имеют значения, поэтому вы можете просто перейти к той части кода, где написано «set.seed (401)».
Редактировать: Некоторые образцы данных:
days_to_death OS_MONTHS
121 3.98
NACC1 2001.5708 2363.8063 1419.879
NACC2 58.2948 61.8157 43.4386
NADK 706.868 1053.4424 732.1562
NADSYN1 1628.7634 912.1034 638.6471
NAE1 832.8825 793.3014 689.7123
NAF1 140.3264 165.4858 186.355
NAGA 1523.3441 1524.4619 1858.9074
NAGK 983.6809 899.869 1168.2003
NAGLU 621.3457 510.9453 1172.511
NAGPA 346.9762 257.5654 275.5533
NAGS 460.7732 107.2116 321.9763
NAIF1 217.1219 202.5108 132.3054
NAIP 101.2305 87.8942 77.261
NALCN 13.9628 36.7031 48.0809
NAMPT 3245.6584 1257.8849 5465.6387