На сегодняшний день (выпуск Spark 2.4.0 уже утвержден и ожидает официального объявления) ваша лучшая ставка * без привлечения сложных сторонних инструментов (например, вы можете посмотреть MLeap), вероятно, сохранить модель и считывание спецификации :
ml_stage(iris_prediction_model, "random_forest") %>%
ml_save("/tmp/model")
rf_spec <- spark_read_parquet(sc, "rf", "/tmp/model/data/")
Результатом будет Spark DataFrame
со следующей схемой:
rf_spec %>%
spark_dataframe() %>%
invoke("schema") %>% invoke("treeString") %>%
cat(sep = "\n")
root
|-- treeID: integer (nullable = true)
|-- nodeData: struct (nullable = true)
| |-- id: integer (nullable = true)
| |-- prediction: double (nullable = true)
| |-- impurity: double (nullable = true)
| |-- impurityStats: array (nullable = true)
| | |-- element: double (containsNull = true)
| |-- gain: double (nullable = true)
| |-- leftChild: integer (nullable = true)
| |-- rightChild: integer (nullable = true)
| |-- split: struct (nullable = true)
| | |-- featureIndex: integer (nullable = true)
| | |-- leftCategoriesOrThreshold: array (nullable = true)
| | | |-- element: double (containsNull = true)
| | |-- numCategories: integer (nullable = true)
предоставление информации обо всех узлах и разбиениях.
Сопоставление объектов можно получить с помощью метаданных столбца:
meta <- iris_predictions %>%
select(features) %>%
spark_dataframe() %>%
invoke("schema") %>% invoke("apply", 0L) %>%
invoke("metadata") %>%
invoke("getMetadata", "ml_attr") %>%
invoke("getMetadata", "attrs") %>%
invoke("json") %>%
jsonlite::fromJSON() %>%
dplyr::bind_rows() %>%
copy_to(sc, .) %>%
rename(featureIndex = idx)
meta
# Source: spark<?> [?? x 2]
featureIndex name
* <int> <chr>
1 0 Sepal_Length
2 1 Sepal_Width
3 2 Petal_Length
4 3 Petal_Width
И отображение меток, которое вы уже получили:
labels <- tibble(prediction = seq_along(iris_labels) - 1, label = iris_labels) %>%
copy_to(sc, .)
Наконец, вы можете объединить все это:
full_rf_spec <- rf_spec %>%
spark_dataframe() %>%
invoke("selectExpr", list("treeID", "nodeData.*", "nodeData.split.*")) %>%
sdf_register() %>%
select(-split, -impurityStats) %>%
left_join(meta, by = "featureIndex") %>%
left_join(labels, by = "prediction")
full_rf_spec
# Source: spark<?> [?? x 12]
treeID id prediction impurity gain leftChild rightChild featureIndex
* <int> <int> <dbl> <dbl> <dbl> <int> <int> <int>
1 0 0 1 0.636 0.379 1 2 2
2 0 1 1 0 -1 -1 -1 -1
3 0 2 0 0.440 0.367 3 8 2
4 0 3 0 0.0555 0.0269 4 5 3
5 0 4 0 0 -1 -1 -1 -1
6 0 5 0 0.5 0.5 6 7 0
7 0 6 0 0 -1 -1 -1 -1
8 0 7 2 0 -1 -1 -1 -1
9 0 8 2 0.111 0.0225 9 12 2
10 0 9 2 0.375 0.375 10 11 1
# ... with more rows, and 4 more variables: leftCategoriesOrThreshold <list>,
# numCategories <int>, name <chr>, label <chr>
, который, собранный и разделенный treeID
, должен дать достаточно информации ** для имитации древовидного объекта (вы можете получить хорошее представление о требуемой структуре, проверив rpart::rpart.object
документацию и / или unclass
с моделью rpart
. tree::tree
потребует меньше работы, но его утилиты для черчения далеко не впечатляют), и создайте приличный сюжет.
Альтернативный путь - экспортировать данные в PMML, используя Sparklyr2PMML и использовать это представление.
Вы также можете проверить Как визуализировать / построить дерево решений в Apache Spark (PySpark 1.4.1)? , который предлагает сторонний пакет Python для решения той же проблемы.
Если вам не нужно ничего необычного, вы можете создать грубый сюжет с помощью igraph
:
library(igraph)
gframe <- full_rf_spec %>%
filter(treeID == 0) %>% # Take the first tree
mutate(
leftCategoriesOrThreshold = ifelse(
size(leftCategoriesOrThreshold) == 1,
# Continuous variable case
concat("<= ", round(concat_ws("", leftCategoriesOrThreshold), 3)),
# Categorical variable case. Decoding variables might be involved
# but can be achieved if needed, using column metadata or indexer labels
concat("in {", concat_ws(",", leftCategoriesOrThreshold), "}")
),
name = coalesce(name, label)) %>%
select(
id, label, impurity, gain,
leftChild, rightChild, leftCategoriesOrThreshold, name) %>%
collect()
vertices <- gframe %>% rename(label = name, name = id)
edges <- gframe %>%
transmute(from = id, to = leftChild, label = leftCategoriesOrThreshold) %>%
union_all(gframe %>% select(from = id, to = rightChild)) %>%
filter(to != -1)
g <- igraph::graph_from_data_frame(edges, vertices = vertices)
plot(
g, layout = layout_as_tree(g, root = c(1)),
vertex.shape = "rectangle", vertex.size = 45)
* Это должно улучшиться в ближайшем будущем, благодаря недавно введенному API записи ML, не зависящему от формата (который уже поддерживает PMML Writer для выбранных моделей. Надеемся, что появятся новые модели и форматы).
** Если вы работаете с категориальными функциями, возможно, вы захотите сопоставить leftCategoriesOrThreshold
с соответствующими индексированными уровнями.
Если вектор признаков содержит катагорические переменные, вывод jsonlite::fromJSON()
будет содержать группу nominal
. Например, если у вас есть индексированный столбец foo
с тремя уровнями, собранный в первой позиции, он будет выглядеть примерно так:
$nominal
vals idx name
1 a, b, c 1 foo
где vals
столбец - список векторов переменной длины.
length(meta$nominal$vals[[1]])
[1] 3
Метки соответствуют индексам этой структуры, поэтому в примере:
a
имеет метку 0.0 (за исключением того, что метки являются числами с плавающей запятой двойной точности, а нумерация начинается с 0.0)
b
имеет метку 1.0
и т. Д., И если вы разделите с leftCategoriesOrThreshold
, равным, скажем, c(0.0, 2.0)
, это означает, что разделение на метках {"a", "c"}
.
Обратите также внимание, что при наличии категориальных данных вам, возможно, придется обрабатывать их перед вызовом copy_to
- на данный момент не похоже, что он поддерживает сложные поля.
В Spark <= 2.3 вам придется использовать R-код для отображения (в локальной структуре некоторые <code>purrr вполне подойдут). В Spark 2.4 (пока не поддерживается в sparklyr
AFAIK) может быть проще считывать метаданные напрямую с помощью читателя JSON Spark и отображать его с функциями более высокого порядка.