How to use a dropdown widget to highlight selected categorical variable in stacked bar chart?

81 views Asked by At

I am learning matplotlib and ipywidgets and attempt to design an interactive bar chart, such that the selected category can be highlighted.

Data Example

Assuming I have a dataframe:

import pandas as pd
import matplotlib.pyplot as plot

data = {"Production":[10000, 12000, 14000],
        "Sales":[9000, 10500, 12000]}
index = ["2017", "2018", "2019"]

df = pd.DataFrame(data=data, index=index)
df.plot.bar(stacked=True,rot=15, title="Annual Production Vs Annual Sales")

The resulting stacked bar chart looks like below:

enter image description here

What I am after

If we select production in the dropdown list, the blue bars will be highlighted by adding a box (or a frame) surrounding it. Similar should happen to Sales if it is selected.

Question

I am not sure if ipywidgets and matplotlib are enough to fulfill this feature, or do we need other package to make it? If possible to do with those two packages, could anyone share some clues? Thanks!

1

There are 1 answers

0
Laurent On BEST ANSWER

Here is one quick way to do it with ipywidgets and matplotlib:

import ipywidgets as widgets
import matplotlib.pyplot as plt
import pandas as pd
from IPython import display


def select(*args):
    """Helper function to draw red border around bars."""
    display.clear_output(wait=True)
    display.display(dropdown)
    if dropdown.value == "All":
        indices = []
    elif dropdown.value == "Sales":
        indices = [0, 2, 4]
    else:
        indices = [1, 3, 5]
    for i, bar in enumerate(ax.containers):
        if i in indices:
            plt.setp(bar, edgecolor="red", linewidth=3)
        else:
            plt.setp(bar, edgecolor=None)
    display.display(fig)


df = pd.DataFrame(
    data={"Production": [10000, 12000, 14000], "Sales": [9000, 10500, 12000]},
    index=["2017", "2018", "2019"],
)
# Initialize plot
fig, ax = plt.subplots()
colors = ["blue", "orange"]
labels = ["Sales", "Production"]

for i in df.index:
    bottom = 0
    for j, item in enumerate(df.loc[i, :]):
        ax.bar(x=i, height=item, bottom=bottom, color=colors[j], label=labels[j])
        bottom += item
    labels = [None, None]  # To avoid repetition in legend

ax.legend(loc="upper left")
fig.suptitle("Annual Production Vs Annual Sales")
# Add dropdown menu above plot
dropdown = widgets.Dropdown(
    options=["All", *df.columns],
    disabled=False,
)
dropdown.observe(select)
display.display(dropdown)

This is what you get from running the above code in the same Jupyter notebook cell:

enter image description here

enter image description here

enter image description here