Извлекайте и визуализируйте деревья моделей из Sparklyr - PullRequest
0 голосов
/ 02 ноября 2018

Есть ли у кого-нибудь совет о том, как преобразовать информацию дерева из моделей sparklyr ml_decision_tree_classifier, ml_gbt_classifier или ml_random_forest_classifier в.) Формат, который может быть понят другими библиотеками, относящимися к дереву R, и (в конечном итоге) б.) Визуализацией деревьев для нетехнического потребления? Это включает в себя возможность преобразования обратно в реальные имена объектов из значений индексации замещенных строк, которые создаются во время векторного ассемблера.

Следующий код скопирован из блога блога sparklyr с целью предоставления примера:

library(sparklyr)
library(dplyr)

# If needed, install Spark locally via `spark_install()`
sc <- spark_connect(master = "local")
iris_tbl <- copy_to(sc, iris)

# split the data into train and validation sets
iris_data <- iris_tbl %>%
  sdf_partition(train = 2/3, validation = 1/3, seed = 123)


iris_pipeline <- ml_pipeline(sc) %>%
  ft_dplyr_transformer(
    iris_data$train %>%
      mutate(Sepal_Length = log(Sepal_Length),
             Sepal_Width = Sepal_Width ^ 2)
  ) %>%
  ft_string_indexer("Species", "label")

iris_pipeline_model <- iris_pipeline %>%
  ml_fit(iris_data$train)

iris_vector_assembler <- ft_vector_assembler(
  sc, 
  input_cols = setdiff(colnames(iris_data$train), "Species"), 
  output_col = "features"
)
random_forest <- ml_random_forest_classifier(sc,features_col = "features")

# obtain the labels from the fitted StringIndexerModel
iris_labels <- iris_pipeline_model %>%
  ml_stage("string_indexer") %>%
  ml_labels()

# IndexToString will convert the predicted numeric values back to class labels
iris_index_to_string <- ft_index_to_string(sc, "prediction", "predicted_label", 
                                      labels = iris_labels)

# construct a pipeline with these stages
iris_prediction_pipeline <- ml_pipeline(
  iris_pipeline, # pipeline from previous section
  iris_vector_assembler, 
  random_forest,
  iris_index_to_string
)

# fit to data and make some predictions
iris_prediction_model <- iris_prediction_pipeline %>%
  ml_fit(iris_data$train)
iris_predictions <- iris_prediction_model %>%
  ml_transform(iris_data$validation)
iris_predictions %>%
  select(Species, label:predicted_label) %>%
  glimpse()

После проб и ошибок, основываясь на рекомендациях здесь Мне удалось распечатать формулировку базового дерева решений в формате "если / еще", приведенном в виде строки:

model_stage <- iris_prediction_model$stages[[3]]

