How to use Pymatgen to draw a band structure projection diagram based on the number of atoms?

The following code is to draw band structures according to elements, but I want to draw band structures according to the selected atoms. But I don’t know how to modify it.

import matplotlib.pyplot as plt
from pymatgen.io.vasp.outputs import Vasprun
from pymatgen.electronic_structure.plotter import BSDOSPlotter,BSPlotter,BSPlotterProjected,DosPlotter
dos_vasprun=Vasprun("./scf/nonscf/vasprun.xml",)
dos_data=dos_vasprun.complete_dos
bs_vasprun=Vasprun("./scf/nonscf/vasprun.xml",parse_projected_eigen=True)
bs_data=bs_vasprun.get_band_structure(line_mode=1)

plt_1=BSDOSPlotter(bs_projection="elements",dos_projection="elements", font='DejaVu Sans',fig_size=(16,12) )
plt_1.get_plot(bs=bs_data,dos=dos_data)
plt.savefig('plt_1.png')

Hi, can you clarify what you’d like to plot? Do you want to plot the projected band structure only for a select set of elements, rather than all the elements in the structure?

Hello, I just want to select a few specific atoms to draw corresponding band structure diagrams. For example, selecting the two atoms Fe1 and Ti2 in the structure. I want to draw the projection of these two atoms in the band structure.

Perhaps you can try the code below:

import os
import numpy as np
import matplotlib.pyplot as plt
from pymatgen.io.vasp.outputs import Vasprun
from pymatgen.electronic_structure.plotter import BSPlotterProjected
from typing import Dict, List, Tuple, Optional
from collections import defaultdict

def analyze_band_calculation() -> Tuple[Dict[str, List[int]], Optional[object]]:
    """
    分析能带计算结果,返回元素位点信息和能带对象
    
    Returns:
        Tuple[Dict[str, List[int]], Optional[object]]: 
            - 元素到原子位点的映射字典
            - 能带对象(如果成功读取)
    """
    print("Step 1: 检查文件和读取设置")
    
    # 检查必要的文件
    required_files = ['POSCAR', 'vasprun.xml']
    for file in required_files:
        if not os.path.exists(file):
            print(f"× 错误: 找不到{file}文件")
            return None, None
    
    try:
        print("正在读取vasprun.xml...")
        vasprun = Vasprun("vasprun.xml", parse_projected_eigen=True)
        bands = vasprun.get_band_structure()
        
        # 获取结构信息
        structure = vasprun.final_structure
        print("✓ 成功读取结构数据")
        
        # 分析结构中的元素和它们的位置
        element_sites = defaultdict(list)
        print("结构分析:")
        print("-" * 60)
        print("索引  元素  分数坐标                        笛卡尔坐标")
        print("-" * 60)
        
        for i, site in enumerate(structure.sites):
            element = site.specie.symbol
            element_sites[element].append(i)
            frac_coords = site.frac_coords
            cart_coords = site.coords
            print(f"{i:3d}   {element:<3}   {frac_coords} {cart_coords}")
        
        # 打印元素统计信息
        print("元素统计:")
        print("-" * 30)
        for element, sites in element_sites.items():
            print(f"{element}: {len(sites)}个原子, 位于索引 {sites}")
        
        # 检查能带计算设置
        print("Step 2: 检查能带计算设置")
        if bands is not None:
            print("✓ 成功读取能带数据")
            print(f"  - 能带数量: {len(bands.bands[list(bands.bands.keys())[0]])}")
            print(f"  - k点数量: {len(bands.kpoints)}")
            
            if hasattr(bands, 'projections') and len(bands.projections) > 0:
                print("✓ 找到投影数据")
                print("可以使用以下格式创建投影:")
                print("projection = {")
                for element, sites in element_sites.items():
                    print(f"    '{element}1': {sites[0]},  # 第一个{element}原子")
                print("}")
            else:
                print("× 警告: 没有找到投影数据")
        
        return element_sites, bands
        
    except Exception as e:
        print(f"× 错误: {str(e)}")
        import traceback
        traceback.print_exc()
        return None, None

