Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 53 additions & 21 deletions src/symwannier/amn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,34 +5,50 @@
import os
import itertools
import gzip
import logging

from symwannier.nnkp import Nnkp
from symwannier.sym import Sym
from symwannier.io_utils import open_text_or_gz

class Amn():
"""
amn file
"""Reader and processor for Amn files
Amn(k) = <psi_mk|g_n>
"""
def __init__(self, file_amn, nnkp, sym = None):

def __init__(self, file_amn, nnkp, sym = None, log=None):
"""Load Amn data from file and prepare symmetry handling.

Parameters
----------
file_amn : str
Path to amn file (optionally gzipped).
nnkp : Nnkp
Parsed nnkp object providing k-point and neighbor info.
sym : Sym, optional
Symmetry data. If provided, Amn is symmetrized/expanded accordingly.
log : logging.Logger, optional
Logger to use; if not provided a module logger is created.
Copy link

Copilot AI Jan 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring contains a typo: "Logger to use; if not provided a module logger is created" is missing a comma. It should read "Logger to use; if not provided, a module logger is created" for proper grammar.

Suggested change
Logger to use; if not provided a module logger is created.
Logger to use; if not provided, a module logger is created.

Copilot uses AI. Check for mistakes.
"""
self.log = log or logging.getLogger(__name__)
if not self.log.handlers:
logging.basicConfig(level=logging.INFO, format="%(message)s")

self.nnkp = nnkp
self.sym = sym

if os.path.exists(file_amn):
with open(file_amn) as fp:
self._read_amn(fp)
elif os.path.exists(file_amn + ".gz"):
with gzip.open(file_amn + ".gz", 'rt') as fp:
self._read_amn(fp)
else:
raise Exception("failed to read amn file: " + file_amn)
fp, used_path = open_text_or_gz(file_amn, desc="amn file")
self.log.debug(f"Reading amn from {used_path}")
with fp:
self._read_amn(fp)

def _read_amn(self, fp):
print(" Reading amn file")
"""Read amn contents from an open file-like object."""
self.log.info("Reading amn file")
lines = fp.readlines()
num_bands, nk, num_wann = [ int(x) for x in lines[1].split() ]
dat = np.genfromtxt(lines[2:]).reshape(nk, num_wann, num_bands, 5)
flat = np.fromstring("".join(lines[2:]), sep=" ")
Copy link

Copilot AI Jan 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using np.fromstring is deprecated as of NumPy 1.14 and will be removed in future versions. The recommended replacement is np.frombuffer for bytes. Consider using a more modern approach that doesn't rely on deprecated functions.

Suggested change
flat = np.fromstring("".join(lines[2:]), sep=" ")
flat_str = " ".join(lines[2:])
flat = np.fromiter((float(val) for val in flat_str.split()), dtype=float)

Copilot uses AI. Check for mistakes.
Comment on lines 33 to +50
Copy link

Copilot AI Jan 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Joining all remaining lines into a single string for parsing is inefficient for large amn files. This creates a potentially very large temporary string object before parsing. Consider using np.loadtxt or np.genfromtxt directly on the file pointer (after skipping headers) to avoid the memory overhead of concatenating all lines into one string.

Copilot uses AI. Check for mistakes.
dat = flat.reshape(nk, num_wann, num_bands, 5)
amn = np.transpose(dat[:,:,:,3] + 1j*dat[:,:,:,4], axes=(0,2,1))

self.num_bands = num_bands
Expand All @@ -50,9 +66,10 @@ def _read_amn(self, fp):
self.amn = self.symmetrize_expand(amn)

def symmetrize_Gk(self, amn, thr=0):
"""Symmetrize Amn (or U) on irreducible k-points using site symmetry G_k."""
amn_sym = self._symmetrize_Gk_internal(amn)
diff = np.sum(np.abs(amn_sym - amn))/self.sym.nks
print(" symmetrize Gk diff1 = {:15.5e}".format(diff))
self.log.info("symmetrize Gk diff1 = %.5e", diff)
Copy link

Copilot AI Jan 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logger formatting string uses old-style % formatting while most other logging calls in the file use f-strings or format method. For consistency, consider using the logging module's built-in parameter substitution (e.g., self.log.info("symmetrize Gk diff1 = %.5e", diff)) consistently throughout the file.

Suggested change
self.log.info("symmetrize Gk diff1 = %.5e", diff)
self.log.info(f"symmetrize Gk diff1 = {diff:.5e}")

Copilot uses AI. Check for mistakes.

#if thr > 0:
# for i in range(20):
Expand Down Expand Up @@ -124,6 +141,7 @@ def symmetrize_expand(self, Umat_irk):
return Umat

def Umat_symmetrize(self, Umat, Umat_opt = None):
"""Symmetrize Umat (and optionally combine with disentanglement rotations)."""
Umat_irk = Umat[self.sym.iks2ik[:]]
if Umat_opt is not None:
Umat_opt_irk = Umat_opt[self.sym.iks2ik[:]]
Expand All @@ -134,10 +152,11 @@ def Umat_symmetrize(self, Umat, Umat_opt = None):
else:
Umat_irk = self.symmetrize_Gk(Umat_irk)
Umat_new = self.symmetrize_expand(Umat_irk)
print(" symmetrize expand diff = {:15.5e}".format(np.sum(np.abs(Umat_new - Umat))/self.nk))
self.log.info("symmetrize expand diff = %.5e", np.sum(np.abs(Umat_new - Umat))/self.nk)
return Umat_new

def write_amn(self, file_amn):
"""Write Amn data back to disk in wannier90 format."""
with open(file_amn, "w") as fp:
fp.write("Amn created by amn.py\n")
fp.write("{} {} {}\n".format(self.num_bands, self.nk, self.num_wann))
Expand All @@ -146,12 +165,24 @@ def write_amn(self, file_amn):
fp.write("{0} {1} {2} {3.real:18.12f} {3.imag:18.12f}\n".format(n+1, m+1, ik+1, self.amn[ik,n,m]))

def Umat(self, index_win=None, check=True, Umat_opt=None):
"""
calculate initial Umat

