DOS plots - plotting subplots for pdos and change axis label, figure size

Hi everyone,

I have couple questions relate to pymatgen.electronic_structure.dos and pymatgen.electronic_structure.plotter module.

I am trying to generate DOS and PDOS for an interface. and would like to generate subplots for pdos and change the size of the dos plots like pictures below

Screen Shot 2021-08-04 at 8.59.44 PM

I tried
dos_dict = {“Al” : Al_dos.get_site_spd_dos(71),
‘Fe_sur’: Al_dos.get_site_spd_dos(70),
‘Fe_sur’: Al_dos.get_site_spd_dos(69),
‘Fe_sur’: Al_dos.get_site_spd_dos(68),
‘Fe_sur’: Al_dos.get_site_spd_dos(67),
‘Fe_sur’: Al_dos.get_site_spd_dos(66),
‘C_sur’: Al_dos.get_site_spd_dos(32),
‘C_sur’: Al_dos.get_site_spd_dos(34),
‘C_sur’: Al_dos.get_site_spd_dos(35),
}
for orbital, dos in dos_dict.items():
dosplot = DosPlotter(sigma=0.1)
dosplot.add_dos_dict(dos)
plt = dosplot.get_plot()
plt.subplot(1, 1, 1)
plt.xlabel(‘Energy (eV)’)
plt.ylabel(‘Total DOS (au/eV)’)
but i got an empty plot and 3 pdos plots

Does anyone know how to fix this issue?
Thank you for your time!

1 Like

It’s hard to tell without seeing the loop indentation, but I think that the DosPlotter.get_plot() method is likely to return a Matplotlib Figure object, and every time your for loop executes I would expect it to overwrite the previous value of plt with a new Figure. (Or if the loop ends before then, you’d be overwriting dosplot with a new DosPlotter every time, without using dosplot first)

I’m not aware of an easy built-in method to do what you want using only Pymatgen, without digging into the underlying Matplotlib. It might help to create a Matplotlib Figure first, then in your for loop you could generate the each individual dos figure using DosPlotter.get_plot(), and add the plot’s Axes as a subplot of the starting Figure matplotlib.pyplot.subplot — Matplotlib 3.4.3 documentation

I am dealing with this exact same problem. The closest thing I’ve found is this gist, and I haven’t been able to get it to arrange plots in a grid yet. It might be faster to reimplement pymatgen.electronic_structure.plotter to accept an axes instance onto which to plot, and then do something like

fig, axs = plt.subplots(4, 1, figsize=(8, 14))

for ax, (orbital, dos) in zip(axs.flatten(), os_dict.items()):
    dosplot = YourBetterDosPlotter(sigma=0.1)
    dosplot.add_dos_dict(dos)
    dosplot.your_better_dos_plot_function(ax=ax)

Okay, getting a little further. Here’s a function I adapted from the gist:

def move_axes(ax, new_ax, fig, nrows, ncols, index):
    """
    Move an Axes object from a figure to a new pyplot managed Figure in
    the specified subplot.
    
    Adapted from https://gist.github.com/salotz/8b4542d7fe9ea3e2eacc1a2eef2532c5.
    """

    # get a reference to the old figure context so we can release it
    old_fig = ax.figure

    # remove the Axes from it's original Figure context
    try:
        ax.remove()
    except ValueError:
        pass

    # set the pointer from the Axes to the new figure
    ax.figure = fig
    
    # Get reference to original axes in the new figure
    new_axes = new_ax
         

    # add the Axes to the registry of axes for the figure
    fig.axes.append(ax)
    # twice, I don't know why...
    fig.add_axes(ax)

    # then to actually show the Axes in the new figure we have to make
    # a subplot with the positions etc for the Axes to go, so make a
    # subplot which will have a dummy Axes
    dummy_ax = fig.add_subplot(nrows, ncols, index)
    
    # Add some label text to alter position
    dummy_ax.set_xlabel("test")
    dummy_ax.set_ylabel("test")
    dummy_ax.set_title("test")   

    # then copy the relevant data from the dummy to the ax
    ax.set_position(dummy_ax.get_position())
    # ax.set_position(new_axes.get_position())

    # then remove the dummy and original axes
    dummy_ax.remove()    
    new_axes.remove()

    # close the figure the original axis was bound to
    plt.close(old_fig)