def plot_projected_bands(bands, element_sites):
    """
    创建投影能带图,包含正确的轨道投影
    """
    import matplotlib.pyplot as plt
    from pymatgen.electronic_structure.plotter import BSPlotterProjected
    import numpy as np
    
    if (bands is not None and 
        hasattr(bands, 'projections') and 
        len(bands.projections) > 0 and 
        len(bands.kpoints) >= 10):
        
        print("\n开始创建投影图...")
        bs_plotter = BSPlotterProjected(bands)
        
        for element, sites in element_sites.items():
            print(f"\n处理 {element} 的投影...")
            first_site = sites[0]
            
            try:
                # 获取能带数据
                distances = []
                energies = []
                s_weights = []
                p_weights = []
                d_weights = []
                
                # 遍历所有k点和能带
                for spin in bands.bands.keys():
                    for band_idx in range(len(bands.bands[spin])):
                        for k_idx, k in enumerate(bands.kpoints):
                            energy = bands.bands[spin][band_idx][k_idx] - bands.efermi
                            
                            # 获取该k点的投影权重
                            s_orbital = np.sum([bands.projections[spin][band_idx][k_idx][first_site][orb]
                                              for orb in [0]])  # s orbital
                            p_orbital = np.sum([bands.projections[spin][band_idx][k_idx][first_site][orb]
                                              for orb in [1, 2, 3]])  # p orbitals
                            d_orbital = np.sum([bands.projections[spin][band_idx][k_idx][first_site][orb]
                                              for orb in [4, 5, 6, 7, 8]])  # d orbitals
                            
                            distances.append(bands.distance[k_idx])
                            energies.append(energy)
                            s_weights.append(s_orbital)
                            p_weights.append(p_orbital)
                            d_weights.append(d_orbital)
                
                # 创建综合图
                fig, axes = plt.subplots(1, 3, figsize=(16, 4))
                plt.subplots_adjust(wspace=0.4)  # 增加子图之间的间距
                
                # 绘制s, p, d轨道投影
                orbital_data = [
                    (s_weights, 's'),
                    (p_weights, 'p'),
                    (d_weights, 'd')
                ]
                
                for ax, (weights, orbital) in zip(axes, orbital_data):
                    scatter = ax.scatter(
                        distances,
                        energies,
                        c=weights,
                        cmap='jet',
                        s=20,
                        alpha=0.7,
                        vmin=0,
                        vmax=max(weights) if max(weights) > 0 else 1
                    )
                    
                    ax.set_xlabel('Wave Vector')
                    ax.set_ylabel('$E - E_f$ (eV)')
                    ax.set_ylim(-4, 4)
                    
                    plt.colorbar(scatter, ax=ax, label=f'{orbital}-orbital weight')
                    ax.set_title(f'{element} {orbital}-orbital')
                    ax.axhline(y=0, color='k', linestyle='--', alpha=0.5)
                
                plt.suptitle(f'{element} Orbital Projections')
                plt.tight_layout()
                
                filename = f"projection_{element}_combined.png"
                plt.savefig(filename, dpi=300, bbox_inches='tight')
                plt.close()
                print(f"✓ 成功创建综合投影图 {filename}")
                
                # 为每个轨道创建单独的图
                for weights, orbital in orbital_data:
                    fig, ax = plt.subplots(figsize=(8, 6))
                    
                    scatter = ax.scatter(
                        distances,
                        energies,
                        c=weights,
                        cmap='jet',
                        s=20,
                        alpha=0.7,
                        vmin=0,
                        vmax=max(weights) if max(weights) > 0 else 1
                    )
                    
                    ax.set_xlabel('Wave Vector')
                    ax.set_ylabel('$E - E_f$ (eV)')
                    ax.set_ylim(-4, 4)
                    
                    plt.colorbar(scatter, ax=ax, label=f'{orbital}-orbital weight')
                    ax.axhline(y=0, color='k', linestyle='--', alpha=0.5)
                    
                    plt.title(f'{element} {orbital}-orbital projection')
                    plt.tight_layout()
                    
                    filename = f"projection_{element}_{orbital}.png"
                    plt.savefig(filename, dpi=300, bbox_inches='tight')
                    plt.close()
                    print(f"✓ 成功创建 {filename}")
                    
            except Exception as e:
                print(f"× 创建 {element} 投影图失败: {str(e)}")
                import traceback
                traceback.print_exc()
    else:
        print("数据不完整或k点数量不足,无法创建投影图")

def analyze_and_plot():
    """
    分析结构并创建投影图的完整流程
    """
    # 得到结构和能带数据
    element_sites, bands = analyze_band_calculation()
    
    # 显示能带的基本信息
    if bands is not None:
        is_metal = bands.is_metal()
        print("\n能带信息:")
        print(f"是否是金属: {is_metal}")
        if not is_metal:
            bandgap = bands.get_band_gap()
            print(f"带隙: {bandgap['energy']:.3f} eV")
            print(f"带隙类型: {'直接' if bandgap['direct'] else '间接'}")
        
        # 创建投影图
        plot_projected_bands(bands, element_sites)
        
        # 打印费米能级
        print(f"\n费米能级: {bands.efermi:.3f} eV")
    else:
        print("无法获取能带数据")

if __name__ == "__main__":
    analyze_and_plot()

