Невозможно построить Границу Решения в R с помощью geom_contour () - PullRequest
0 голосов
/ 18 апреля 2020

Я знаю, что есть несколько похожих вопросов по этому поводу, но я все еще пытаюсь понять концепцию, лежащую в основе этого.

Я предсказал классы ("истина" и "ложь") на основе кнн:

hype_knn_prediction = knn(hype_quant_train[,2:3], hype_quant_test[,2:3], cl = hype_quant_train$Selected, k = k)

и я также могу построить данные теста:

ggplot(hype_quant_test, aes(x= Comments, y= Votings, color = Selected, shape = Selected)) + 
  geom_point(size = 3) +
  labs(y="Votes", x = "Comments")+
  ggtitle("Testdata")+
  theme(plot.title = element_text(hjust = 0.5))+
  theme(legend.position="bottom")

График данных теста

Теперь я хотел бы добавить решение границы классов к участку тестовых данных. Поэтому я добавил hype_knn_prediction в качестве столбца к фрейму данных hype_quant_test. Но когда я добавляю geom_contour(data=hype_quant_test, aes(x=Comments, y=Votings, z= as.numberic(Selected)), breaks=c(0,.5)) к графику, я получаю следующее сообщение: Сбой вычисления в stat_contour (): количество координат x должно соответствовать количеству столбцов в матрице плотности.

Как можно Я решил проблему? Я предполагаю, что мне нужно преобразовать некоторые данные, но я не знаю, как

РЕДАКТИРОВАТЬ

данные обучения:

   Selected  Votings          Comments
1      true  0.2348563517    0.162454874
2     false  0.0027691243    0.001805054
3     false  0.0136725511    0.027075812
4     false  0.1128418138    0.077617329
5     false  0.0529595016    0.016245487
6     false  0.0190377293    0.012635379
7     false  0.0231914157    0.001805054
8     false  0.3367947387    0.019855596
9     false  0.0036344756    0.005415162
10    false  0.0051921080    0.005415162
11    false  0.0202492212    0.014440433
12    false  0.0178262375    0.007220217
13    false  0.0029421945    0.010830325
14    false  0.0680166147    0.036101083
15    false  0.0053651783    0.003610108
16    false  0.2397023191    0.034296029
17    false  0.0001730703    0.000000000
18    false  0.0228452752    0.023465704
19    false  0.0129802700    0.000000000
20    false  0.0192107996    0.018050542
21    false  0.0010384216    0.000000000
22    false  0.0129802700    0.005415162
23    false  0.0000000000    0.000000000
24    false  0.0134994808    0.003610108
25    false  0.0742471443    0.039711191
26    false  0.0256143994    0.009025271
27    false  0.0039806161    0.001805054
28     true  0.4110418830    0.050541516
29    false  0.0114226376    0.063176895
30    false  0.0185185185    0.016245487
31    false  0.0051921080    0.003610108
32    false  0.1952232606    0.021660650
33    false  0.1138802354    0.012635379
34    false  0.0048459675    0.016245487
35    false  0.0242298373    0.009025271
36    false  0.0167878159    0.001805054
37    false  0.0039806161    0.001805054
38     true  0.7727587400    0.146209386
39    false  0.0154032537    0.000000000
40    false  0.0057113188    0.007220217
41    false  0.0038075459    0.000000000
42    false  0.0046728972    0.001805054
43    false  0.0152301835    0.003610108
44    false  0.0408445829    0.025270758
45    false  0.0131533403    0.007220217
46    false  0.0578054690    0.037906137
47    false  0.0046728972    0.005415162
48    false  0.0001730703    0.001805054
49    false  0.1169955002    0.122743682
50    false  0.0044998269    0.003610108
51    false  0.0000000000    0.000000000
52    false  0.1439944618    0.036101083
53    false  0.0072689512    0.005415162
54    false  0.0064035999    0.009025271
55    false  0.0614399446    0.027075812
56    false  0.0719972309    0.005415162
57     true  0.3418137764    0.018050542
58    false  0.0117687781    0.012635379
59    false  0.0072689512    0.014440433
60     true  0.0313257182    0.018050542
61    false  0.1021114573    0.019855596
62    false  0.0024229837    0.003610108
63    false  0.0072689512    0.000000000
64    false  0.0169608861    0.003610108
65    false  0.0340948425    0.014440433
66     true  0.7069920388    0.332129964
67     true  0.7377985462    0.175090253
68    false  0.0919003115    0.007220217
69    false  0.0065766701    0.001805054
70    false  0.0401523018    0.027075812
71    false  0.0223260644    0.005415162
72    false  0.0635167878    0.018050542
73    false  0.0013845621    0.000000000
74    false  0.0060574593    0.000000000
75     true  0.6102457598    0.909747292
76    false  0.0022499135    0.001805054
77    false  0.0316718588    0.007220217
78    false  0.0019037729    0.000000000
79     true  1.0000000000    1.000000000
80    false  0.0240567670    0.016245487

