Making sense of gbm survival prediction model

558 views Asked by At

I am a newbie in using and making sense of ML methods and currently doing survival analysis using gbm package in R.

I have difficulty understanding some of the output of the survival prediction model. I have checked this tutorial and this post but still, find trouble in making sense of the outputted survival prediction model.

Here is my code for analysis based on example data:

rm(list=ls(all=TRUE))
library(randomForestSRC)
library(gbm)
library(survival)
library(Hmisc)

data(pbc, package="randomForestSRC")
data <- na.omit(pbc)

set.seed(9512)
train <- sample(1:nrow(data), round(nrow(data)*0.7))
data.train <- data[train, ]
data.test <- data[-train, ]

set.seed(9741)
model <- gbm(Surv(days, status)~.,
           data.train,
           interaction.depth=2,
           shrinkage=0.01,
           n.trees=500,
           distribution="coxph",
           cv.folds = 5)

summary(model)

best.iter <- gbm.perf(model, plot.it = TRUE, method = 'cv',
                      overlay = TRUE) #to get the optimal number of Boosting iterations
best.iter

#Us the best number of tree to produce predicted values for each observation in newdata 
# return a vector of prediction on n.trees indicting log hazard scale.f(x)
# By default the predictions are on log hazard scale for coxph
# proportional hazard model assumes h(t|x)=lambda(t)*exp(f(x)).
# estimate the f(x) component of the hazard function
pred.train <- predict(object=model, newdata=data.train, n.trees = best.iter)
pred.test <- predict(object=model, newdata=data.test, n.trees = best.iter)


#trainig set
Hmisc::rcorr.cens(-pred.train, Surv(data.train$days, data.train$status))
#val set
Hmisc::rcorr.cens(-pred.test, Surv(data.test$days, data.test$status))

# Estimate the cumulative baseline hazard function using training data
basehaz.cum <- basehaz.gbm(t=data.train$days,       #The survival times.
                           delta=data.train$status, #The censoring indicator
                           f.x=pred.train,          #The predicted values of the regression model on the log hazard scale.
                           t.eval = data.train$days,  #Values at which the baseline hazard will be evaluated
                           cumulative = TRUE,       #If TRUE the cumulative survival function will be computed
                           smooth = FALSE)          #If TRUE basehaz.gbm will smooth the estimated baseline hazard using Friedman's super smoother supsmu.

basehaz.cum

#Estimation of survival rate of all:
surv.rate <- exp(-exp(pred.train)*basehaz.cum)
surv.rate

res_train <- data.train
# predicted outcome for train set
res_train$pred <- pred.train
res_train$survival_rate <- surv.rate
res_train


# Estimate the cumulative baseline hazard function using training data
basehaz.cum <- basehaz.gbm(t=data.test$days,       #The survival times.
                           delta=data.test$status, #The censoring indicator
                           f.x=pred.test,          #The predicted values of the regression model on the log hazard scale.
                           t.eval = data.test$days,  #Values at which the baseline hazard will be evaluated
                           cumulative = TRUE,       #If TRUE the cumulative survival function will be computed
                           smooth = FALSE)          #If TRUE basehaz.gbm will smooth the estimated baseline hazard using Friedman's super smoother supsmu.

basehaz.cum
#Estimation of survival rate of all at specified time is:
surv.rate <- exp(-exp(pred.test)*basehaz.cum)
surv.rate

res_test <- data.test
# predicted outcome for test set
res_test$pred <- pred.test
res_test$survival_rate <- surv.rate
res_test

#--------------------------------------------------
#Estimate survival rate at time of interest

# Specify time of interest
time.interest <- sort(unique(data.train$days[data.train$status==1]))

# Estimate the cumulative baseline hazard function using training data
basehaz.cum <- basehaz.gbm(t=data.train$days,       #The survival times.
                           delta=data.train$status, #The censoring indicator
                           f.x=pred.train,          #The predicted values of the regression model on the log hazard scale.
                           t.eval = time.interest,  #Values at which the baseline hazard will be evaluated
                           cumulative = TRUE,       #If TRUE the cumulative survival function will be computed
                           smooth = FALSE)          #If TRUE basehaz.gbm will smooth the estimated baseline hazard using Friedman's super smoother supsmu.


#For individual $i$ in test set, estimation of survival function is:
surf.i <- exp(-exp(pred.test[1])*basehaz.cum) #survival rate

#Estimation of survival rate of all at specified time is:
specif.time <- time.interest[10]
surv.rate <- exp(-exp(pred.test)*basehaz.cum[10])
cat("Survival Rate of all at time", specif.time, "\n")
print(surv.rate)

The output returned from the predict function represents the f(x) component of the hazard function ( h(t|x)=lambda(t)*exp(f(x)) ).

My questions:

• A bit confused about whether hazard ratios can be calculated here?

• Wondering how can I divide the population into low-risk and high-risk groups? Can I rely on the estimated f(x) component of the hazard function to do the scoring system for the training set? I aim from this to have a scoring system where I show KM plots for low and high-risk groups for training and test sets.

• How can I construct calibration curve plots where I can plot observed survival vs. predicted survival for the training set and test set?

1

There are 1 answers

1
Pei Liu On BEST ANSWER

Amer. Thx for your reading of my tutorial!

As you mentioned that "The output returned from the predict function represents the f(x) component of the hazard function ( h(t|x)=lambda(t)*exp(f(x)) )", maybe we need to understand the hazard function, i.e. h(t|x).

Before this, please sure that you have the basic knowledge of survival analysis. if not, it's recommended to read the great post. I think the post would help you solve the questions.

Back to your questions:

  • Exactly, we can get the hazard ratios of log scale by invoking the predict function. Therefore, the hazard ratio can be calculated by exp() .
  • Sure! Relying on the values of hazard ratio, we can divide the population into low-risk and high-risk groups. Alternatively, you can use the median of hazard ratios as the cutoff value. I think the cutoff value should be derived from the training set, and then test in the test set. If your model is effective, KM plots for low and high-risk groups would have a significant difference (measured by log-rank test statistically).
  • Calibration curve plots are often used to evaluated the performance of model that outputs probabilities or likelihoods ranged from [0.0, 1.0]. We can calculate the survival function, and then specify a time point of interest, e.g. 5-Year. At last, we compare the survival probabilities with the actual survival state at the specified time, which is just the same as we do evaluating a binary classification model. More details of obtaining survival function can refer to my tutorial, and the principles can be found in that post aforementioned.