spark_jobj(model_stage) %>% invoke(., "toDebugString") %>% cat()
##print out below##
RandomForestClassificationModel (uid=random_forest_classifier_5c6a1934c8e) with 20 trees
  Tree 0 (weight 1.0):
    If (feature 2 <= 2.5)
     Predict: 1.0
    Else (feature 2 > 2.5)
     If (feature 2 <= 4.95)
      If (feature 3 <= 1.65)
       Predict: 0.0
      Else (feature 3 > 1.65)
       If (feature 0 <= 1.7833559100698644)
        Predict: 0.0
       Else (feature 0 > 1.7833559100698644)
        Predict: 2.0
     Else (feature 2 > 4.95)
      If (feature 2 <= 5.05)
       If (feature 1 <= 6.505000000000001)
        Predict: 2.0
       Else (feature 1 > 6.505000000000001)
        Predict: 0.0
      Else (feature 2 > 5.05)
       Predict: 2.0
  Tree 1 (weight 1.0):
    If (feature 3 <= 0.8)
     Predict: 1.0
    Else (feature 3 > 0.8)
     If (feature 3 <= 1.75)
      If (feature 1 <= 5.0649999999999995)
       If (feature 3 <= 1.05)
        Predict: 0.0
       Else (feature 3 > 1.05)
        If (feature 0 <= 1.8000241202036602)
         Predict: 2.0
        Else (feature 0 > 1.8000241202036602)
         Predict: 0.0
      Else (feature 1 > 5.0649999999999995)
       If (feature 0 <= 1.8000241202036602)
        Predict: 0.0
       Else (feature 0 > 1.8000241202036602)
        If (feature 2 <= 5.05)
         Predict: 0.0
        Else (feature 2 > 5.05)
         Predict: 2.0
     Else (feature 3 > 1.75)
      Predict: 2.0
  Tree 2 (weight 1.0):
    If (feature 3 <= 0.8)
     Predict: 1.0
    Else (feature 3 > 0.8)
     If (feature 0 <= 1.7664051342320237)
      Predict: 0.0
     Else (feature 0 > 1.7664051342320237)
      If (feature 3 <= 1.45)
       If (feature 2 <= 4.85)
        Predict: 0.0
       Else (feature 2 > 4.85)
        Predict: 2.0
      Else (feature 3 > 1.45)
       If (feature 3 <= 1.65)
        If (feature 1 <= 8.125)
         Predict: 2.0
        Else (feature 1 > 8.125)
         Predict: 0.0
       Else (feature 3 > 1.65)
        Predict: 2.0
  Tree 3 (weight 1.0):
    If (feature 0 <= 1.6675287895788053)
     If (feature 2 <= 2.5)
      Predict: 1.0
     Else (feature 2 > 2.5)
      Predict: 0.0
    Else (feature 0 > 1.6675287895788053)
     If (feature 3 <= 1.75)
      If (feature 3 <= 1.55)
       If (feature 1 <= 7.025)
        If (feature 2 <= 4.55)
         Predict: 0.0
        Else (feature 2 > 4.55)
         Predict: 2.0
       Else (feature 1 > 7.025)
        Predict: 0.0
      Else (feature 3 > 1.55)
       If (feature 2 <= 5.05)
        Predict: 0.0
       Else (feature 2 > 5.05)
        Predict: 2.0
     Else (feature 3 > 1.75)
      Predict: 2.0
  Tree 4 (weight 1.0):
    If (feature 2 <= 4.85)
     If (feature 2 <= 2.5)
      Predict: 1.0
     Else (feature 2 > 2.5)
      Predict: 0.0
    Else (feature 2 > 4.85)
     If (feature 2 <= 5.05)
      If (feature 0 <= 1.8484238118815566)
       Predict: 2.0
      Else (feature 0 > 1.8484238118815566)
       Predict: 0.0
     Else (feature 2 > 5.05)
      Predict: 2.0
  Tree 5 (weight 1.0):
    If (feature 2 <= 1.65)
     Predict: 1.0
    Else (feature 2 > 1.65)
     If (feature 3 <= 1.65)
      If (feature 0 <= 1.8325494627242664)
       Predict: 0.0
      Else (feature 0 > 1.8325494627242664)
       If (feature 2 <= 4.95)
        Predict: 0.0
       Else (feature 2 > 4.95)
        Predict: 2.0
     Else (feature 3 > 1.65)
      Predict: 2.0
  Tree 6 (weight 1.0):
    If (feature 2 <= 2.5)
     Predict: 1.0
    Else (feature 2 > 2.5)
     If (feature 2 <= 5.05)
      If (feature 3 <= 1.75)
       Predict: 0.0
      Else (feature 3 > 1.75)
       Predict: 2.0
     Else (feature 2 > 5.05)
      Predict: 2.0
  Tree 7 (weight 1.0):
    If (feature 3 <= 0.55)
     Predict: 1.0
    Else (feature 3 > 0.55)
     If (feature 3 <= 1.65)
      If (feature 2 <= 4.75)
       Predict: 0.0
      Else (feature 2 > 4.75)
       Predict: 2.0
     Else (feature 3 > 1.65)
      If (feature 2 <= 4.85)
       If (feature 0 <= 1.7833559100698644)
        Predict: 0.0
       Else (feature 0 > 1.7833559100698644)
        Predict: 2.0
      Else (feature 2 > 4.85)
       Predict: 2.0
  Tree 8 (weight 1.0):
    If (feature 3 <= 0.8)
     Predict: 1.0
    Else (feature 3 > 0.8)
     If (feature 3 <= 1.85)
      If (feature 2 <= 4.85)
       Predict: 0.0
      Else (feature 2 > 4.85)
       If (feature 0 <= 1.8794359129669855)
        Predict: 2.0
       Else (feature 0 > 1.8794359129669855)
        If (feature 3 <= 1.55)
         Predict: 0.0
        Else (feature 3 > 1.55)
         Predict: 0.0
     Else (feature 3 > 1.85)
      Predict: 2.0
  Tree 9 (weight 1.0):
    If (feature 2 <= 2.5)
     Predict: 1.0
    Else (feature 2 > 2.5)
     If (feature 2 <= 4.95)
      Predict: 0.0
     Else (feature 2 > 4.95)
      Predict: 2.0
  Tree 10 (weight 1.0):
    If (feature 3 <= 0.8)
     Predict: 1.0
    Else (feature 3 > 0.8)
     If (feature 2 <= 4.95)
      Predict: 0.0
     Else (feature 2 > 4.95)
      If (feature 2 <= 5.05)
       If (feature 3 <= 1.55)
        Predict: 2.0
       Else (feature 3 > 1.55)
        If (feature 3 <= 1.75)
         Predict: 0.0
        Else (feature 3 > 1.75)
         Predict: 2.0
      Else (feature 2 > 5.05)
       Predict: 2.0
  Tree 11 (weight 1.0):
    If (feature 3 <= 0.8)
     Predict: 1.0
    Else (feature 3 > 0.8)
     If (feature 2 <= 5.05)
      If (feature 2 <= 4.75)
       Predict: 0.0
      Else (feature 2 > 4.75)
       If (feature 3 <= 1.75)
        Predict: 0.0
       Else (feature 3 > 1.75)
        Predict: 2.0
     Else (feature 2 > 5.05)
      Predict: 2.0
  Tree 12 (weight 1.0):
    If (feature 3 <= 0.8)
     Predict: 1.0
    Else (feature 3 > 0.8)
     If (feature 3 <= 1.75)
      If (feature 3 <= 1.35)
       Predict: 0.0
      Else (feature 3 > 1.35)
       If (feature 0 <= 1.695573522904327)
        Predict: 0.0
       Else (feature 0 > 1.695573522904327)
        If (feature 1 <= 8.125)
         Predict: 2.0
        Else (feature 1 > 8.125)
         Predict: 0.0
     Else (feature 3 > 1.75)
      If (feature 0 <= 1.7833559100698644)
       Predict: 0.0
      Else (feature 0 > 1.7833559100698644)
       Predict: 2.0
  Tree 13 (weight 1.0):
    If (feature 3 <= 0.55)
     Predict: 1.0
    Else (feature 3 > 0.55)
     If (feature 2 <= 4.95)
      If (feature 2 <= 4.75)
       Predict: 0.0
      Else (feature 2 > 4.75)
       If (feature 0 <= 1.8000241202036602)
        If (feature 1 <= 9.305)
         Predict: 2.0
        Else (feature 1 > 9.305)
         Predict: 0.0
       Else (feature 0 > 1.8000241202036602)
        Predict: 0.0
     Else (feature 2 > 4.95)
      Predict: 2.0
  Tree 14 (weight 1.0):
    If (feature 2 <= 2.5)
     Predict: 1.0
    Else (feature 2 > 2.5)
     If (feature 3 <= 1.65)
      If (feature 3 <= 1.45)
       Predict: 0.0
      Else (feature 3 > 1.45)
       If (feature 2 <= 4.95)
        Predict: 0.0
       Else (feature 2 > 4.95)
        Predict: 2.0
     Else (feature 3 > 1.65)
      If (feature 0 <= 1.7833559100698644)
       If (feature 0 <= 1.7664051342320237)
        Predict: 2.0
       Else (feature 0 > 1.7664051342320237)
        Predict: 0.0
      Else (feature 0 > 1.7833559100698644)
       Predict: 2.0
  Tree 15 (weight 1.0):
    If (feature 2 <= 2.5)
     Predict: 1.0
    Else (feature 2 > 2.5)
     If (feature 3 <= 1.75)
      If (feature 2 <= 4.95)
       Predict: 0.0
      Else (feature 2 > 4.95)
       If (feature 1 <= 8.125)
        Predict: 2.0
       Else (feature 1 > 8.125)
        If (feature 0 <= 1.9095150692894909)
         Predict: 0.0
        Else (feature 0 > 1.9095150692894909)
         Predict: 2.0
     Else (feature 3 > 1.75)
      Predict: 2.0
  Tree 16 (weight 1.0):
    If (feature 3 <= 0.8)
     Predict: 1.0
    Else (feature 3 > 0.8)
     If (feature 0 <= 1.7491620461964392)
      Predict: 0.0
     Else (feature 0 > 1.7491620461964392)
      If (feature 3 <= 1.75)
       If (feature 2 <= 4.75)
        Predict: 0.0
       Else (feature 2 > 4.75)
        If (feature 0 <= 1.8164190316151556)
         Predict: 2.0
        Else (feature 0 > 1.8164190316151556)
         Predict: 0.0
      Else (feature 3 > 1.75)
       Predict: 2.0
  Tree 17 (weight 1.0):
    If (feature 0 <= 1.695573522904327)
     If (feature 2 <= 1.65)
      Predict: 1.0
     Else (feature 2 > 1.65)
      Predict: 0.0
    Else (feature 0 > 1.695573522904327)
     If (feature 2 <= 4.75)
      If (feature 2 <= 2.5)
       Predict: 1.0
      Else (feature 2 > 2.5)
       Predict: 0.0
     Else (feature 2 > 4.75)
      If (feature 3 <= 1.75)
       If (feature 1 <= 5.0649999999999995)
        Predict: 2.0
       Else (feature 1 > 5.0649999999999995)
        If (feature 3 <= 1.65)
         Predict: 0.0
        Else (feature 3 > 1.65)
         Predict: 0.0
      Else (feature 3 > 1.75)
       Predict: 2.0
  Tree 18 (weight 1.0):
    If (feature 3 <= 0.8)
     Predict: 1.0
    Else (feature 3 > 0.8)
     If (feature 3 <= 1.65)
      Predict: 0.0
     Else (feature 3 > 1.65)
      If (feature 0 <= 1.7833559100698644)
       Predict: 0.0
      Else (feature 0 > 1.7833559100698644)
       Predict: 2.0
  Tree 19 (weight 1.0):
    If (feature 2 <= 2.5)
     Predict: 1.0
    Else (feature 2 > 2.5)
     If (feature 2 <= 4.95)
      If (feature 1 <= 8.705)
       Predict: 0.0
      Else (feature 1 > 8.705)
       If (feature 2 <= 4.85)
        Predict: 0.0
       Else (feature 2 > 4.85)
        If (feature 0 <= 1.8164190316151556)
         Predict: 2.0
        Else (feature 0 > 1.8164190316151556)
         Predict: 0.0
     Else (feature 2 > 4.95)
      Predict: 2.0

