How to better develop function for automating Kaplan Meier survival function estimation?

228 views Asked by At

I'm hoping to get some useful feedback on how to improve a quite lengthy function I wrote to make Kaplan Meier survival function plots quick and easy to plot for all features in a dataframe. I've only been coding for about a year and a half now and would like some pointers on how to write functions/classes which are reproducible and useful to others. Any feedback would be awesome, you can find the code for the function below -- cheers!

def kpm_groups_auto_plot(df,column,num_groups,group_values,group_labels=None):
    """
    Plots kaplan meire estimation of the survival function for an input column within a dataframe. If the column is numerical
    this function automatically discretizes the variable into 10-bin quantiles.
    args: Dataframe, column name to plot as a string, number of discrete groups to plot, which groups to plot, labels of the different
    categorical levels of a feature default is none
    Returns: None -- side effects are kaplan meire survival function plots.
    """
    df_temp = df.copy(deep=True)
    targets = ['default_time','status_time']
    if df[column].dtype == 'object':
        if num_groups == 2:
            df_0 =  df[df[column]==group_values[0]]
            df_1 =  df[df[column]==group_values[1]]

            kmf_0 = KaplanMeierFitter()
            kmf_1 = KaplanMeierFitter()

            kmf_0.fit(df_0['time'],df_0['default_time'],label=group_labels[0])
            kmf_1.fit(df_1['time'],df_1['default_time'],label=group_labels[1])

            kmf_0.plot_survival_function()
            kmf_1.plot_survival_function()

            plt.ylabel('Probability of survival (Probability of not defaulting)')
            plt.xlabel('Months')

            add_at_risk_counts(kmf_0,kmf_1)
            plt.title(f'Survival function for groups in {column}')
            plt.tight_layout()
            plt.show()
        elif num_groups == 3:
            df_0 =  df[df[column]==group_values[0]]
            df_1 =  df[df[column]==group_values[1]]
            df_2 =  df[df[column]==group_values[2]]

            kmf_0 = KaplanMeierFitter()
            kmf_1 = KaplanMeierFitter()
            kmf_2 = KaplanMeierFitter()

            kmf_0.fit(df_0['time'],df_0['default_time'],label=group_labels[0])
            kmf_1.fit(df_1['time'],df_1['default_time'],label=group_labels[1])
            kmf_2.fit(df_2['time'],df_2['default_time'],label=group_labels[2])

            kmf_0.plot_survival_function()
            kmf_1.plot_survival_function()
            kmf_2.plot_survival_function()

            plt.ylabel('Probability of survival (Probability of not defaulting)')
            plt.xlabel('Months')

            add_at_risk_counts(kmf_0,kmf_1,kmf_2)
            plt.title(f'Survival function for groups in {column}')
            plt.tight_layout()
            plt.show()
        elif num_groups == 4:
            df_0 =  df[df[column]==group_values[0]]
            df_1 =  df[df[column]==group_values[1]]
            df_2 =  df[df[column]==group_values[2]]
            df_3 =  df[df[column]==group_values[3]]

            kmf_0 = KaplanMeierFitter()
            kmf_1 = KaplanMeierFitter()
            kmf_2 = KaplanMeierFitter()
            kmf_3 = KaplanMeierFitter()

            kmf_0.fit(df_0['time'],df_0['default_time'],label=group_labels[0])
            kmf_1.fit(df_1['time'],df_1['default_time'],label=group_labels[1])
            kmf_2.fit(df_2['time'],df_2['default_time'],label=group_labels[2])
            kmf_3.fit(df_3['time'],df_3['default_time'],label=group_labels[3])

            kmf_0.plot_survival_function()
            kmf_1.plot_survival_function()
            kmf_2.plot_survival_function()
            kmf_3.plot_survival_function()

            plt.ylabel('Probability of survival (Probability of not defaulting)')
            plt.xlabel('Months')

            add_at_risk_counts(kmf_0,kmf_1,kmf_2,kmf_3)
            plt.title(f'Survival function for groups in {column}')
            plt.tight_layout()
            plt.show()

        elif num_groups == 5:
            df_0 =  df[df[column]==group_values[0]]
            df_1 =  df[df[column]==group_values[1]]
            df_2 =  df[df[column]==group_values[2]]
            df_3 =  df[df[column]==group_values[3]]
            df_4 =  df[df[column]==group_values[4]]

            kmf_0 = KaplanMeierFitter()
            kmf_1 = KaplanMeierFitter()
            kmf_2 = KaplanMeierFitter()
            kmf_3 = KaplanMeierFitter()
            kmf_4 = KaplanMeierFitter()

            kmf_0.fit(df_0['time'],df_0['default_time'],label=group_labels[0])
            kmf_1.fit(df_1['time'],df_1['default_time'],label=group_labels[1])
            kmf_2.fit(df_2['time'],df_2['default_time'],label=group_labels[2])
            kmf_3.fit(df_3['time'],df_3['default_time'],label=group_labels[3])
            kmf_4.fit(df_4['time'],df_4['default_time'],label=group_labels[4])

            kmf_0.plot_survival_function()
            kmf_1.plot_survival_function()
            kmf_2.plot_survival_function()
            kmf_3.plot_survival_function()
            kmf_4.plot_survival_function()

            plt.ylabel('Probability of survival (Probability of not defaulting)')
            plt.xlabel('Months')

            add_at_risk_counts(kmf_0,kmf_1,kmf_2,kmf_3,kmf_4)
            plt.title(f'Survival function for groups in {column}')
            plt.tight_layout()
            plt.show()

        else: 
            print('Must provide 2-5 groups')
    else:
        if any(column in var for var in targets):
            pass
        elif mortgage_df[column].isin([0,1]).all() == True:
            df_0 =  df[df[column]==0]
            df_1 =  df[df[column]==1]

            kmf_0 = KaplanMeierFitter()
            kmf_1 = KaplanMeierFitter()

            kmf_0.fit(df_0['time'],df_0['default_time'],label='No')
            kmf_1.fit(df_1['time'],df_1['default_time'],label='Yes')

            kmf_0.plot_survival_function()
            kmf_1.plot_survival_function()

            plt.ylabel('Probability of survival (Probability of not defaulting)')
            plt.xlabel('Months')

            add_at_risk_counts(kmf_0,kmf_1)
            plt.title(f'Survival function for groups in {column}')
            plt.tight_layout()
            plt.show()
        else:
            df_temp[column+'_discretized'] = pd.qcut(df[column],q=10,labels=False, duplicates='drop')
            # Check if col is binary
            
            if num_groups == 2:
                df_0 =  df_temp[df_temp[column+'_discretized']==group_values[0]]
                df_1 =  df_temp[df_temp[column+'_discretized']==group_values[1]]

                kmf_0 = KaplanMeierFitter()
                kmf_1 = KaplanMeierFitter()

                kmf_0.fit(df_0['time'],df_0['default_time'],label=str(group_values[0])+'Quantile')
                kmf_1.fit(df_1['time'],df_1['default_time'],label=str(group_values[1])+'Quantile')

                kmf_0.plot_survival_function()
                kmf_1.plot_survival_function()

                plt.ylabel('Probability of survival (Probability of not defaulting)')
                plt.xlabel('Months')

                add_at_risk_counts(kmf_0,kmf_1)
                plt.title(f'Survival function for groups in {column}_discretized')
                plt.tight_layout()
                plt.show()
                    
            elif num_groups == 3:
                df_0 =  df_temp[df_temp[column+'_discretized']==group_values[0]]
                df_1 =  df_temp[df_temp[column+'_discretized']==group_values[1]]
                df_2 =  df_temp[df_temp[column+'_discretized']==group_values[2]]

                kmf_0 = KaplanMeierFitter()
                kmf_1 = KaplanMeierFitter()
                kmf_2 = KaplanMeierFitter()

                kmf_0.fit(df_0['time'],df_0['default_time'],label=str(group_values[0])+'Quantile')
                kmf_1.fit(df_1['time'],df_1['default_time'],label=str(group_values[1])+'Quantile')
                kmf_2.fit(df_2['time'],df_2['default_time'],label=str(group_values[2])+'Quantile')

                kmf_0.plot_survival_function()
                kmf_1.plot_survival_function()
                kmf_2.plot_survival_function()

                plt.ylabel('Probability of survival (Probability of not defaulting)')
                plt.xlabel('Months')

                add_at_risk_counts(kmf_0,kmf_1,kmf_2)
                plt.title(f'Survival function for groups in {column}_discretized')
                plt.tight_layout()
                plt.show()
            elif num_groups == 4:
                df_0 =  df_temp[df_temp[column+'_discretized']==group_values[0]]
                df_1 =  df_temp[df_temp[column+'_discretized']==group_values[1]]
                df_2 =  df_temp[df_temp[column+'_discretized']==group_values[2]]
                df_3 =  df_temp[df_temp[column+'_discretized']==group_values[3]]

                kmf_0 = KaplanMeierFitter()
                kmf_1 = KaplanMeierFitter()
                kmf_2 = KaplanMeierFitter()
                kmf_3 = KaplanMeierFitter()

                kmf_0.fit(df_0['time'],df_0['default_time'],label=str(group_values[0])+'Quantile')
                kmf_1.fit(df_1['time'],df_1['default_time'],label=str(group_values[1])+'Quantile')
                kmf_2.fit(df_2['time'],df_2['default_time'],label=str(group_values[2])+'Quantile')
                kmf_3.fit(df_3['time'],df_3['default_time'],label=str(group_values[3])+'Quantile')

                kmf_0.plot_survival_function()
                kmf_1.plot_survival_function()
                kmf_2.plot_survival_function()
                kmf_3.plot_survival_function()

                plt.ylabel('Probability of survival (Probability of not defaulting)')
                plt.xlabel('Months')

                add_at_risk_counts(kmf_0,kmf_1,kmf_2,kmf_3)
                plt.title(f'Survival function for groups in {column}_discretized')
                plt.tight_layout()
                plt.show()

            elif num_groups == 5:
                df_0 =  df_temp[df_temp[column+'_discretized']==group_values[0]]
                df_1 =  df_temp[df_temp[column+'_discretized']==group_values[1]]
                df_2 =  df_temp[df_temp[column+'_discretized']==group_values[2]]
                df_3 =  df_temp[df_temp[column+'_discretized']==group_values[3]]
                df_4 =  df_temp[df_temp[column+'_discretized']==group_values[4]]

                kmf_0 = KaplanMeierFitter()
                kmf_1 = KaplanMeierFitter()
                kmf_2 = KaplanMeierFitter()
                kmf_3 = KaplanMeierFitter()
                kmf_4 = KaplanMeierFitter()

                kmf_0.fit(df_0['time'],df_0['default_time'],label=str(group_values[0])+'Quantile')
                kmf_1.fit(df_1['time'],df_1['default_time'],label=str(group_values[1])+'Quantile')
                kmf_2.fit(df_2['time'],df_2['default_time'],label=str(group_values[2])+'Quantile')
                kmf_3.fit(df_3['time'],df_3['default_time'],label=str(group_values[3])+'Quantile')
                kmf_4.fit(df_3['time'],df_3['default_time'],label=str(group_values[4])+'Quantile')

                kmf_0.plot_survival_function()
                kmf_1.plot_survival_function()
                kmf_2.plot_survival_function()
                kmf_3.plot_survival_function()
                kmf_4.plot_survival_function()

                plt.ylabel('Probability of survival (Probability of not defaulting)')
                plt.xlabel('Months')

                add_at_risk_counts(kmf_0,kmf_1,kmf_2,kmf_3,kmf_4)
                plt.title(f'Survival function for groups in {column}_discretized')
                plt.tight_layout()
                plt.show()

            else: 
                print('Must provide 2-5 groups')
        
