I'm hoping to get some useful feedback on how to improve a quite lengthy function I wrote to make Kaplan Meier survival function plots quick and easy to plot for all features in a dataframe. I've only been coding for about a year and a half now and would like some pointers on how to write functions/classes which are reproducible and useful to others. Any feedback would be awesome, you can find the code for the function below -- cheers!
def kpm_groups_auto_plot(df,column,num_groups,group_values,group_labels=None):
"""
Plots kaplan meire estimation of the survival function for an input column within a dataframe. If the column is numerical
this function automatically discretizes the variable into 10-bin quantiles.
args: Dataframe, column name to plot as a string, number of discrete groups to plot, which groups to plot, labels of the different
categorical levels of a feature default is none
Returns: None -- side effects are kaplan meire survival function plots.
"""
df_temp = df.copy(deep=True)
targets = ['default_time','status_time']
if df[column].dtype == 'object':
if num_groups == 2:
df_0 = df[df[column]==group_values[0]]
df_1 = df[df[column]==group_values[1]]
kmf_0 = KaplanMeierFitter()
kmf_1 = KaplanMeierFitter()
kmf_0.fit(df_0['time'],df_0['default_time'],label=group_labels[0])
kmf_1.fit(df_1['time'],df_1['default_time'],label=group_labels[1])
kmf_0.plot_survival_function()
kmf_1.plot_survival_function()
plt.ylabel('Probability of survival (Probability of not defaulting)')
plt.xlabel('Months')
add_at_risk_counts(kmf_0,kmf_1)
plt.title(f'Survival function for groups in {column}')
plt.tight_layout()
plt.show()
elif num_groups == 3:
df_0 = df[df[column]==group_values[0]]
df_1 = df[df[column]==group_values[1]]
df_2 = df[df[column]==group_values[2]]
kmf_0 = KaplanMeierFitter()
kmf_1 = KaplanMeierFitter()
kmf_2 = KaplanMeierFitter()
kmf_0.fit(df_0['time'],df_0['default_time'],label=group_labels[0])
kmf_1.fit(df_1['time'],df_1['default_time'],label=group_labels[1])
kmf_2.fit(df_2['time'],df_2['default_time'],label=group_labels[2])
kmf_0.plot_survival_function()
kmf_1.plot_survival_function()
kmf_2.plot_survival_function()
plt.ylabel('Probability of survival (Probability of not defaulting)')
plt.xlabel('Months')
add_at_risk_counts(kmf_0,kmf_1,kmf_2)
plt.title(f'Survival function for groups in {column}')
plt.tight_layout()
plt.show()
elif num_groups == 4:
df_0 = df[df[column]==group_values[0]]
df_1 = df[df[column]==group_values[1]]
df_2 = df[df[column]==group_values[2]]
df_3 = df[df[column]==group_values[3]]
kmf_0 = KaplanMeierFitter()
kmf_1 = KaplanMeierFitter()
kmf_2 = KaplanMeierFitter()
kmf_3 = KaplanMeierFitter()
kmf_0.fit(df_0['time'],df_0['default_time'],label=group_labels[0])
kmf_1.fit(df_1['time'],df_1['default_time'],label=group_labels[1])
kmf_2.fit(df_2['time'],df_2['default_time'],label=group_labels[2])
kmf_3.fit(df_3['time'],df_3['default_time'],label=group_labels[3])
kmf_0.plot_survival_function()
kmf_1.plot_survival_function()
kmf_2.plot_survival_function()
kmf_3.plot_survival_function()
plt.ylabel('Probability of survival (Probability of not defaulting)')
plt.xlabel('Months')
add_at_risk_counts(kmf_0,kmf_1,kmf_2,kmf_3)
plt.title(f'Survival function for groups in {column}')
plt.tight_layout()
plt.show()
elif num_groups == 5:
df_0 = df[df[column]==group_values[0]]
df_1 = df[df[column]==group_values[1]]
df_2 = df[df[column]==group_values[2]]
df_3 = df[df[column]==group_values[3]]
df_4 = df[df[column]==group_values[4]]
kmf_0 = KaplanMeierFitter()
kmf_1 = KaplanMeierFitter()
kmf_2 = KaplanMeierFitter()
kmf_3 = KaplanMeierFitter()
kmf_4 = KaplanMeierFitter()
kmf_0.fit(df_0['time'],df_0['default_time'],label=group_labels[0])
kmf_1.fit(df_1['time'],df_1['default_time'],label=group_labels[1])
kmf_2.fit(df_2['time'],df_2['default_time'],label=group_labels[2])
kmf_3.fit(df_3['time'],df_3['default_time'],label=group_labels[3])
kmf_4.fit(df_4['time'],df_4['default_time'],label=group_labels[4])
kmf_0.plot_survival_function()
kmf_1.plot_survival_function()
kmf_2.plot_survival_function()
kmf_3.plot_survival_function()
kmf_4.plot_survival_function()
plt.ylabel('Probability of survival (Probability of not defaulting)')
plt.xlabel('Months')
add_at_risk_counts(kmf_0,kmf_1,kmf_2,kmf_3,kmf_4)
plt.title(f'Survival function for groups in {column}')
plt.tight_layout()
plt.show()
else:
print('Must provide 2-5 groups')
else:
if any(column in var for var in targets):
pass
elif mortgage_df[column].isin([0,1]).all() == True:
df_0 = df[df[column]==0]
df_1 = df[df[column]==1]
kmf_0 = KaplanMeierFitter()
kmf_1 = KaplanMeierFitter()
kmf_0.fit(df_0['time'],df_0['default_time'],label='No')
kmf_1.fit(df_1['time'],df_1['default_time'],label='Yes')
kmf_0.plot_survival_function()
kmf_1.plot_survival_function()
plt.ylabel('Probability of survival (Probability of not defaulting)')
plt.xlabel('Months')
add_at_risk_counts(kmf_0,kmf_1)
plt.title(f'Survival function for groups in {column}')
plt.tight_layout()
plt.show()
else:
df_temp[column+'_discretized'] = pd.qcut(df[column],q=10,labels=False, duplicates='drop')
# Check if col is binary
if num_groups == 2:
df_0 = df_temp[df_temp[column+'_discretized']==group_values[0]]
df_1 = df_temp[df_temp[column+'_discretized']==group_values[1]]
kmf_0 = KaplanMeierFitter()
kmf_1 = KaplanMeierFitter()
kmf_0.fit(df_0['time'],df_0['default_time'],label=str(group_values[0])+'Quantile')
kmf_1.fit(df_1['time'],df_1['default_time'],label=str(group_values[1])+'Quantile')
kmf_0.plot_survival_function()
kmf_1.plot_survival_function()
plt.ylabel('Probability of survival (Probability of not defaulting)')
plt.xlabel('Months')
add_at_risk_counts(kmf_0,kmf_1)
plt.title(f'Survival function for groups in {column}_discretized')
plt.tight_layout()
plt.show()
elif num_groups == 3:
df_0 = df_temp[df_temp[column+'_discretized']==group_values[0]]
df_1 = df_temp[df_temp[column+'_discretized']==group_values[1]]
df_2 = df_temp[df_temp[column+'_discretized']==group_values[2]]
kmf_0 = KaplanMeierFitter()
kmf_1 = KaplanMeierFitter()
kmf_2 = KaplanMeierFitter()
kmf_0.fit(df_0['time'],df_0['default_time'],label=str(group_values[0])+'Quantile')
kmf_1.fit(df_1['time'],df_1['default_time'],label=str(group_values[1])+'Quantile')
kmf_2.fit(df_2['time'],df_2['default_time'],label=str(group_values[2])+'Quantile')
kmf_0.plot_survival_function()
kmf_1.plot_survival_function()
kmf_2.plot_survival_function()
plt.ylabel('Probability of survival (Probability of not defaulting)')
plt.xlabel('Months')
add_at_risk_counts(kmf_0,kmf_1,kmf_2)
plt.title(f'Survival function for groups in {column}_discretized')
plt.tight_layout()
plt.show()
elif num_groups == 4:
df_0 = df_temp[df_temp[column+'_discretized']==group_values[0]]
df_1 = df_temp[df_temp[column+'_discretized']==group_values[1]]
df_2 = df_temp[df_temp[column+'_discretized']==group_values[2]]
df_3 = df_temp[df_temp[column+'_discretized']==group_values[3]]
kmf_0 = KaplanMeierFitter()
kmf_1 = KaplanMeierFitter()
kmf_2 = KaplanMeierFitter()
kmf_3 = KaplanMeierFitter()
kmf_0.fit(df_0['time'],df_0['default_time'],label=str(group_values[0])+'Quantile')
kmf_1.fit(df_1['time'],df_1['default_time'],label=str(group_values[1])+'Quantile')
kmf_2.fit(df_2['time'],df_2['default_time'],label=str(group_values[2])+'Quantile')
kmf_3.fit(df_3['time'],df_3['default_time'],label=str(group_values[3])+'Quantile')
kmf_0.plot_survival_function()
kmf_1.plot_survival_function()
kmf_2.plot_survival_function()
kmf_3.plot_survival_function()
plt.ylabel('Probability of survival (Probability of not defaulting)')
plt.xlabel('Months')
add_at_risk_counts(kmf_0,kmf_1,kmf_2,kmf_3)
plt.title(f'Survival function for groups in {column}_discretized')
plt.tight_layout()
plt.show()
elif num_groups == 5:
df_0 = df_temp[df_temp[column+'_discretized']==group_values[0]]
df_1 = df_temp[df_temp[column+'_discretized']==group_values[1]]
df_2 = df_temp[df_temp[column+'_discretized']==group_values[2]]
df_3 = df_temp[df_temp[column+'_discretized']==group_values[3]]
df_4 = df_temp[df_temp[column+'_discretized']==group_values[4]]
kmf_0 = KaplanMeierFitter()
kmf_1 = KaplanMeierFitter()
kmf_2 = KaplanMeierFitter()
kmf_3 = KaplanMeierFitter()
kmf_4 = KaplanMeierFitter()
kmf_0.fit(df_0['time'],df_0['default_time'],label=str(group_values[0])+'Quantile')
kmf_1.fit(df_1['time'],df_1['default_time'],label=str(group_values[1])+'Quantile')
kmf_2.fit(df_2['time'],df_2['default_time'],label=str(group_values[2])+'Quantile')
kmf_3.fit(df_3['time'],df_3['default_time'],label=str(group_values[3])+'Quantile')
kmf_4.fit(df_3['time'],df_3['default_time'],label=str(group_values[4])+'Quantile')
kmf_0.plot_survival_function()
kmf_1.plot_survival_function()
kmf_2.plot_survival_function()
kmf_3.plot_survival_function()
kmf_4.plot_survival_function()
plt.ylabel('Probability of survival (Probability of not defaulting)')
plt.xlabel('Months')
add_at_risk_counts(kmf_0,kmf_1,kmf_2,kmf_3,kmf_4)
plt.title(f'Survival function for groups in {column}_discretized')
plt.tight_layout()
plt.show()
else:
print('Must provide 2-5 groups')
It is difficult to give feedback since your code is so long and there are functions that you haven't shared like
KaplanMeierFitter()
However, it looks like the code is so long because you have to handle cases with different numbers of groups and that causes a lot of code duplication. Could you handle these different cases with a loop?
Also there are hard-coded columns like
time
anddefault_time
which look like they have to be part of the dataframeIn the future for similar questions you should instead use https://codereview.stackexchange.com/
My final suggestion is to try and use existing libraries as that means less code for you to write, maintain, document etc.
I'd use the seaborn plotting library which builds on top of matplotlib and I think will be very helpful in your case. Here's an example of how it handles different numbers of groups. There is definitely a learning curve to seaborn
Feel free to run this code yourself and play with the different parameters and comment if there are parts you need help understanding, or if this doesn't seem like a suitable alternative to your code