I am trying to plot 4 Shap dependency plots in 2x2 subplots but cannot get it to work. I have tried the following:
fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(10,10))
ax1 = plt.subplot(221)
shap.dependence_plot('age', shap_values[1], X_train, ax=ax1)
ax2 = plt.subplot(222)
shap.dependence_plot('income', shap_values[1], X_train, ax=ax2)
ax3 = plt.subplot(223)
shap.dependence_plot('score', shap_values[1], X_train, ax=ax2)
And this:
plt.figure(figsize=(10,5))
plt.subplot(1,2,1)
shap.dependence_plot('age', shap_values[1], X_train)
plt.subplot(1,2,2)
shap.dependence_plot('income', shap_values[1], X_train)
plt.subplot(2,2,3)
shap.dependence_plot('score', shap_values[1], X_train)
plt.tight_layout()
plt.show()
It keeps plotting them on different rows, instead of a 2x2 format.
You need to pass the argument
show=False
to the dependence plot.In This notebook you can read the following about this argument:
''by passing show=False you can prevent shap.dependence_plot from calling the matplotlib show() function, and so you can keep customizing the plot before eventually calling show yourself``