Как видите, этот формат менее чем оптимален для перехода к одному из многих прекрасных методов визуализации графики дерева решений, которые я видел (например, аналитика революции или statmethods )

1 Ответ

0 голосов
/ 03 ноября 2018

На сегодняшний день (выпуск Spark 2.4.0 уже утвержден и ожидает официального объявления) ваша лучшая ставка * без привлечения сложных сторонних инструментов (например, вы можете посмотреть MLeap), вероятно, сохранить модель и считывание спецификации :

ml_stage(iris_prediction_model, "random_forest") %>% 
  ml_save("/tmp/model")

rf_spec <- spark_read_parquet(sc, "rf", "/tmp/model/data/")

Результатом будет Spark DataFrame со следующей схемой:

rf_spec %>% 
  spark_dataframe() %>% 
  invoke("schema") %>% invoke("treeString") %>% 
  cat(sep = "\n")
root
 |-- treeID: integer (nullable = true)
 |-- nodeData: struct (nullable = true)
 |    |-- id: integer (nullable = true)
 |    |-- prediction: double (nullable = true)
 |    |-- impurity: double (nullable = true)
 |    |-- impurityStats: array (nullable = true)
 |    |    |-- element: double (containsNull = true)
 |    |-- gain: double (nullable = true)
 |    |-- leftChild: integer (nullable = true)
 |    |-- rightChild: integer (nullable = true)
 |    |-- split: struct (nullable = true)
 |    |    |-- featureIndex: integer (nullable = true)
 |    |    |-- leftCategoriesOrThreshold: array (nullable = true)
 |    |    |    |-- element: double (containsNull = true)
 |    |    |-- numCategories: integer (nullable = true)

