Прогнозирование значений с помощью dplyr и augment - PullRequest
0 голосов
/ 04 июля 2018

Я хотел бы подогнать модели к сгруппированному фрейму данных, а затем предсказать одно новое значение для каждой модели (то есть группы).

library(dplyr)
library(broom)

data(iris)
dat <- rbind(iris, iris) 
dat$Group <- rep(c("A", "B"), each = 150)

new.dat <- data.frame(Group = rep(c("A", "B"), each = 3),
                      Species = rep(c("setosa", "versicolor", "virginica"), times = 2),
                      Sepal.Width = 1:6)
> new.dat
  Group    Species val
1     A     setosa   1
2     A versicolor   2
3     A  virginica   3
4     B     setosa   4
5     B versicolor   5
6     B  virginica   6

Однако augment возвращает 36 строк, как будто каждое новое значение соответствует каждой модели. Как я могу сохранить группировку здесь и получить одно подходящее значение на группу?

dat %>%
  group_by(Species, Group) %>%
  do(augment(lm(Sepal.Length ~ Sepal.Width, data = .), newdata = new.dat))

# A tibble: 36 x 5
# Groups:   Species, Group [6]
   Group Species    Sepal.Width .fitted .se.fit
   <fct> <fct>            <int>   <dbl>   <dbl>
 1 A     setosa               1    3.33  0.221 
 2 A     versicolor           2    4.02  0.133 
 3 A     virginica            3    4.71  0.0512
 4 B     setosa               4    5.40  0.0615
 5 B     versicolor           5    6.09  0.145 
 6 B     virginica            6    6.78  0.234 
 7 A     setosa               1    3.33  0.221 
 8 A     versicolor           2    4.02  0.133 
 9 A     virginica            3    4.71  0.0512
10 B     setosa               4    5.40  0.0615
# ... with 26 more rows

(Обратите внимание, что из-за данных примера строки на самом деле являются дубликатами, что, однако, не относится к моим исходным данным).

1 Ответ

0 голосов
/ 04 июля 2018

Вам нужно, чтобы Species и Group из new.dat соответствовали группам, обрабатываемым в данный момент в do. Вы можете сделать это так:

group.cols <- c("Species", "Group")
dat %>%
    group_by(!!! group.cols) %>%
    do(augment(lm(Sepal.Length ~ Sepal.Width, data = .),
               newdata = semi_join(new.dat, ., by = group.cols)))
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...