Ваше масштабирование может стать настоящим узким местом, если у вас есть больше столбцов (проверено на матрице экспрессии гена 200 x 22216). Моя версия может показаться не такой уж впечатляющей с набором данных iris
, но на большем наборе данных я получаю 1,3 сек c против 32,8 сек c время выполнения.
Использование tabulate
вместо table
дает дополнительное улучшение, которое затмевается, однако, масштабированием матрицы.
Я использовал здесь функцию пользовательского масштаба, но использование base::scale
в матрице уже было бы значительным улучшением.
Я также затронул вопрос, поднятый М. Папенбергом о том, что «я» не считается ближайшим соседом, установив для них значение NA.
invisible(lapply(c("tidyverse", "matrixStats", "RANN", "microbenchmark", "compiler"),
require, character.only=TRUE))
enableJIT(3)
# faster column scaling (modified from https://www.r-bloggers.com/author/strictlystat/)
colScale <- function(x, center = TRUE, scale = TRUE, rows = NULL, cols = NULL) {
if (!is.null(rows) && !is.null(cols)) {x <- x[rows, cols, drop = FALSE]
} else if (!is.null(rows)) {x <- x[rows, , drop = FALSE]
} else if (!is.null(cols)) x <- x[, cols, drop = FALSE]
cm <- colMeans(x, na.rm = TRUE)
if (scale) csd <- matrixStats::colSds(x, center = cm, na.rm = TRUE) else
csd <- rep(1, length = length(cm))
if (!center) cm <- rep(0, length = length(cm))
x <- t((t(x) - cm) / csd)
return(x)
}
# your posted version (mostly):
oldv <- function(){
iris.scaled <- iris %>%
mutate_if(is.numeric, scale)
iris.nn2 <- nn2(iris.scaled[1:4])
distance.index <- iris.nn2$nn.idx[,-1]
target = iris.scaled$Species
category_neighbours <- matrix(target[distance.index[,]], nrow = nrow(distance.index), ncol = ncol(distance.index))
class <- apply(category_neighbours, 1, function(x) {
x1 <- table(x)
names(x1)[which.max(x1)]})
cbind(iris, class)
}
## my version:
myv <- function(){
iris.scaled <- colScale(data.matrix(iris[, 1:(dim(iris)[2]-1)]))
iris.nn2 <- nn2(iris.scaled)
# set self neighbors to NA
iris.nn2$nn.idx[iris.nn2$nn.idx - seq_len(dim(iris.nn2$nn.idx)[1]) == 0] <- NA
# match up categories
category_neighbours <- matrix(iris$Species[iris.nn2$nn.idx[,]],
nrow = dim(iris.nn2$nn.idx)[1], ncol = dim(iris.nn2$nn.idx)[2])
# turn category_neighbours into numeric for tabulate
cn <- matrix(as.numeric(factor(category_neighbours, exclude=NULL)),
nrow = dim(iris.nn2$nn.idx)[1], ncol = dim(iris.nn2$nn.idx)[2])
cnl <- levels(factor(category_neighbours, exclude = NULL))
# tabulate frequencies and match up with factor levels
class <- apply(cn, 1, function(x) {
cnl[which.max(tabulate(x, nbins=length(cnl))[!is.na(cnl)])]})
cbind(iris, class)
}
microbenchmark(oldv(), myv(), times=100L)
#> Unit: milliseconds
#> expr min lq mean median uq max neval cld
#> oldv() 11.015986 11.679337 12.806252 12.064935 12.745082 33.89201 100 b
#> myv() 2.430544 2.551342 3.020262 2.612714 2.691179 22.41435 100 a