and the attempt to plot a grid of these plots

# Okay, try making a grid of DOS plots
# Not working...
fig, axs = plt.subplots(5, 5, figsize=(20, 20))
k = 1
for (label, test_cdos), ax in zip(shifted_dos_data.items(), axs.flatten()):
    o_dos = test_cdos.get_element_spd_dos(el="O")[OrbitalType.p]
    plotter = dos_plotter.DosPlotter()
    plotter.add_dos(label=label, dos=o_dos)
    the_plot = plotter.get_plot()
    test_fig, test_ax = the_plot.gcf(), the_plot.gca()

    test_fig.set_size_inches(8, 6)
    test_ax.set_title(label)

    move_axes(test_ax, ax, fig, nrows=5, ncols=5, index=k)
    k += 1
    
fig.set_size_inches(20, 20)
fig.savefig("grid-test.pdf", bbox_inches="tight", pad_inches=0.1)

It clearly plots them in a grid, but I can’t seem to place them correctly to fill the plot. Result in jupyter notebook:

And printed to PDF:

So somehow inserting the new axes stops you from being able to resize the plot? Or there’s some global canvas within the plot that you have to enlarge to get the grid of subplots to spread out?

So I bit the bullet and just edited the class, like this. Put this in dos_plotter_axes.py:

from collections import OrderedDict
from monty.json import jsanitize
from pymatgen.electronic_structure.core import OrbitalType, Spin
import numpy as np

