Skip to main content

Subplots with Matplotlib

As already discussed in some of my previous articles good visualisation of data is essential to getting the associated message across. One aspect of this is the need to plot multiple data sets or visualise the same data set in different ways on the same figure. For example we may wish to illustrate our data and the residuals after we subtract a fit to that data all in the same figure.

This can be effectively done using matplotlib and the associated subplots environments. Mastery of these tools is something that comes very much with practice and I do not claim to be an expert. However, I have some experience with the environment and I will share with you the basics in this article.

In order to use the matplotlib environment we will need to begin by importing matplotlib via,
import matplotlib.pyplot as plt
We can then proceed to explore what is on offer.

plt.subplot() and plt.subplots()

So plt.subplots() and plt.subplot() are probably where most people begin to learn about the idea of combining sub-figures into a single figure. This was certainly my experience when learning Python and while they are incredibly useful they do have some imitations.

To begin with lets look at plt.subplot(). We can produce a simple subplot graph using the following code,
plt.subplot(211)
plt.subplot(212)
plt.show()
which displays the figure below. The numbers inside plt.subplot() indicate the number of rows, number of columns and the figure number. The figures are numbered by row and then by column which effectively results in clockwise numbering if there is more than one row. Underneath each plt.subplot() we write the relevant plt.plot(), plt.xlabel() ect functions that we want to apply to each sub-figure.
Pretty straight forward right? You can provide keyword arguments to plt.subplot() aswell and these might be worth investigating further via the link to the matplotlib documentation at the end of the article. However, I find defining subplots in this manner is not always the best choice and tend to favour using plt.subplots() which is subtly different.

In my opinion plt.subplots() is much more flexible and usable than the above (although I may have a biased opinion having slightly more experience with this slightly different environment). Where the two functions differ is that plt.subplots() produces all of the sub-figures at the same time where as using plt.subplot() they have to be individually created. This means we can do the following,
nrows, ncols = 2, 2
fig, axs = plt.subplots(nrows, ncols)
plt.show()
where all 4 of the plots are created together and their axis placed in an array. For me this helps me keep track of what is happening with each sub-figure via functions like axs[0].plot(). You can also provide plt.subplots() with arguments like `sharex', `sharey' and `figsize'. If we leave `figsize' at the default value then are figure will often appear cramped and wont display our data effectively. `sharex' and `sharey' allow you to ensure that all of the axis have the same scales across the subplots by setting their values to True. Alternatively they can be set to `col' and `row' to share the axis scales across columns and rows. An example of this is as follows,
import numpy as np
x = np.linspace(3, 5, 100)
y = x**2

x1 = np.linspace(2, 5, 100)
y1 = x**3

nrows, ncols = 2, 1
fig, axs = plt.subplots(nrows, ncols, sharex='col')
axs[0].plot(x, y)
axs[1].plot(x1, y1)
plt.show()
where I have used numpy to define some simple power law data. Note that I have specifically defined the data over different ranges but the resultant figure below has the same range of x values for both subplots. This is because I have set the `sharex' argument to `col' which also removes the tick labels in the figure for the first subplot. Alternatively I could have set `sharex' equal to True as in this instance there is only one column. The above example also illustrates how to iterate through the axs array to give each sub-figure detail. Similar functions to axs[1].plot() like axs[1].set_xlabel() can be used to give each sub-figure further specific properties.

We might also want to remove the whitespace between the two figures since they are sharing the same x-axis and this can be done with plt.subplots_adjust(hspace=0). We will also want to add x and y labels to the figure. We can do this as described above for each sub-figure if the y-axes correspond to different variables using axs[i].set_ylabel(). Alternatively if the y-axis of both figures represents the same variable we want to have a single global label for the pair that sits nicely in the middle of our figure. We can do this by encasing the figure inside a new subplot that has no frame, axis ticks or ticklabels but does have axes labels. We use the fig.add_subplots() function to do this and this is shown in the code bellow,
nrows, ncols = 2, 1
fig, axs = plt.subplots(nrows, ncols, sharex='col')
axs[0].plot(x, y)
axs[1].plot(x1, y1)
fig.add_subplot(111, frame_on=False)
plt.tick_params(labelcolor="none", bottom=False, left=False)
plt.xlabel('x')
plt.ylabel('y')
plt.subplots_adjust(hspace=0)
plt.savefig('Fig2.png')
plt.show()
plt.close()
where the arguments inside add_subplot() and tick_params() ensure that we don't see anything other than the assigned labels for the global figure. We can see the results in the below figure.