предоставление информации обо всех узлах и разбиениях.

Сопоставление объектов можно получить с помощью метаданных столбца:

meta <- iris_predictions %>% 
    select(features) %>% 
    spark_dataframe() %>% 
    invoke("schema") %>% invoke("apply", 0L) %>% 
    invoke("metadata") %>% 
    invoke("getMetadata", "ml_attr") %>% 
    invoke("getMetadata", "attrs") %>% 
    invoke("json") %>%
    jsonlite::fromJSON() %>% 
    dplyr::bind_rows() %>% 
    copy_to(sc, .) %>%
    rename(featureIndex = idx)

meta
# Source: spark<?> [?? x 2]
  featureIndex name        
*        <int> <chr>       
1            0 Sepal_Length
2            1 Sepal_Width 
3            2 Petal_Length
4            3 Petal_Width 

И отображение меток, которое вы уже получили:

labels <- tibble(prediction = seq_along(iris_labels) - 1, label = iris_labels) %>%
  copy_to(sc, .)

Наконец, вы можете объединить все это:

full_rf_spec <- rf_spec %>% 
  spark_dataframe() %>% 
  invoke("selectExpr", list("treeID", "nodeData.*", "nodeData.split.*")) %>% 
  sdf_register() %>% 
  select(-split, -impurityStats) %>% 
  left_join(meta, by = "featureIndex") %>% 
  left_join(labels, by = "prediction")