Thank you for your reply! I tried to modify the code. But for the band structure that considers spin, it failed.

Then, what about the following code?

import os
import numpy as np
import matplotlib.pyplot as plt
from pymatgen.io.vasp.outputs import Vasprun
from pymatgen.electronic_structure.plotter import BSPlotterProjected
from typing import Dict, List, Tuple, Optional
from collections import defaultdict

def analyze_band_calculation() -> Tuple[Dict[str, List[int]], Optional[object]]:
    """
    Analyze band structure calculation results, handling spin polarization
    
    Returns:
        Tuple[Dict[str, List[int]], Optional[object]]: 
            - Dictionary mapping elements to atomic sites
            - Band structure object (if successfully read)
    """
    print("Step 1: Checking files and reading settings")
    
    required_files = ['POSCAR', 'vasprun.xml']
    for file in required_files:
        if not os.path.exists(file):
            print(f"× Error: Cannot find {file}")
            return None, None
    
    try:
        print("Reading vasprun.xml...")
        vasprun = Vasprun("vasprun.xml", parse_projected_eigen=True)
        bands = vasprun.get_band_structure()
        
        structure = vasprun.final_structure
        print("✓ Successfully read structure data")
        
        element_sites = defaultdict(list)
        print("Structure analysis:")
        print("-" * 60)
        print("Index  Element  Fractional Coordinates          Cartesian Coordinates")
        print("-" * 60)
        
        for i, site in enumerate(structure.sites):
            element = site.specie.symbol
            element_sites[element].append(i)
            frac_coords = site.frac_coords
            cart_coords = site.coords
            print(f"{i:3d}   {element:<3}   {frac_coords} {cart_coords}")
        
        print("\nElement statistics:")
        for element, sites in element_sites.items():
            print(f"{element}: {len(sites)} atoms at indices {sites}")
        
        # Check spin polarization
        is_spin_polarized = len(bands.bands.keys()) > 1
        print(f"\nSpin polarization: {'Yes' if is_spin_polarized else 'No'}")
        
        if bands is not None:
            print("✓ Successfully read band data")
            for spin in bands.bands.keys():
                print(f"  - Number of bands (Spin {spin}): {len(bands.bands[spin])}")
            print(f"  - Number of k-points: {len(bands.kpoints)}")
            
            if hasattr(bands, 'projections') and bands.projections:
                print("✓ Found projection data")
            else:
                print("× Warning: No projection data found")
        
        return element_sites, bands
        
    except Exception as e:
        print(f"× Error: {str(e)}")
        import traceback
        traceback.print_exc()
        return None, None

