Количество десятичных разрядов по краям графика дерева решений с ggparty - PullRequest
1 голос
/ 06 марта 2020

Я хочу построить дерево решений (по оценке пакета partykit), используя мощный пакет ggparty. Все хорошо, за исключением количества десятичных разрядов чисел c разделенных переменных. Как я могу отформатировать breaks_label в geom_edge_label(), например, чтобы изменить > 75.33333 в > 75.3 на графике ниже? round() не работает. Я мог бы использовать обходной путь через общий options(digits = 3), но мне интересно, есть ли более прямой путь.

library("ggparty") 
data("WeatherPlay", package = "partykit")

sp_o <- partysplit(1L, index = 1:3)
sp_h <- partysplit(3L, breaks = 75 + 1/3)
sp_w <- partysplit(4L, index = 1:2)
pn <- partynode(1L, split = sp_o, kids = list(
    partynode(2L, split = sp_h, kids = list(
        partynode(3L, info = "yes"),
        partynode(4L, info = "no"))),
    partynode(5L, info = "yes"),
    partynode(6L, split = sp_w, kids = list(
        partynode(7L, info = "yes"),
        partynode(8L, info = "no")))))
py <- party(pn, WeatherPlay)

ggparty(py) +
    geom_edge() +
    # geom_edge_label() +
    geom_edge_label(mapping = aes(label = paste(breaks_label))) +
    geom_node_splitvar() +
    geom_node_info()

Создано в 2020-03 -05 * Представить пакет (v0.3.0)

1 Ответ

1 голос
/ 09 марта 2020

Спасибо за использование ggparty!

Так что я думаю, это то, для чего действительно нет прямого решения с текущей версией. Но я обязательно сделаю это в будущем!

Обычно, используя geoms только на подмножествах узлов, обычно можно обойти довольно много вещей. Как вы уже заметили, символ breaks_label хранится не как цифра c, а как символ с некоторым анализируемым текстом для знаков неравенства перед ними. Поэтому вам придется использовать что-то вроде substr ().

ggparty(py) +
  geom_edge() +
  geom_edge_label(id = -c(3, 4)) +
    geom_edge_label(mapping = aes(label = paste(substr(breaks_label, start = 1, stop = 15))),
                    id = c(3, 4)) +
  geom_node_splitvar() +
  geom_node_info() 

Я также изменил одну из внутренних функций, включив в нее функцию округления, чтобы вы могли получить ее из github и использовать ее. Но я действительно не проверял это, так что используйте на свой страх и риск;)

library(devtools)
source_url("https://raw.githubusercontent.com/martin-borkovec/ggparty/martin/R/add_splitvar_breaks_index_new.R")

rounded_labels <- add_splitvar_breaks_index_new(party_object = py,
                                                plot_data = ggparty:::get_plot_data(py), 
                                                round_digits = 2)

ggparty(py) +
  geom_edge() +
  geom_edge_label(mapping = aes(label = unlist(rounded_labels)),
                  data = rounded_labels) +
  geom_node_splitvar() +
  geom_node_info()
...