full_rf_spec
# Source: spark<?> [?? x 12]
   treeID    id prediction impurity    gain leftChild rightChild featureIndex
 *  <int> <int>      <dbl>    <dbl>   <dbl>     <int>      <int>        <int>
 1      0     0          1   0.636   0.379          1          2            2
 2      0     1          1   0      -1             -1         -1           -1
 3      0     2          0   0.440   0.367          3          8            2
 4      0     3          0   0.0555  0.0269         4          5            3
 5      0     4          0   0      -1             -1         -1           -1
 6      0     5          0   0.5     0.5            6          7            0
 7      0     6          0   0      -1             -1         -1           -1
 8      0     7          2   0      -1             -1         -1           -1
 9      0     8          2   0.111   0.0225         9         12            2
10      0     9          2   0.375   0.375         10         11            1
# ... with more rows, and 4 more variables: leftCategoriesOrThreshold <list>,
#   numCategories <int>, name <chr>, label <chr>

, который, собранный и разделенный treeID, должен дать достаточно информации ** для имитации древовидного объекта (вы можете получить хорошее представление о требуемой структуре, проверив rpart::rpart.object документацию и / или unclass с моделью rpart. tree::tree потребует меньше работы, но его утилиты для черчения далеко не впечатляют), и создайте приличный сюжет.

Альтернативный путь - экспортировать данные в PMML, используя Sparklyr2PMML и использовать это представление.

Вы также можете проверить Как визуализировать / построить дерево решений в Apache Spark (PySpark 1.4.1)? , который предлагает сторонний пакет Python для решения той же проблемы.

Если вам не нужно ничего необычного, вы можете создать грубый сюжет с помощью igraph:

library(igraph)

gframe <- full_rf_spec %>% 
  filter(treeID == 0) %>%   # Take the first tree
  mutate(
    leftCategoriesOrThreshold = ifelse(
      size(leftCategoriesOrThreshold) == 1,
      # Continuous variable case
      concat("<= ", round(concat_ws("", leftCategoriesOrThreshold), 3)),
      # Categorical variable case. Decoding variables might be involved
      # but can be achieved if needed, using column metadata or indexer labels
      concat("in {", concat_ws(",", leftCategoriesOrThreshold), "}")
    ),
    name = coalesce(name, label)) %>% 
 select(
   id, label, impurity, gain, 
   leftChild, rightChild, leftCategoriesOrThreshold, name) %>%
 collect()

vertices <- gframe %>% rename(label = name, name = id)

edges <- gframe %>%
  transmute(from = id, to = leftChild, label = leftCategoriesOrThreshold) %>% 
  union_all(gframe %>% select(from = id, to = rightChild)) %>% 
  filter(to != -1)

g <- igraph::graph_from_data_frame(edges, vertices = vertices)

plot(
  g, layout = layout_as_tree(g, root = c(1)),
  vertex.shape = "rectangle",  vertex.size = 45)

tree plot


* Это должно улучшиться в ближайшем будущем, благодаря недавно введенному API записи ML, не зависящему от формата (который уже поддерживает PMML Writer для выбранных моделей. Надеемся, что появятся новые модели и форматы).

** Если вы работаете с категориальными функциями, возможно, вы захотите сопоставить leftCategoriesOrThreshold с соответствующими индексированными уровнями.

Если вектор признаков содержит катагорические переменные, вывод jsonlite::fromJSON() будет содержать группу nominal. Например, если у вас есть индексированный столбец foo с тремя уровнями, собранный в первой позиции, он будет выглядеть примерно так:

$nominal
     vals idx      name
1 a, b, c   1       foo

где vals столбец - список векторов переменной длины.

length(meta$nominal$vals[[1]])
[1] 3

Метки соответствуют индексам этой структуры, поэтому в примере:

  • a имеет метку 0.0 (за исключением того, что метки являются числами с плавающей запятой двойной точности, а нумерация начинается с 0.0)
  • b имеет метку 1.0

и т. Д., И если вы разделите с leftCategoriesOrThreshold, равным, скажем, c(0.0, 2.0), это означает, что разделение на метках {"a", "c"}.

Обратите также внимание, что при наличии категориальных данных вам, возможно, придется обрабатывать их перед вызовом copy_to - на данный момент не похоже, что он поддерживает сложные поля.

В Spark <= 2.3 вам придется использовать R-код для отображения (в локальной структуре некоторые <code>purrr вполне подойдут). В Spark 2.4 (пока не поддерживается в sparklyr AFAIK) может быть проще считывать метаданные напрямую с помощью читателя JSON Spark и отображать его с функциями более высокого порядка.

...