1

There are 1 answers

0
mitoRibo On

It is difficult to give feedback since your code is so long and there are functions that you haven't shared like KaplanMeierFitter()

However, it looks like the code is so long because you have to handle cases with different numbers of groups and that causes a lot of code duplication. Could you handle these different cases with a loop?

Also there are hard-coded columns like time and default_time which look like they have to be part of the dataframe

In the future for similar questions you should instead use https://codereview.stackexchange.com/

My final suggestion is to try and use existing libraries as that means less code for you to write, maintain, document etc.

I'd use the seaborn plotting library which builds on top of matplotlib and I think will be very helpful in your case. Here's an example of how it handles different numbers of groups. There is definitely a learning curve to seaborn

Feel free to run this code yourself and play with the different parameters and comment if there are parts you need help understanding, or if this doesn't seem like a suitable alternative to your code

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

#Make fake survival data
np.random.seed(10)

num_months = 24
num_groups = 4

df = pd.DataFrame({
    'Group {}'.format(g): sorted(np.random.randint(low=0,high=100,size=num_months),reverse=True)
    for g in range(num_groups)
})
df['Month'] = range(1,num_months+1)


#Convert table to "long form" which seaborn functions work well with
long_df = df.melt(
    id_vars = 'Month',
    var_name = 'Group',
    value_name = 'Individuals alive',
)

#Convert Individuals alive into percent survival for each group
long_df['Percent survival'] = long_df.groupby('Group')['Individuals alive'].transform(lambda v: v/max(v))

#Helpful seaborn plot which can handle any number of groups
sns.lineplot(
    x = 'Month',
    y = 'Percent survival',
    hue = 'Group',
    #style = 'Group', #optional for styling the plot
    #dashes = False,  #optional for styling the plot
    #markers = True,  #optional for styling the plot
    data = long_df,
)
plt.show()
plt.close()

enter image description here