How does XGBoost use MAE as objective function?

535 views Asked by At

XGBoost did not allow to use absolute error as objective function in the past since it is a non-differentiable function and its Hessian is equal to 0. However, it does allow to use it now (https://xgboost.readthedocs.io/en/stable/parameter.html).

How does it do it and how can I define non-differentiable custom objective functions?

I've tried to implement it as simply:

def absolute_error(predt, dtrain):
    y_true = dtrain.get_label()
    errors = y_true - predt
    grad = -1.0 * np.sign(errors)  # Gradient (negative of the sign of the error)
    hess = np.zeros_like(y_true)  # Hessian (constant, 0, in this case)
    return grad, hess

But it obviously does not work

1

There are 1 answers

2
Sandipan Dey On

As described here and here, the absolute error or L1 loss function is not continuously twice differentiable. XGBoost uses the second derivative is used as a denominator in the leaf weight, hence a constant 0 as the Hessian does not work.

We can use a smooth version of the L1 loss instead (as shown below) or some other smooth Loss functions as Huber-loss, which gets rid of the discontinuity at 0 and in the nbd.

enter image description here

def absolute_error_obj(alpha):
    def absolute_error(labels, predt):
        x = predt - labels
        grad = np.sign(x) 
        grad[np.abs(x) < alpha] = 2/alpha*x[np.abs(x) < alpha]
        hess = np.zeros_like(labels) 
        hess[np.abs(x) < alpha] = 2/alpha
        return grad, hess
    return absolute_error

This works, as demonstrated with the diabetes dataset, since the loss function decreases over time and converges.

from sklearn.datasets import load_diabetes
import numpy as np
import matplotlib.pylab as plt
X, y = load_diabetes(return_X_y=True)
reg = xgb.XGBRegressor(tree_method="hist", objective=absolute_error_obj(75))
reg.fit(X, y, eval_set=[(X, y)])

enter image description here

The next figure shows the scatter plot of the actual values and the predicted values for the target variable y.

ypred = reg.predict(X)
plt.scatter(y, ypred)
plt.show()

enter image description here