We can create more complex figures like the one below with a bit more work. The code for this figure can be found on my github linked at the end of the article. I have used the plt.subplots() function to define a $4\times4$ array of sub-figures and then subsequently removed the axis of the figures in the top right of the graph using the axs[j, i].axis('off') function for i > j . I have also rotated the x-ticks and adjusted the positioning of the x-label so that it is not obscured.

One of the disadvantages of the plt.subplots environment is the inability to resize particular sub figures. For this we can use the GridSpec environment which I will hopefully cover the basics of in a future post. However, for most purposes the subplots environment is sufficient and can improve the quality of your data visualisation significantly. I hope that this article has been informative. The code to produce the graphs shown can be found here.

Thanks for reading!

Further reading:


Comments

Popular posts from this blog

LDL Decomposition with Python

I recently wrote a post on the Cholesky decomposition of a matrix. You can read more about the Cholesky decomposition here;  https://harrybevins.blogspot.com/2020/04/cholesky-decomposition-and-identity.html . A closely related and more stable decomposition is the LDL decomposition which has the form, $\textbf{Q} = \textbf{LDL*}$, where $\textbf{L}$ is a lower triangular matrix with diagonal entries equal to 1, $\textbf{L*}$ is it's complex conjugate and $\textbf{D}$ is a diagonal matrix. Again an LDL decomposition can be performed using the Scipy or numpy linear algebra packages but it is a far more rewarding experience to write the code. This also often leads to a better understanding of what is happening during this decomposition. The relationship between the two decompositions, Cholesky and LDL, can be expressed like so, $\textbf{Q} = \textbf{LDL*} = \textbf{LD}^{1/2}(\textbf{D})^{1/2}\textbf{*L*} = \textbf{LD}^{1/2}(\textbf{LD}^{1/2})\textbf{*}$. A simple way to calcu

Random Number Generation: Box-Muller Transform

Knowing how to generate random numbers is a key tool in any data scientists tool box. They appear in multiple different optimisation routines and machine learning processes. However, we often use random number generators built into programming languages without thinking about what is happening below the surface. For example in Python if I want to generate a random number uniformally distributed between 0 and 1 all I need to do is import numpy and use the np.random.uniform() function. Similarly if I want gaussian random numbers to for example simulate random noise in an experiment all I need to do is use np.random.normal(). But what is actually happening when I call these functions? and how do I go about generating random numbers from scratch? This is the first of hopefully a number of blog posts on the subject of random numbers and generating random numbers. There are multiple different methods that can be used in order to do this such as the inverse probability transform method and I

Random Number Generation: Inverse Transform Sampling with Python

Following on from my previous post, in which I showed how to generate random normally distributed numbers using the Box-Muller Transform, I want to demonstrate how Inverse Transform Sampling(ITS) can be used to generate random exponentially distributed numbers. The description of the Box-Muller Transform can be found here:  https://astroanddata.blogspot.com/2020/06/random-number-generation-box-muller.html . As discussed in my previous post random numbers appear everywhere in data analysis and knowing how to generate them is an important part of any data scientists tool box. ITS takes a sample of uniformly distributed numbers and maps them onto a chosen probability density function via the cumulative distribution function (CDF). In our case the chosen probability density function is for an exponential distribution given by, $P_d(x) = \lambda \exp(-\lambda x)$. This is a common distribution that describes events that occur independently, continuously and with an average constant rate, $\