Inspection of trees in a Quantile Random Forest Regression model

285 views Asked by At

I am interested in training a random forest to learn some conditional quantile on some data {X, y} sampled independently from some distribution.

That is, for some $$\alpha \in (0, 1)$$, a mapping $$\hat{q}{\alpha}(x) \in [0, 1]$$ such that for each $X$, $$argmin{\hat{q}{\alpha} P(y < \hat{q}\alpha(x)) > \alpha$$.

Is there any clear way to build a random forest effectively in python that could yield such a model?

Additionally, I have one added requirement that may be possible with the current libraries, though I am unsure. Requirement: I would like to select a subset of points, A, from my training set and select and exclude those trees that were trained with points in A from my random forest as I make predictions.

1

There are 1 answers

0
Reid Johnson On

There is a Python-based, scikit-learn compatible/compliant Quantile Regression Forest implementation that can be used to estimate conditional quantiles here: https://github.com/zillow/quantile-forest

Your additional requirement of making predictions on training samples by excluding trees that included those samples during training is called out-of-bag (OOB) estimation, and can also be done with the above package.

Setup should be as easy as:

pip install quantile-forest

Then, here's an example of how to fit a quantile random forest model and use it to predict quantiles with OOB estimation for a subset (here the first 100 rows) of the training data:

import numpy as np
from quantile_forest import RandomForestQuantileRegressor
from sklearn import datasets
from sklearn.model_selection import train_test_split

X, y = datasets.fetch_california_housing(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y)

qrf = RandomForestQuantileRegressor()
qrf.fit(X_train, y_train)

# Predict OOB quantiles for first 100 training samples.
y_pred_oob = qrf.predict(
    X_train[:100, :],
    quantiles=[0.025, 0.5, 0.975],
    oob_score=True,
    indices=np.arange(100),
)