def plot_projected_bands(bands, element_sites):
    """
    Create projected band structure plots with proper orbital projections and spin handling
    """
    if not (bands and hasattr(bands, 'projections') and bands.projections and len(bands.kpoints) >= 10):
        print("Incomplete data or insufficient k-points for plotting")
        return
        
    print("\nStarting projection plots...")
    spins = list(bands.bands.keys())
    is_spin_polarized = len(spins) > 1
    
    for element, sites in element_sites.items():
        print(f"\nProcessing projections for {element}...")
        first_site = sites[0]
        
        try:
            # Process data for each spin channel
            spin_data = {}
            for spin in spins:
                distances = []
                energies = []
                s_weights = []
                p_weights = []
                d_weights = []
                
                for band_idx in range(len(bands.bands[spin])):
                    for k_idx, k in enumerate(bands.kpoints):
                        energy = bands.bands[spin][band_idx][k_idx] - bands.efermi
                        
                        s_orbital = np.sum([bands.projections[spin][band_idx][k_idx][first_site][orb]
                                          for orb in [0]])  # s orbital
                        p_orbital = np.sum([bands.projections[spin][band_idx][k_idx][first_site][orb]
                                          for orb in [1, 2, 3]])  # p orbitals
                        d_orbital = np.sum([bands.projections[spin][band_idx][k_idx][first_site][orb]
                                          for orb in [4, 5, 6, 7, 8]])  # d orbitals
                        
                        distances.append(bands.distance[k_idx])
                        energies.append(energy)
                        s_weights.append(s_orbital)
                        p_weights.append(p_orbital)
                        d_weights.append(d_orbital)
                
                spin_data[spin] = {
                    'distances': distances,
                    'energies': energies,
                    's_weights': s_weights,
                    'p_weights': p_weights,
                    'd_weights': d_weights
                }
            
            # Create plots for each orbital type
            orbital_types = [('s', 's_weights'), ('p', 'p_weights'), ('d', 'd_weights')]
            
            for orbital, weight_key in orbital_types:
                if is_spin_polarized:
                    fig, (ax_up, ax_down) = plt.subplots(1, 2, figsize=(16, 6))
                    fig.suptitle(f'{element} {orbital}-orbital projection (Spin-polarized)')
                    
                    # Spin up
                    data_up = spin_data[spins[0]]
                    scatter_up = ax_up.scatter(
                        data_up['distances'],
                        data_up['energies'],
                        c=data_up[weight_key],
                        cmap='jet',
                        s=20,
                        alpha=0.7
                    )
                    ax_up.set_title('Spin Up')
                    plt.colorbar(scatter_up, ax=ax_up, label=f'{orbital}-orbital weight')
                    
                    # Spin down
                    data_down = spin_data[spins[1]]
                    scatter_down = ax_down.scatter(
                        data_down['distances'],
                        data_down['energies'],
                        c=data_down[weight_key],
                        cmap='jet',
                        s=20,
                        alpha=0.7
                    )
                    ax_down.set_title('Spin Down')
                    plt.colorbar(scatter_down, ax=ax_down, label=f'{orbital}-orbital weight')
                    
                    # Common settings
                    for ax in [ax_up, ax_down]:
                        ax.set_xlabel('Wave Vector')
                        ax.set_ylabel('$E - E_f$ (eV)')
                        ax.set_ylim(-4, 4)
                        ax.axhline(y=0, color='k', linestyle='--', alpha=0.5)
                    
                else:
                    # Non-spin-polarized case
                    fig, ax = plt.subplots(figsize=(8, 6))
                    data = spin_data[spins[0]]
                    
                    scatter = ax.scatter(
                        data['distances'],
                        data['energies'],
                        c=data[weight_key],
                        cmap='jet',
                        s=20,
                        alpha=0.7
                    )
                    
                    ax.set_xlabel('Wave Vector')
                    ax.set_ylabel('$E - E_f$ (eV)')
                    ax.set_ylim(-4, 4)
                    plt.colorbar(scatter, ax=ax, label=f'{orbital}-orbital weight')
                    ax.axhline(y=0, color='k', linestyle='--', alpha=0.5)
                    plt.title(f'{element} {orbital}-orbital projection')
                
                plt.tight_layout()
                filename = f"projection_{element}_{orbital}_{'spin' if is_spin_polarized else 'nospin'}.png"
                plt.savefig(filename, dpi=300, bbox_inches='tight')
                plt.close()
                print(f"✓ Created {filename}")
            
            # Create combined orbital plot for each spin channel
            if is_spin_polarized:
                for spin, spin_label in zip(spins, ['up', 'down']):
                    fig, axes = plt.subplots(1, 3, figsize=(16, 4))
                    plt.subplots_adjust(wspace=0.4)
                    
                    data = spin_data[spin]
                    orbital_data = [
                        (data['s_weights'], 's'),
                        (data['p_weights'], 'p'),
                        (data['d_weights'], 'd')
                    ]
                    
                    for ax, (weights, orbital) in zip(axes, orbital_data):
                        scatter = ax.scatter(
                            data['distances'],
                            data['energies'],
                            c=weights,
                            cmap='jet',
                            s=20,
                            alpha=0.7
                        )
                        
                        ax.set_xlabel('Wave Vector')
                        ax.set_ylabel('$E - E_f$ (eV)')
                        ax.set_ylim(-4, 4)
                        plt.colorbar(scatter, ax=ax, label=f'{orbital}-orbital weight')
                        ax.set_title(f'{orbital}-orbital')
                        ax.axhline(y=0, color='k', linestyle='--', alpha=0.5)
                    
                    plt.suptitle(f'{element} Orbital Projections (Spin {spin_label})')
                    plt.tight_layout()
                    
                    filename = f"projection_{element}_combined_spin_{spin_label}.png"
                    plt.savefig(filename, dpi=300, bbox_inches='tight')
                    plt.close()
                    print(f"✓ Created combined projection {filename}")
            
        except Exception as e:
            print(f"× Failed to create projections for {element}: {str(e)}")
            import traceback
            traceback.print_exc()

def analyze_and_plot():
    """
    Complete workflow for structure analysis and projection plotting
    """
    element_sites, bands = analyze_band_calculation()
    
    if bands is not None:
        is_metal = bands.is_metal()
        print("\nBand structure information:")
        print(f"Is metal: {is_metal}")
        if not is_metal:
            bandgap = bands.get_band_gap()
            print(f"Band gap: {bandgap['energy']:.3f} eV")
            print(f"Gap type: {'direct' if bandgap['direct'] else 'indirect'}")
        
        plot_projected_bands(bands, element_sites)
        print(f"\nFermi level: {bands.efermi:.3f} eV")
    else:
        print("Unable to obtain band structure data")

if __name__ == "__main__":
    analyze_and_plot()