Plotly: How to insert a categorical variable into a parallel coordinates plot?

4.5k views Asked by At

So far, I have tried this:

import pandas as pd
import plotly.graph_objects as go

df = pd.read_csv('https://raw.githubusercontent.com/vyaduvanshi/helper-files/master/parallel_coordinates.csv')

dimensions = list([dict(range=[df['gm_Retail & Recreation'].min(),df['gm_Retail & Recreation'].max()],
                        label='Retail & Recreation', values=df['gm_Retail & Recreation']),
                  dict(range=[df['gm_Grocery & Pharmacy'].min(),df['gm_Grocery & Pharmacy'].max()],
                       label='Grocery & Pharmacy', values=df['gm_Grocery & Pharmacy']),
                  dict(range=[df['gm_Parks'].min(),df['gm_Parks'].max()],
                       label='Parks', values=df['gm_Parks']),
                  dict(range=[df['gm_Transit Stations'].min(),df['gm_Transit Stations'].max()],
                       label='Transit Stations', values=df['gm_Transit Stations']),
                  dict(range=[df['gm_Workplaces'].min(),df['gm_Workplaces'].max()],
                       label='Workplaces', values=df['gm_Workplaces']),
                  dict(range=[df['gm_Residential'].min(),df['gm_Residential'].max()],
                       label='Residential', values=df['gm_Residential']),])
#                   dict(range=[0,len(df)], values=df['country'],
#                       label='Country')])

fig = go.Figure(data=go.Parcoords(line = dict(color = '#ff0000',
                   colorscale = 'Electric',
                   showscale = True,
                   cmin = -4000,
                   cmax = -100), dimensions=dimensions))
fig.show()

And it returns this:

enter image description here

What I am looking to do is assign these lines to a last column which would be the country column (categorical). (My attempts is commented out in the code snippet). I am trying to think how I could link these values to categorical countries. Index could be one way perhaps? I also want to color code the lines by countries for which a list of varying colors could help I guess. I am stuck and could use some help.

2

There are 2 answers

5
vestland On BEST ANSWER

In your case, you can do so by letting a dummy variable represent each unique element in df['country], You've got a dataset of a long format here, so you'll get duplicate dummy variables. But don't worry, the code below will sort that out for you. Then you can specify your last dimension as:

dict(range=[0,df['dummy'].max()],
                   tickvals = dfg['dummy'], ticktext = dfg['country'],
                   label='Country', values=df['dummy']),

And at last assign color range for the lines using, for example:

line = dict(color = df['dummy'],
                   colorscale = [[0,'rgba(200,0,0,0.1)'],[0.5,'rgba(0,200,0,0.1)'],[1,'rgba(0,0,200,0.1)']])

Plot:

enter image description here

Complete code:

import pandas as pd
import plotly.graph_objects as go

df = pd.read_csv('https://raw.githubusercontent.com/vyaduvanshi/helper-files/master/parallel_coordinates.csv')
group_vars = df['country'].unique()
dfg = pd.DataFrame({'country':df['country'].unique()})
dfg['dummy'] = dfg.index
df = pd.merge(df, dfg, on = 'country', how='left')


dimensions = list([dict(range=[df['gm_Retail & Recreation'].min(),df['gm_Retail & Recreation'].max()],
                        label='Retail & Recreation', values=df['gm_Retail & Recreation']),
                  dict(range=[df['gm_Grocery & Pharmacy'].min(),df['gm_Grocery & Pharmacy'].max()],
                       label='Grocery & Pharmacy', values=df['gm_Grocery & Pharmacy']),
                  dict(range=[df['gm_Parks'].min(),df['gm_Parks'].max()],
                       label='Parks', values=df['gm_Parks']),
                  dict(range=[df['gm_Transit Stations'].min(),df['gm_Transit Stations'].max()],
                       label='Transit Stations', values=df['gm_Transit Stations']),
                  dict(range=[df['gm_Workplaces'].min(),df['gm_Workplaces'].max()],
                       label='Workplaces', values=df['gm_Workplaces']),
                  dict(range=[df['gm_Residential'].min(),df['gm_Residential'].max()],
                       label='Residential', values=df['gm_Residential']),
                   
                  dict(range=[0,df['dummy'].max()],
                       tickvals = dfg['dummy'], ticktext = dfg['country'],
                       label='Country', values=df['dummy']),
                  
                  ])

fig = go.Figure(data=go.Parcoords(line = dict(color = df['dummy'],
                   colorscale = [[0,'rgba(200,0,0,0.1)'],[0.5,'rgba(0,200,0,0.1)'],[1,'rgba(0,0,200,0.1)']]), dimensions=dimensions))
fig.show()
0
Andrei Margeloiu On

Use df.infer_objects() to automatically infer the datatypes of each column.