Во время работы с predict.knn3
я столкнулся с интересным вариантом использования data-wrangling-ish. Я не знал, что могу вызвать предикат, используя аргумент type="class"
, чтобы получить прогнозируемые уровни, именно то, что мне нужно. Поэтому я разработал несколько сложное решение, чтобы выбрать из каждой строки результатов predict()
, уровень с максимальной вероятностью. Проблема была в том, что функция names
не работала в «векторизованной» форме с матрицей, а только с векторами.
Чтобы проиллюстрировать сценарий использования до и после выяснения аргумента type="class"
:
rm(list = ls())
library(caret)
library(tidyverse)
library(dslabs)
data("tissue_gene_expression")
x <- tissue_gene_expression$x
y <- tissue_gene_expression$y
set.seed(1)
test_index <- createDataPartition(y, times = 1, p = 0.5, list = FALSE)
test_x <- x[test_index,]
test_y <- y[test_index]
train_x <- x[-test_index,]
train_y <- y[-test_index]
# fit the model, predict without type="class" and use sapply to build the y_hat levels
fit <- knn3(train_x, train_y, k = 1)
pred <- predict(fit, test_x)
y_hat <- sapply(1:nrow(pred), function(i) as.factor(names(pred[i,which.max(pred[i,])])))
# compare it to the solution using predict with type="class"
identical(y_hat, as.factor(predict(fit, test_x, type="class")))
[1] TRUE
Чтобы проиллюстрировать проблему, я могу сделать следующее: посмотрите, что функция names, работающая с вектором именованных числовых элементов, дает желаемый результат, тогда как с матрицей произойдет сбой с выводом NULL:
names(pred[1, which.max(pred[1,])])
[1] "cerebellum"
names(pred[1:2, which.max(pred[1:2,])])
NULL
Предполагая, что вы не знаете об этом удобном type="class"
в функции predict.knn3
;Есть ли более простой способ, используя tidyverse и dplyr, чтобы заменить это sapply? Или какой-нибудь другой более простой способ реализовать этот вариант использования?
y_hat <- sapply(1:nrow(pred), function(i) as.factor(names(pred[i, which.max(pred[i,])])))
Мне нужно что-то вроде следующего, но это не работает:
as_tibble(predict(fit, test_x)) %>% mutate(y_hat=names(which.max(.[row_number(),])))