Then, what about the following code?
import os
import numpy as np
import matplotlib.pyplot as plt
from pymatgen.io.vasp.outputs import Vasprun
from pymatgen.electronic_structure.plotter import BSPlotterProjected
from typing import Dict, List, Tuple, Optional
from collections import defaultdict
def analyze_band_calculation() -> Tuple[Dict[str, List[int]], Optional[object]]:
"""
Analyze band structure calculation results, handling spin polarization
Returns:
Tuple[Dict[str, List[int]], Optional[object]]:
- Dictionary mapping elements to atomic sites
- Band structure object (if successfully read)
"""
print("Step 1: Checking files and reading settings")
required_files = ['POSCAR', 'vasprun.xml']
for file in required_files:
if not os.path.exists(file):
print(f"× Error: Cannot find {file}")
return None, None
try:
print("Reading vasprun.xml...")
vasprun = Vasprun("vasprun.xml", parse_projected_eigen=True)
bands = vasprun.get_band_structure()
structure = vasprun.final_structure
print("✓ Successfully read structure data")
element_sites = defaultdict(list)
print("Structure analysis:")
print("-" * 60)
print("Index Element Fractional Coordinates Cartesian Coordinates")
print("-" * 60)
for i, site in enumerate(structure.sites):
element = site.specie.symbol
element_sites[element].append(i)
frac_coords = site.frac_coords
cart_coords = site.coords
print(f"{i:3d} {element:<3} {frac_coords} {cart_coords}")
print("\nElement statistics:")
for element, sites in element_sites.items():
print(f"{element}: {len(sites)} atoms at indices {sites}")
# Check spin polarization
is_spin_polarized = len(bands.bands.keys()) > 1
print(f"\nSpin polarization: {'Yes' if is_spin_polarized else 'No'}")
if bands is not None:
print("✓ Successfully read band data")
for spin in bands.bands.keys():
print(f" - Number of bands (Spin {spin}): {len(bands.bands[spin])}")
print(f" - Number of k-points: {len(bands.kpoints)}")
if hasattr(bands, 'projections') and bands.projections:
print("✓ Found projection data")
else:
print("× Warning: No projection data found")
return element_sites, bands
except Exception as e:
print(f"× Error: {str(e)}")
import traceback
traceback.print_exc()
return None, None
def plot_projected_bands(bands, element_sites):
"""
Create projected band structure plots with proper orbital projections and spin handling
"""
if not (bands and hasattr(bands, 'projections') and bands.projections and len(bands.kpoints) >= 10):
print("Incomplete data or insufficient k-points for plotting")
return
print("\nStarting projection plots...")
spins = list(bands.bands.keys())
is_spin_polarized = len(spins) > 1
for element, sites in element_sites.items():
print(f"\nProcessing projections for {element}...")
first_site = sites[0]
try:
# Process data for each spin channel
spin_data = {}
for spin in spins:
distances = []
energies = []
s_weights = []
p_weights = []
d_weights = []
for band_idx in range(len(bands.bands[spin])):
for k_idx, k in enumerate(bands.kpoints):
energy = bands.bands[spin][band_idx][k_idx] - bands.efermi
s_orbital = np.sum([bands.projections[spin][band_idx][k_idx][first_site][orb]
for orb in [0]]) # s orbital
p_orbital = np.sum([bands.projections[spin][band_idx][k_idx][first_site][orb]
for orb in [1, 2, 3]]) # p orbitals
d_orbital = np.sum([bands.projections[spin][band_idx][k_idx][first_site][orb]
for orb in [4, 5, 6, 7, 8]]) # d orbitals
distances.append(bands.distance[k_idx])
energies.append(energy)
s_weights.append(s_orbital)
p_weights.append(p_orbital)
d_weights.append(d_orbital)
spin_data[spin] = {
'distances': distances,
'energies': energies,
's_weights': s_weights,
'p_weights': p_weights,
'd_weights': d_weights
}
# Create plots for each orbital type
orbital_types = [('s', 's_weights'), ('p', 'p_weights'), ('d', 'd_weights')]
for orbital, weight_key in orbital_types:
if is_spin_polarized:
fig, (ax_up, ax_down) = plt.subplots(1, 2, figsize=(16, 6))
fig.suptitle(f'{element} {orbital}-orbital projection (Spin-polarized)')
# Spin up
data_up = spin_data[spins[0]]
scatter_up = ax_up.scatter(
data_up['distances'],
data_up['energies'],
c=data_up[weight_key],
cmap='jet',
s=20,
alpha=0.7
)
ax_up.set_title('Spin Up')
plt.colorbar(scatter_up, ax=ax_up, label=f'{orbital}-orbital weight')
# Spin down
data_down = spin_data[spins[1]]
scatter_down = ax_down.scatter(
data_down['distances'],
data_down['energies'],
c=data_down[weight_key],
cmap='jet',
s=20,
alpha=0.7
)
ax_down.set_title('Spin Down')
plt.colorbar(scatter_down, ax=ax_down, label=f'{orbital}-orbital weight')
# Common settings
for ax in [ax_up, ax_down]:
ax.set_xlabel('Wave Vector')
ax.set_ylabel('$E - E_f$ (eV)')
ax.set_ylim(-4, 4)
ax.axhline(y=0, color='k', linestyle='--', alpha=0.5)
else:
# Non-spin-polarized case
fig, ax = plt.subplots(figsize=(8, 6))
data = spin_data[spins[0]]
scatter = ax.scatter(
data['distances'],
data['energies'],
c=data[weight_key],
cmap='jet',
s=20,
alpha=0.7
)
ax.set_xlabel('Wave Vector')
ax.set_ylabel('$E - E_f$ (eV)')
ax.set_ylim(-4, 4)
plt.colorbar(scatter, ax=ax, label=f'{orbital}-orbital weight')
ax.axhline(y=0, color='k', linestyle='--', alpha=0.5)
plt.title(f'{element} {orbital}-orbital projection')
plt.tight_layout()
filename = f"projection_{element}_{orbital}_{'spin' if is_spin_polarized else 'nospin'}.png"
plt.savefig(filename, dpi=300, bbox_inches='tight')
plt.close()
print(f"✓ Created {filename}")
# Create combined orbital plot for each spin channel
if is_spin_polarized:
for spin, spin_label in zip(spins, ['up', 'down']):
fig, axes = plt.subplots(1, 3, figsize=(16, 4))
plt.subplots_adjust(wspace=0.4)
data = spin_data[spin]
orbital_data = [
(data['s_weights'], 's'),
(data['p_weights'], 'p'),
(data['d_weights'], 'd')
]
for ax, (weights, orbital) in zip(axes, orbital_data):
scatter = ax.scatter(
data['distances'],
data['energies'],
c=weights,
cmap='jet',
s=20,
alpha=0.7
)
ax.set_xlabel('Wave Vector')
ax.set_ylabel('$E - E_f$ (eV)')
ax.set_ylim(-4, 4)
plt.colorbar(scatter, ax=ax, label=f'{orbital}-orbital weight')
ax.set_title(f'{orbital}-orbital')
ax.axhline(y=0, color='k', linestyle='--', alpha=0.5)
plt.suptitle(f'{element} Orbital Projections (Spin {spin_label})')
plt.tight_layout()
filename = f"projection_{element}_combined_spin_{spin_label}.png"
plt.savefig(filename, dpi=300, bbox_inches='tight')
plt.close()
print(f"✓ Created combined projection {filename}")
except Exception as e:
print(f"× Failed to create projections for {element}: {str(e)}")
import traceback
traceback.print_exc()
def analyze_and_plot():
"""
Complete workflow for structure analysis and projection plotting
"""
element_sites, bands = analyze_band_calculation()
if bands is not None:
is_metal = bands.is_metal()
print("\nBand structure information:")
print(f"Is metal: {is_metal}")
if not is_metal:
bandgap = bands.get_band_gap()
print(f"Band gap: {bandgap['energy']:.3f} eV")
print(f"Gap type: {'direct' if bandgap['direct'] else 'indirect'}")
plot_projected_bands(bands, element_sites)
print(f"\nFermi level: {bands.efermi:.3f} eV")
else:
print("Unable to obtain band structure data")
if __name__ == "__main__":
analyze_and_plot()