How to use a for loop to plot Kaplan Meier with lifelines in python?

77 views Asked by At

I would like to make a Kaplan Meier plot with multiple groups. The code below can show lines for two groups in one plot but the amount of groups is binary. I would like to use a for loop that runs over a list containing all the groups but without a fixed length and > 2. How can I achieve this with lifelines?

from lifelines import KaplanMeierFitter
from lifelines.datasets import load_waltons

waltons = load_waltons()
ix = waltons['group'] == 'control'

ax = plt.subplot(111)

kmf_control = KaplanMeierFitter()
ax = kmf_control.fit(waltons.loc[ix]['T'], waltons.loc[ix]['E'],label='control').plot_survival_function(ax=ax)

kmf_exp = KaplanMeierFitter()
ax = kmf_exp.fit(waltons.loc[~ix]['T'], waltons.loc[~ix]['E'], label='exp').plot_survival_function(ax=ax)


from lifelines.plotting import add_at_risk_counts
add_at_risk_counts(kmf_exp, kmf_control, ax=ax)
plt.tight_layout()

Thank you in advance.

1

There are 1 answers

1
Peter Hill On

The key instead of each curve having unique name is to make each KM curve an element of a list by appending it and access it using the for loop index.

Note that add_at_risk_counts uses "*list_of_fits", taken from from this example. It uses "enumerate" to drive the iteration.

If you know the number of data sets in a list of dataframes:

from lifelines import KaplanMeierFitter
from lifelines.utils import median_survival_times
from lifelines.plotting import add_at_risk_counts
from lifelines.datasets import load_waltons

# import pandas as pd
import matplotlib.pyplot as plt

waltons = load_waltons()
ix = waltons['group'] == 'control'

SD = 2 # Number of data sets
CI = False # True or False to show confidence intervals
SC = True # True or False to show censor tick marks

kmf_ = []
data_ = []
data_.append(waltons.loc[ix])
data_.append(waltons.loc[~ix])
Set_ = ['control', 'exp']

ax = plt.subplot(111)

for i in range(0, SD): # index for loop
#    print("set : ", i)
    kmf_.append(KaplanMeierFitter())
    kmf_[i].fit_right_censoring(data_[i]['T'], data_[i]['E'], 
                                label=Set_[i]) 
    ax = kmf_[i].plot_survival_function(ax=ax, ci_show=CI, show_censors=SC)

add_at_risk_counts(*kmf_, ax=ax, labels=[Set_[l] for  l in range(0, SD)])
plt.tight_layout()

Or if number of datasets is unknown but has a key/level you can use groupby and increment an index on each iteration.

from lifelines import KaplanMeierFitter
from lifelines.utils import median_survival_times
from lifelines.plotting import add_at_risk_counts
from lifelines.datasets import load_waltons

# import pandas as pd
import matplotlib.pyplot as plt

waltons = load_waltons()

CI = False # True or False to show confidence intervals
SC = True # True or False to show censor tick marks

# print(waltons)

ax = plt.subplot(111)

kmf_ = []
i = 0 # initialise index
for name, grouped_df in waltons.groupby('group'): 
    kmf_.append(KaplanMeierFitter())
    kmf_[i].fit(grouped_df["T"], grouped_df["E"], label=name)
    kmf_[i].plot_survival_function(ax=ax, ci_show=CI, show_censors=SC)
    print(kmf_[i])
    i = i + 1 # increment index

add_at_risk_counts(*kmf_,  
                   ax=ax,
                   rows_to_show=["At risk"])

plt.tight_layout()