Есть ли способ заставить Случайный Лесной Регрессор не соответствовать перехвату? - PullRequest
0 голосов
/ 29 марта 2020

Мне было интересно, есть ли способ установить sklearn Random Forest Regressor таким образом, чтобы вход «0» давал мне прогноз 0. Я знаю, что для линейных моделей я могу просто передать аргумент fit_intercept=False при инициализации, и я хочу повторить это для случайного леса.

Имеет ли смысл для модели на основе дерева достичь того, что я пытаюсь сделать? Если да, то как мне это реализовать?

1 Ответ

3 голосов
/ 29 марта 2020

Короткий ответ: Нет .


Длинный ответ:

Модели на основе деревьев очень отличаются от линейных; понятие перехвата даже не существует в деревьях.

Чтобы понять, почему это так, давайте адаптируем простой пример из документации (одно дерево решений с одной входной функцией) :

import numpy as np
from sklearn.tree import DecisionTreeRegressor, plot_tree
import matplotlib.pyplot as plt

# Create a random dataset
rng = np.random.RandomState(1)
X = np.sort(5 * rng.rand(80, 1), axis=0)
y = np.sin(X).ravel()
y[::5] += 3 * (0.5 - rng.rand(16))

# Fit regression model
regr = DecisionTreeRegressor(max_depth=2)
regr.fit(X, y)

# Predict
X_test = np.arange(0.0, 5.0, 0.01)[:, np.newaxis]
y_pred = regr.predict(X_test)

# Plot the results
plt.figure()
plt.scatter(X, y, s=20, edgecolor="black",
            c="darkorange", label="data")
plt.plot(X_test, y_pred, color="cornflowerblue",
         label="max_depth=2", linewidth=2)
plt.xlabel("data")
plt.ylabel("target")
plt.title("Decision Tree Regression")
plt.legend()
plt.show()

Вот вывод:

enter image description here

Грубо говоря, деревья решений пытаются аппроксимировать данные локально , следовательно, любая глобальная попытка (такая как линия перехвата) не существует в их юниверсе.

То, что дерево регрессии на самом деле возвращает в качестве вывода, - это среднее значение зависимой переменной y тренировочных образцов, которые заканчиваются в соответствующих терминальных узлах (листьях) во время подгонки. Чтобы увидеть это, давайте нарисуем дерево, которое мы только что поместили выше:

plt.figure()
plot_tree(regr, filled=True)
plt.show()

enter image description here

Обходя дерево в этом очень простом примере с игрушкой, вам следует суметь убедить себя, что прогноз для X=0 равен 0.052 (стрелки влево - это условие True узлов). Давайте проверим это:

regr.predict(np.array([0]).reshape(1,-1))
# array([0.05236068])

Я проиллюстрировал вышеизложенное очень простым деревом решений, чтобы дать вам представление о том, почему понятие перехвата здесь не существует; вывод о том, что это также имеет место с любой моделью, которая фактически основана и состоит из деревьев решений (таких как Случайный Лес), должен быть прямым.

...