Plot array of Figures Mathplotlib

75 views Asked by At

I have a function that creates a figure and now I would like to call this function repeatedly inside a loop with different parameters, collect the figures and plot them. This is how I would do it in julia:

using Plots

plots = Array{Plots.Plot{Plots.GRBackend},2}(undef, 3,3)

for i in 1:3
    for j in 1:3
        plots[i,j] = scatter(rand(10), rand(10))
        title!(plots[i,j], "Plot $(i),$(j)")
    end
end

plot(plots..., layout=(3,3))

However I have to write python. So currently I have a function that creates a new figure and returns it. I would be reluctant to change this function call signature (eg. to pass some axis object), since it is allready used in a different context. This is a minimal working example. For some reason the individual figures are displayed here even though I am not calling plt.display(), in the main code they are not however. Here the final figure is empty.

import matplotlib.pyplot as plt
import numpy as np

fig, ax = plt.subplots(3,3)

def plottingfunction(x,y):
    plt.figure()
    plt.scatter(x, y)
    return plt.gcf()

for i in range(3):
    for j in range(3):
        x = np.random.rand(10)
        y = np.random.rand(10)
        ax[i,j] = plottingfunction(x,y)

plt.show()

So how do I plot allready existing functions in a grid, that are for example collected inside of an array using python matplotlib.

2

There are 2 answers

0
RuthC On

How about making the ax parameter optional, so you can use it for this but do not break your existing use-case?

def plottingfunction(x, y, ax=None):
    if ax is None:
        # Set up a new figure and axes
        fig, ax = plt.subplots()
    else:
        fig = ax.get_figure()

    # Plot the data
    ax.scatter(x, y)
    
    return fig
4
Anonymous On

Perhaps you can play around with setting/getting the current axis. It doesn't make much sense to me to do it this way, but it seems to work for this minimal example.

import matplotlib.pyplot as plt
import numpy as np

fig, ax = plt.subplots(3, 3)


def plottingfunction(x, y):
    axis = plt.gca()
    axis.scatter(x, y)


for i in range(3):
    for j in range(3):
        plt.sca(ax[i, j])
        x = np.random.rand(10)
        y = np.random.rand(10)
        plottingfunction(x, y)


plt.show()

EDIT: It pains me to do this, but this works.

import matplotlib.pyplot as plt
import numpy as np

fig1, ax1 = plt.subplots(3, 3)
fig2, ax2 = plt.subplots(3, 3)


def plottingfunction1(x, y):
    axis = plt.gca()
    axis.scatter(x, y)


def plottingfunction2(x, y):
    axis = plt.gca()
    axis.plot(x, y)


for ax_ind in range(0, 9):
    x = np.random.rand(10)
    y = np.random.rand(10)

    plt.sca(fig1.axes[ax_ind])
    plottingfunction1(x, y)

    plt.sca(fig2.axes[ax_ind])
    plottingfunction2(x, y)

EDIT 2: Create figures and axes as needed, pass the axes to the plotting functions..

import matplotlib.pyplot as plt
import numpy as np


def plottingfunction1(x, y, axis):
    axis.scatter(x, y)


def plottingfunction2(x, y, axis):
    axis.plot(x, y)


_, ax1 = plt.subplots(3, 3)
_, ax2 = plt.subplots(3, 3)

for i in range(0, 3):
    for j in range(0, 3):
        x = np.random.rand(10)
        y = np.random.rand(10)

        plottingfunction1(x, y, axis=ax1[i, j])
        plottingfunction2(x, y, axis=ax2[i, j])

plottingfunction2(x, y, axis=ax1[1, 2])