matplotlib savefig performance, saving multiple pngs within loop

5.3k views Asked by At

I'm hoping to find a way to optimise the following situation. I have a large contour plot created with imshow of matplotlib. I then want to use this contour plot to create a large number of png images, where each image is a small section of the contour image by changing the x and y limits and the aspect ratio.

So no plot data is changing in the loop, only the axis limits and the aspect ratio are changing between each png image.

The following MWE creates 70 png images in a "figs" folder demonstrating the simplified idea. About 80% of the runtime is taken up by fig.savefig('figs/'+filename).

I've looked into the following without coming up with an improvement:

  • An alternative to matplotlib with a focus on speed -- I've struggled to find any examples/documentation of contour/surface plots with similar requirements
  • Multiprocessing -- Similar questions I've seen here appear to require fig = plt.figure() and ax.imshow to be called within the loop, since fig and ax can't be pickled. In my case this will be more expensive than any speed gains achieved by implementing multiprocessing.

I'd appreciate any insight or suggestions you might have.

import numpy as np
import matplotlib as mpl
mpl.use('agg')
import matplotlib.pyplot as plt
import time, os

def make_plot(x, y, fix, ax):
    aspect = np.random.random(1)+y/2.0-x
    xrand = np.random.random(2)*x
    xlim = [min(xrand), max(xrand)]
    yrand = np.random.random(2)*y
    ylim = [min(yrand), max(yrand)]
    filename = '{:d}_{:d}.png'.format(x,y)

    ax.set_aspect(abs(aspect[0]))
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    fig.savefig('figs/'+filename)

if not os.path.isdir('figs'):
    os.makedirs('figs')
data = np.random.rand(25, 25)

fig = plt.figure()
ax = fig.add_axes([0., 0., 1., 1.])
# in the real case, imshow is an expensive calculation which can't be put inside the loop
ax.imshow(data, interpolation='nearest')

tstart = time.clock()
for i in range(1, 8):
    for j in range(3, 13):
        make_plot(i, j, fig, ax)

print('took {:.2f} seconds'.format(time.clock()-tstart))
1

There are 1 answers

0
ImportanceOfBeingErnest On BEST ANSWER

Since the limitation in this case is the call to plt.savefig() it cannot be optimized a lot. Internally the figure is rendered from scratch and that takes a while. Possibly reducing the number of vertices to be drawn might reduce the time a bit.

The time to run your code on my machine (Win 8, i5 with 4 cores 3.5GHz) is 2.5 seconds. This seems not too bad. One can get a little improvement by using Multiprocessing.

A note about Multiprocessing: It may seem surprising that using the state machine of pyplot inside multiprocessing should work at all. But it does. And in this case here, since every image is based on the same figure and axes object, one does not even have to create new figures and axes.

I modified an answer I gave here a while ago for your case and the total time is roughly halved using multiprocessing and 5 processes on 4 cores. I appended a barplot which shows the effect of multiprocessing.

import numpy as np
#import matplotlib as mpl
#mpl.use('agg') # use of agg seems to slow things down a bit
import matplotlib.pyplot as plt
import multiprocessing
import time, os

def make_plot(d):
    start = time.clock()
    x,y=d
    #using aspect in this way causes a warning for me
    #aspect = np.random.random(1)+y/2.0-x 
    xrand = np.random.random(2)*x
    xlim = [min(xrand), max(xrand)]
    yrand = np.random.random(2)*y
    ylim = [min(yrand), max(yrand)]
    filename = '{:d}_{:d}.png'.format(x,y)
    ax = plt.gca()
    #ax.set_aspect(abs(aspect[0]))
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    plt.savefig('figs/'+filename)
    stop = time.clock()
    return np.array([x,y, start, stop])

if not os.path.isdir('figs'):
    os.makedirs('figs')
data = np.random.rand(25, 25)

fig = plt.figure()
ax = fig.add_axes([0., 0., 1., 1.])
ax.imshow(data, interpolation='nearest')


some_list = []
for i in range(1, 8):
    for j in range(3, 13):
        some_list.append((i,j))


if __name__ == "__main__":
    multiprocessing.freeze_support()
    tstart = time.clock()
    print tstart
    num_proc = 5
    p = multiprocessing.Pool(num_proc)

    nu = p.map(make_plot, some_list)

    tooktime = 'Plotting of {} frames took {:.2f} seconds'
    tooktime = tooktime.format(len(some_list), time.clock()-tstart)
    print tooktime
    nu = np.array(nu)

    plt.close("all")
    fig, ax = plt.subplots(figsize=(8,5))
    plt.suptitle(tooktime)
    ax.barh(np.arange(len(some_list)), nu[:,3]-nu[:,2], 
            height=np.ones(len(some_list)), left=nu[:,2],  align="center")
    ax.set_xlabel("time [s]")
    ax.set_ylabel("image number")
    ax.set_ylim([-1,70])
    plt.tight_layout()
    plt.savefig(__file__+".png")
    plt.show()

enter image description here