Stack Plot with a Color Map matplotlib

5k views Asked by At

I would like to draw a stack plot with a colormap as given in Figure 5 of this paper. Here's a screenshot of the same

enter image description here

Currently, I am able to draw a scatter plot of a similar nature.

enter image description here

I would like to convert this scatter plot to a stack plot with a colormap. I am bit lost on ideas to do this. My initial guess is that for each (x,y) point I need list of z points on the colormap spectrum. I wonder however, if there's a simpler way to this. Here's my code to generate the scatter plot with color map

cm = plt.cm.get_cmap('RdYlBu')
plt.xscale('log')
plt.yscale('log')
sc = plt.scatter(x, y, c=z, marker ='x', norm = matplotlib.colors.Normalize(vmin= np.min(z), vmax=np.max(z)), s=35, cmap=cm)
plt.colorbar(sc)
plt.show()

Edit

I feel I need to find a way to convert the z-array to multiple z-arrays - one for each bin on the color bar. I can then simply create a stacked area chart from these derived z-arrays.

Edit 2

I followed Rutger's code and was able to produce this graph for my data. I wonder why there's an issue with the axes limits.

enter image description here

1

There are 1 answers

1
Rutger Kassies On

It seems from your example scatterplot that you have a lot of points. Plotting these as individual data will cover up a large portion of your data and only show the 'top' ones. This is bad practice and when you have this much data doing some aggregation will improve the visual representation.

The example below shows how you can bin and average your data by using a 2d histogram. Plotting the result as either an image or a contour is fairly straightforward once your data is in an appropriate format for visual display.

Aggregating the data before plotting also increases performance and prevents Array Too Big or memory related errors.

fig, ax = plt.subplots(1, 3, figsize=(15,5), subplot_kw={'aspect': 1})

n = 100000

x = np.random.randn(n)
y = np.random.randn(n)+5
data_values = y * x

# Normal scatter, like your example
ax[0].scatter(x, y, c=data_values, marker='x', alpha=.2)
ax[0].set_xlim(-5,5)


# Get the extent to scale the other plots in a similar fashion
xrng = list(ax[0].get_xbound())
yrng = list(ax[0].get_ybound())

# number of bins used for aggregation
n_bins = 130.

# create the histograms
counts, xedge, yedge = np.histogram2d(x, y, bins=(n_bins,n_bins), range=[xrng,yrng])
sums, xedge, yedge = np.histogram2d(x, y, bins=(n_bins,n_bins), range=[xrng,yrng], weights=data_values)

# gives a warning when a bincount is zero
data_avg = sums / counts

ax[1].imshow(data_avg.T, origin='lower', interpolation='none', extent=xrng+yrng)

xbin_size = (xrng[1] - xrng[0])  / n_bins # the range divided by n_bins
ybin_size = (yrng[1] - yrng[0])  / n_bins # the range divided by n_bins

# create x,y coordinates for the histogram
# coordinates should be shifted from edge to center
xgrid, ygrid = np.meshgrid(xedge[1:] - (xbin_size / 2) , yedge[1:] - (ybin_size / 2))

ax[2].contourf(xgrid, ygrid, data_avg.T)

ax[0].set_title('Scatter')
ax[1].set_title('2D histogram with imshow')
ax[2].set_title('2D histogram with contourf')

enter image description here