amn[:num_bands, :num_wann, :nk]
m[:num_bands, :num_wann] => SVD => u[:num_bands, :num_bands], v[:num_wann, :num_wann]
Umat[:num_bands, :num_wann] = u[:num_bands, :num_wann] . v[:num_wann, :num_wann]
"""Construct initial U matrices from Amn via SVD (and optional disentanglement).
amn[:num_bands, :num_wann, :nk]
m[:num_bands, :num_wann] => SVD => u[:num_bands, :num_bands], v[:num_wann, :num_wann]
Umat[:num_bands, :num_wann] = u[:num_bands, :num_wann] . v[:num_wann, :num_wann]
Comment on lines +168 to +171
Copy link

Copilot AI Jan 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring incorrectly formats the amn shape description. The indentation is broken - it should either be part of the main docstring description or properly formatted as a Notes section. The current format with arbitrary indentation doesn't follow NumPy docstring style guidelines.

Copilot uses AI. Check for mistakes.

Parameters
----------
index_win : ndarray[bool], optional
Disentanglement window mask per k and band. Outside window elements are zeroed.
check : bool, optional
If True and no symmetry, verify unitarity per k-point.
Umat_opt : ndarray, optional
Disentanglement rotations; if provided, Amn is rotated before SVD.

Returns
-------
ndarray
Umat of shape (nk, num_bands or num_wann, num_wann).
"""
if Umat_opt is None: # Umat simply from Amn
Umat = np.zeros_like(self.amn)
Expand Down Expand Up @@ -182,6 +213,7 @@ def Umat(self, index_win=None, check=True, Umat_opt=None):
return Umat

def projection_sym_mat(self):
"""Return rotation matrices and lattice shifts for projecting Wannier centers."""
pos = self.nnkp.nw2r # pos[num_wann, 3]: position of each Wannier
Rmat = self.sym.rotmat # Rotation matrix

Expand Down
30 changes: 25 additions & 5 deletions src/symwannier/eig.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,28 @@
import scipy.linalg
import os
import itertools
import logging

from symwannier.nnkp import Nnkp
from symwannier.sym import Sym

class Eig():
"""
eig file: e_n(k)
"""
def __init__(self, file_eig, sym = None):
"""Reader for eig files (eigenvalues e_n(k))."""
def __init__(self, file_eig, sym = None, log=None):
"""Load eigenvalue data from file.

Parameters
----------
file_eig : str
Path to eig file.
sym : Sym, optional
Symmetry data for IBZ expansion.
log : logging.Logger, optional
Logger instance.
"""
self.log = log or logging.getLogger(__name__)
if not self.log.handlers:
logging.basicConfig(level=logging.INFO, format="%(message)s")
self.sym = sym

if os.path.exists(file_eig):
Expand All @@ -22,7 +35,7 @@ def __init__(self, file_eig, sym = None):
dat = dat.reshape(nk, num_bands, 3)
eig = dat[:,:,2]
else:
raise Exception("failed to read eig file: " + file_eig)
raise FileNotFoundError(f"eig file not found: {file_eig}")

if self.sym is None:
self.nk = nk
Expand All @@ -38,6 +51,13 @@ def __init__(self, file_eig, sym = None):
self.eig[ik,:] = eig[iks,:]

def write_eig(self, file_eig):
"""Write eigenvalues to file in wannier90 format.

Parameters
----------
file_eig : str
Output file path.
"""
with open(file_eig, "w") as fp:
for ik, n in itertools.product( range(self.nk), range(self.num_bands) ):
fp.write("{:5d}{:5d}{:18.12f}\n".format(n+1, ik+1, self.eig[ik,n]))
33 changes: 23 additions & 10 deletions src/symwannier/expand_wannier_inputs.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
#!/usr/bin/env python
"""
Expand Wannier input files using symmetry information:
prefix.immn, prefix.iamn, prefix.ieig, prefix.isym => prefix.mmn, prefix.amn, prefix.eig
Expand Wannier input files using symmetry information:
prefix.immn, prefix.iamn, prefix.ieig, prefix.isym => prefix.mmn, prefix.amn, prefix.eig
"""