Тестовые данные после добавления прогноза столбец:

   Selected   Votings          Comments   Prediction
1     false   0.329525787    0.023465704      false
2     false   0.299930772    0.075812274      false
3      true   0.962443752    0.178700361       true
4     false   0.032191070    0.001805054      false
5     false   0.036863967    0.025270758      false
6     false   0.014884043    0.005415162      false
7     false   0.034787124    0.005415162      false
8     false   0.007615092    0.000000000      false
9     false   0.005538249    0.000000000      false
10    false   0.006403600    0.005415162      false
11    false   0.006749740    0.005415162      false
12    false   0.048286604    0.072202166      false
13    false   0.057286258    0.021660650      false
14    false   0.067324334    0.012635379      false
15    false   0.004153686    0.001805054      false
16    false   0.004845967    0.003610108      false
17    false   0.089131187    0.055956679      false
18    false   0.010384216    0.001805054      false
19    false   0.040671513    0.021660650      false
20    false   0.001903773    0.001805054      false

РЕДАКТИРОВАТЬ

Я пытался протестировать построение с данными радужной оболочки по умолчанию, но сообщение остается тем же:

library(datasets)
library(class)
library(ggplot2)
library(caret)

iris_df = as.data.frame(iris)

normalize = function(x) {
  return ((x - min(x)) / (max(x) - min(x)))
}

iris_df$Sepal.Length = normalize(iris_df$Sepal.Length)
iris_df$Sepal.Width = normalize(iris_df$Sepal.Width)
iris_df$Petal.Length = normalize(iris_df$Petal.Length)
iris_df$Petal.Width = normalize(iris_df$Petal.Width)

set.seed(1234)

#sampling
sample = sample(nrow(iris_df), nrow(iris_df)*0.8, replace = FALSE)

#Training data
iris_train=iris_df[sample,]

#Testdata
iris_test=iris_df[-sample,]


ggplot(iris_train, aes(x= Sepal.Length, y= Sepal.Width, color = Species, shape = Species)) + 
  geom_point(size = 3) +
  theme(legend.position="bottom")

ggplot(iris_test, aes(x= Sepal.Length, y= Sepal.Width, color = Species, shape = Species)) + 
  geom_point(size = 3) +
  theme(legend.position="bottom")

k = round(sqrt(nrow(iris_train)))


knn_predict = knn(iris_train[,1:4], iris_test[,1:4], cl = iris_train$Species, k = k)

iris_test$Prediction = knn_predict

confusionMatrix(iris_test$Prediction, iris_test$Species)

ggplot(iris_test, aes(x= Sepal.Length, y= Sepal.Width, color = Species, shape = Species)) + 
  geom_point(size = 3) +
  geom_contour(data = iris_test, aes(x= Sepal.Length, y= Sepal.Width, z = as.numeric(Prediction)),breaks = c(0,.5))

Не удалось вычислить в stat_contour(): количество координат х должно соответствовать количеству столбцов в матрице плотности.

1 Ответ

1 голос
/ 18 апреля 2020

Я думаю, что ключевой частью подхода geom_contour, которого я не вижу в вашем коде радужной оболочки, является выполнение предсказаний на основе матрицы переменных. Вот как вы могли бы сделать один. Я не делал ничего особенного с предварительной обработкой.

library(class)
library(ggplot2)

train <- sample(150, 75)

train_dat <- iris[train, -5]
test_dat <- iris[-train, -5]

vars <- c("Sepal.Width", "Sepal.Length")

# First make a grid
n <- 40
pred.mat <- expand.grid(
  Sepal.Width = with(iris, seq(min(Sepal.Width), max(Sepal.Width), length.out = n)),
  Sepal.Length = with(iris, seq(min(Sepal.Length), max(Sepal.Length), length.out = n))
)

# Then ask for prediction on the grid
pred.mat$pred <- knn(train_dat[, vars], pred.mat, cl = iris$Species[train], k = 3)

# Use grid as input for geom_contour
ggplot(pred.mat, aes(Sepal.Width, Sepal.Length)) +
  geom_point(data = iris, aes(color = Species)) +
  geom_contour(aes(z = as.numeric(pred == "setosa"), 
                   colour = "setosa"), 
               breaks = 0.5) +
  geom_contour(aes(z = as.numeric(pred == "virginica"), 
                   colour = "virginica"), 
               breaks = 0.5) +
  geom_contour(aes(z = as.numeric(pred == "versicolor"), 
                   color = "versicolor"), 
               breaks = 0.5)

Создано в 2020-04-18 пакетом Представлять (v0.3.0)

...