class DosPlotterAxes:
    """
    Modified Pymatgen class for plotting DOSs. Note that the interface is
    extremely flexible given that there are many different ways in which people
    want to view DOS. The typical usage is:

        # Initializes plotter with some optional args. Defaults are usually
        # fine,
        plotter = DosPlotterAxes()

        # Adds a DOS with a label.
        plotter.add_dos("Total DOS", dos)

        # Alternatively, you can add a dict of DOSs. This is the typical
        # form returned by CompleteDos.get_spd/element/others_dos().
        plotter.add_dos_dict({"dos1": dos1, "dos2": dos2})
        plotter.add_dos_dict(complete_dos.get_spd_dos())
    """

    def __init__(self, zero_at_efermi=True, stack=False, sigma=None):
        """
        Args:
            zero_at_efermi: Whether to shift all Dos to have zero energy at the
                fermi energy. Defaults to True.
            stack: Whether to plot the DOS as a stacked area graph
            key_sort_func: function used to sort the dos_dict keys.
            sigma: A float specifying a standard deviation for Gaussian smearing
                the DOS for nicer looking plots. Defaults to None for no
                smearing.
        """
        self.zero_at_efermi = zero_at_efermi
        self.stack = stack
        self.sigma = sigma
        self._doses = OrderedDict()

    def add_dos(self, label, dos):
        """
        Adds a dos for plotting.

        Args:
            label:
                label for the DOS. Must be unique.
            dos:
                Dos object
        """
        energies = dos.energies - dos.efermi if self.zero_at_efermi else dos.energies
        densities = dos.get_smeared_densities(self.sigma) if self.sigma else dos.densities
        efermi = dos.efermi
        self._doses[label] = {
            "energies": energies,
            "densities": densities,
            "efermi": efermi,
        }

    def add_dos_dict(self, dos_dict, key_sort_func=None):
        """
        Add a dictionary of doses, with an optional sorting function for the
        keys.

        Args:
            dos_dict: dict of {label: Dos}
            key_sort_func: function used to sort the dos_dict keys.
        """
        if key_sort_func:
            keys = sorted(dos_dict.keys(), key=key_sort_func)
        else:
            keys = dos_dict.keys()
        for label in keys:
            self.add_dos(label, dos_dict[label])

    def get_dos_dict(self):
        """
        Returns the added doses as a json-serializable dict. Note that if you
        have specified smearing for the DOS plot, the densities returned will
        be the smeared densities, not the original densities.

        Returns:
            dict: Dict of dos data. Generally of the form
            {label: {'energies':..., 'densities': {'up':...}, 'efermi':efermi}}
        """
        return jsanitize(self._doses)

    def get_plot(self, ax, xlim=None, ylim=None):
        """
        Get a matplotlib plot showing the DOS.

        Args:
            ax : the matplotlib axes onto which to plot.
            xlim: Specifies the x-axis limits. Set to None for automatic
                determination.
            ylim: Specifies the y-axis limits.
        """

        ncolors = max(3, len(self._doses))
        ncolors = min(9, ncolors)

        import palettable

        # pylint: disable=E1101
        colors = palettable.colorbrewer.qualitative.Set1_9.mpl_colors

        y = None
        alldensities = []
        allenergies = []        

        # Note that this complicated processing of energies is to allow for
        # stacked plots in matplotlib.
        for key, dos in self._doses.items():
            energies = dos["energies"]
            densities = dos["densities"]
            if not y:
                y = {
                    Spin.up: np.zeros(energies.shape),
                    Spin.down: np.zeros(energies.shape),
                }
            newdens = {}
            for spin in [Spin.up, Spin.down]:
                if spin in densities:
                    if self.stack:
                        y[spin] += densities[spin]
                        newdens[spin] = y[spin].copy()
                    else:
                        newdens[spin] = densities[spin]
            allenergies.append(energies)
            alldensities.append(newdens)

        keys = list(self._doses.keys())
        keys.reverse()
        alldensities.reverse()
        allenergies.reverse()
        allpts = []
        for i, key in enumerate(keys):
            x = []
            y = []
            for spin in [Spin.up, Spin.down]:
                if spin in alldensities[i]:
                    densities = list(int(spin) * alldensities[i][spin])
                    energies = list(allenergies[i])
                    if spin == Spin.down:
                        energies.reverse()
                        densities.reverse()
                    x.extend(energies)
                    y.extend(densities)
            allpts.extend(list(zip(x, y)))
            if self.stack:
                ax.fill(x, y, color=colors[i % ncolors], label=str(key))
            else:
                ax.plot(x, y, color=colors[i % ncolors], label=str(key), linewidth=3)
            if not self.zero_at_efermi:
                ylim = ax.set_ylim()
                ax.plot(
                    [self._doses[key]["efermi"], self._doses[key]["efermi"]],
                    ylim,
                    color=colors[i % ncolors],
                    linestyle="--",
                    linewidth=2,
                )

        if xlim:
            ax.set_xlim(xlim)
        if ylim:
            ax.set_ylim(ylim)
        else:
            xlim = ax.set_xlim()
            relevanty = [p[1] for p in allpts if xlim[0] < p[0] < xlim[1]]
            ax.set_ylim((min(relevanty), max(relevanty)))

        if self.zero_at_efermi:
            ylim = ax.set_ylim()
            ax.plot([0, 0], ylim, "k--", linewidth=2)

        ax.set_xlabel("Energies (eV)")
        ax.set_ylabel("Density of states")

        ax.axhline(y=0, color="k", linestyle="--", linewidth=2)
        ax.legend()                        
        

Then, with this file importable from your working directory or elsewhere on your PYTHONPATH, you can make a plot like this:

from dos_plotter_axes import DosPlotterAxes

# Make our DOS plotter
plotter = DosPlotterAxes(zero_at_efermi=False, stack=False)
plotter.add_dos(label="test", dos=o_dos)

# Make the plot
fig, ax = plt.subplots(2, 1, sharex=True, figsize=(8, 6))
fig.subplots_adjust(hspace=0, left=0.18)

# Plot on each separate axis
plotter.get_plot(ax=ax[0])
plotter.get_plot(ax=ax[1])

# Clean things up
ax[0].set_xlabel(None)
ax[0].set_ylabel(None)
fig.text(x=0.1, y=0.5, s=ax[1].get_ylabel(), va="center", ha="center", rotation=90)
ax[1].set_ylabel(None)

Looks like this:

image

That way, you still get the features of the DosPlotter class, like stacked DOS plots, but now you can have them as subplots, too.

1 Like