import sys
import os
import argparse
import logging

from symwannier.nnkp import Nnkp
from symwannier.sym import Sym
Expand All @@ -15,10 +16,22 @@
from symwannier.eig import Eig

def main(argv=None, for_cli=False):
"""Expand symmetry-reduced Wannier input files to full k-point sets.

Parameters
----------
argv : list, optional
Command-line arguments.
for_cli : bool, optional
Whether called from CLI.
"""
if argv is None:
argv = sys.argv[1:]
progname = "symmwanier expand" if for_cli else "python expand_wannier_inputs.py"

logging.basicConfig(level=logging.INFO, format="%(message)s")
log = logging.getLogger(__name__)

parser = argparse.ArgumentParser(
prog=progname,
description="Expand Wannier input files using symmetry information."
Expand All @@ -31,22 +44,22 @@ def main(argv=None, for_cli=False):

prefix = parser.parse_args(argv).prefix

nnkp = Nnkp(file_nnkp=prefix+".nnkp")
sym = Sym(file_sym=prefix+".isym", nnkp=nnkp)
nnkp = Nnkp(file_nnkp=prefix+".nnkp", log=log)
sym = Sym(file_sym=prefix+".isym", nnkp=nnkp, log=log)

# Eig
print(" {0:s}.ieig => {0:s}.eig".format(prefix))
eig = Eig(file_eig=prefix+".ieig", sym=sym)
log.info(f"{prefix}.ieig => {prefix}.eig")
eig = Eig(file_eig=prefix+".ieig", sym=sym, log=log)
eig.write_eig(prefix+".eig")

# Amn
print(" {0:s}.iamn => {0:s}.amn".format(prefix))
amn = Amn(file_amn=prefix+".iamn", sym=sym, nnkp=nnkp)
log.info(f"{prefix}.iamn => {prefix}.amn")
amn = Amn(file_amn=prefix+".iamn", sym=sym, nnkp=nnkp, log=log)
amn.write_amn(prefix+".amn")

# Mmn
print(" {0:s}.immn => {0:s}.mmn".format(prefix))
mmn = Mmn(file_mmn=prefix+".immn", nnkp=nnkp, sym=sym)
log.info(f"{prefix}.immn => {prefix}.mmn")
mmn = Mmn(file_mmn=prefix+".immn", nnkp=nnkp, sym=sym, log=log)
mmn.write_mmn(prefix+".mmn")

if __name__ == '__main__':
Expand Down
35 changes: 35 additions & 0 deletions src/symwannier/io_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""Shared I/O utilities for symwannier."""

import os
import gzip
from typing import Tuple, IO


def open_text_or_gz(path: str, desc: str = "file") -> Tuple[IO[str], str]:
"""Open a text file, falling back to gzip if ``.gz`` exists.

Parameters
----------
path : str
Base path (without .gz).
Copy link

Copilot AI Jan 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring's description of the path parameter states "Base path (without .gz)" but the function actually accepts paths with or without .gz extension and checks both. The description should clarify that the path can be either the base name or include .gz, or the implementation should enforce that .gz is not part of the input path.

Suggested change
Base path (without .gz).
Path to the file, with or without a ``.gz`` extension. The function
first tries the given path as-is, then falls back to ``path + ".gz"``.

Copilot uses AI. Check for mistakes.
desc : str, optional
Description used in error message.

Returns
-------
fp : IO[str]
Opened file-like object in text mode.
used_path : str
Actual path used (plain or .gz).
Comment on lines +20 to +23
Copy link

Copilot AI Jan 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The documentation incorrectly describes the return type as IO[str], but the function uses type annotations from the typing module. The actual return type is a tuple Tuple[IO[str], str] as correctly shown in the return type annotation, but the docstring says "fp : IO[str]" and "used_path : str" in separate lines under Returns, which doesn't match standard NumPy docstring format for tuples. The Returns section should clarify this is a tuple or use the tuple notation.

Suggested change
fp : IO[str]
Opened file-like object in text mode.
used_path : str
Actual path used (plain or .gz).
fp_and_path : Tuple[IO[str], str]
A tuple ``(fp, used_path)`` where ``fp`` is the opened file-like
object in text mode and ``used_path`` is the actual path used
(plain or ``.gz``).

Copilot uses AI. Check for mistakes.

Raises
------
FileNotFoundError
If neither the plain file nor the gzipped file exists.
"""
if os.path.exists(path):
return open(path, "r"), path
gz_path = path + ".gz"
if os.path.exists(gz_path):
return gzip.open(gz_path, "rt"), gz_path
raise FileNotFoundError(f"{desc} not found: {path}(.gz)")
Loading