diff --git a/.gitignore b/.gitignore index 525619c..8a215d5 100644 --- a/.gitignore +++ b/.gitignore @@ -63,6 +63,11 @@ cmake-build-debug/ __pycache__/ *.pyc +# Python package metadata +*.egg-info/ +dist/ +*.egg + # Build Output build/ diff --git a/README.md b/README.md index 8c1fa7e..282ea17 100644 --- a/README.md +++ b/README.md @@ -88,14 +88,30 @@ cmake --build build ### Python Implementation -The Python implementation requires no installation. Simply ensure Python 3.10+ and NumPy are available: +Requires Python 3.10+ and NumPy. + +**Install with pip (recommended):** ```bash -# Check Python version -python3 --version # Should be 3.10 or later +cd python +pip install -e . # runtime dependencies (numpy) are installed automatically +pip install -e ".[dev]" # also install development dependencies (pytest, pytest-cov) +``` + +After installation, the `stdface` command becomes available: + +```bash +stdface stan.in +stdface stan.in --solver mVMC +stdface -v +``` + +**Run without installing:** -# Install NumPy if needed -pip install numpy +Use the wrapper script at the project root: + +```bash +./stdface stan.in ``` ## Quick Start @@ -120,16 +136,21 @@ nelec = 4 ./hphi_dry.out stan.in ``` -**Python Implementation:** +**Python Implementation (after pip install):** ```bash -PYTHONPATH=python python3 python/__main__.py stan.in +stdface stan.in ``` -To select a solver (Python): +**Python Implementation (without installation):** ```bash -PYTHONPATH=python python3 python/__main__.py stan.in --solver mVMC -PYTHONPATH=python python3 python/__main__.py stan.in --solver UHF -PYTHONPATH=python python3 python/__main__.py stan.in --solver HWAVE +./stdface stan.in +``` + +To select a solver: +```bash +stdface stan.in --solver mVMC +stdface stan.in --solver UHF +stdface stan.in --solver HWAVE ``` 3. Input files for the target solver are generated in the current directory. @@ -140,9 +161,10 @@ Both implementations produce identical output files. The `python/` directory contains a fully-featured Python port of StdFace that produces byte-identical output to the C implementation. The Python codebase has been refactored into idiomatic Python with: -- **Modular architecture**: Organized into `lattice/` and `writer/` subpackages -- **Comprehensive testing**: 1,252 unit tests and 83 integration tests -- **Python idioms**: Enums, dict dispatch, context managers, and helper functions +- **Plugin architecture**: Solvers and lattices are self-registering plugins -- new ones can be added without modifying dispatch logic +- **Modular architecture**: Organized into `lattice/`, `solvers/`, and `writer/` subpackages +- **Comprehensive testing**: 1,268 unit tests and 83 integration tests +- **Python idioms**: Enums, ABC, context managers, type hints, and helper functions - **Full feature parity**: Supports all lattices, models, and solvers ### Python Project Structure @@ -150,25 +172,35 @@ The `python/` directory contains a fully-featured Python port of StdFace that pr ``` python/ __main__.py # CLI entry point - stdface_main.py # Main logic - stdface_vals.py # Data structures - stdface_model_util.py # Shared utilities - keyword_parser.py # Keyword parsing - param_check.py # Parameter validation - lattice/ # Lattice implementations - chain_lattice.py - square_lattice.py - honeycomb_lattice.py - kagome.py - wannier90.py - ... - writer/ # Solver-specific writers - common_writer.py - hphi_writer.py - mvmc_writer.py - ... + stdface/ + plugin.py # SolverPlugin ABC + registry + core/ + stdface_main.py # Main logic + stdface_vals.py # Data structures + keyword_parser.py # Keyword parsing + param_check.py # Parameter validation + lattice/ # Lattice plugins + __init__.py # LatticePlugin ABC + registry + chain_lattice.py # ChainPlugin + square_lattice.py # SquarePlugin + kagome.py # KagomePlugin + wannier90.py # Wannier90Plugin + ... + solvers/ # Solver plugins + hphi/_plugin.py # HPhiPlugin + mvmc/_plugin.py # MVMCPlugin + uhf/_plugin.py # UHFPlugin + hwave/_plugin.py # HWavePlugin + writer/ # Shared output writers + common_writer.py + interaction_writer.py + ... ``` +### Adding New Solvers or Lattices + +See [docs/tutorial_plugin.md](docs/tutorial_plugin.md) for a step-by-step guide with examples. + ### Running Python Tests **Unit Tests:** diff --git a/docs/tutorial_plugin.md b/docs/tutorial_plugin.md new file mode 100644 index 0000000..8172f4c --- /dev/null +++ b/docs/tutorial_plugin.md @@ -0,0 +1,499 @@ +# Plugin Tutorial: Adding New Solvers and Lattices + +StdFace uses a plugin architecture where solvers and lattices are self-registering modules. This tutorial explains how to add a new solver or lattice by creating a new plugin module and wiring it into the existing discovery mechanism. + +## Table of Contents + +- [Overview](#overview) +- [Part 1: Adding a New Lattice](#part-1-adding-a-new-lattice) +- [Part 2: Adding a New Solver](#part-2-adding-a-new-solver) +- [Appendix: Existing Plugins Reference](#appendix-existing-plugins-reference) + +--- + +## Overview + +### How the Plugin System Works + +``` +User input (stan.in) + | + v +stdface_main() ---> get_lattice("chain") ---> ChainPlugin.setup(StdI) + | | + | v + | (builds geometry & interactions) + | + +---> get_plugin("HPhi") ---> HPhiPlugin.write(StdI) + | + v + (writes .def files) +``` + +1. The input file specifies `lattice = ...` and `model = ...`. +2. `stdface_main` looks up the lattice plugin by name and calls `setup(StdI)`. +3. After lattice construction, the solver plugin's `write(StdI)` generates output files. + +Both registries use **auto-registration**: when a module is imported, it creates a plugin instance and registers it. No central dispatch table needs to be edited. + +--- + +## Part 1: Adding a New Lattice + +### Example: Adding a "dice" (T3) lattice + +The dice lattice is a 2D lattice with 3 sites per unit cell. We will create a plugin for it. + +### Step 1: Create the lattice module + +Create `python/stdface/lattice/dice.py`: + +```python +""" +Standard mode for the dice (T3) lattice. + +License +------- +HPhi-mVMC-StdFace - Common input generator +Copyright (C) 2015 The University of Tokyo +(GPL v3) +""" +from __future__ import annotations + +import math +import numpy as np + +from ..core.stdface_vals import StdIntList, ModelType +from ..core.param_check import ( + print_val_d, print_val_i, + not_used_j, not_used_d, not_used_i, +) +from .input_params import input_spin_nn, input_spin, input_hopp, input_coulomb_v +from .interaction_builder import ( + compute_max_interactions, malloc_interactions, + add_neighbor_interaction, add_local_terms, +) +from .site_util import init_site, set_label, set_local_spin_flags, lattice_gp + + +def dice(StdI: StdIntList) -> None: + """Set up the Hamiltonian for the dice (T3) lattice. + + Parameters + ---------- + StdI : StdIntList + Structure containing model parameters (modified in-place). + """ + with lattice_gp(StdI) as fp: + + # --- (1) Geometry --- + StdI.NsiteUC = 3 # 3 sites per unit cell + + print(" @ Lattice Size & Shape\n") + + StdI.a = print_val_d("a", StdI.a, 1.0) + StdI.length[0] = print_val_d("Wlength", StdI.length[0], StdI.a) + StdI.length[1] = print_val_d("Llength", StdI.length[1], StdI.a) + StdI.direct[0, 0] = print_val_d("Wx", StdI.direct[0, 0], StdI.length[0]) + StdI.direct[0, 1] = print_val_d("Wy", StdI.direct[0, 1], 0.0) + StdI.direct[1, 0] = print_val_d("Lx", StdI.direct[1, 0], + StdI.length[1] * 0.5) + StdI.direct[1, 1] = print_val_d("Ly", StdI.direct[1, 1], + StdI.length[1] * math.sqrt(3.0) / 2.0) + + StdI.phase[0] = print_val_d("phase0", StdI.phase[0], 0.0) + StdI.phase[1] = print_val_d("phase1", StdI.phase[1], 0.0) + + init_site(StdI, fp, 2) + + # tau positions for 3 sites in the unit cell + StdI.tau[0, :] = [0.0, 0.0, 0.0] + StdI.tau[1, :] = [1.0 / 3.0, 1.0 / 3.0, 0.0] + StdI.tau[2, :] = [2.0 / 3.0, 2.0 / 3.0, 0.0] + + # --- (2) Hamiltonian parameters --- + print("\n @ Hamiltonian \n") + + # ... (validate and set parameters, same pattern as kagome.py) + # ... (omitted for brevity) + + # --- (3) Local spin flags --- + print("\n @ Numerical conditions\n") + set_local_spin_flags(StdI, StdI.W * StdI.L) + + # --- (4) Allocate interactions --- + ntransMax, nintrMax = compute_max_interactions(StdI, nbond=6) + malloc_interactions(StdI, ntransMax, nintrMax) + + # --- (5) Build interactions --- + for iW in range(StdI.W): + for iL in range(StdI.L): + isite = iW + iL * StdI.W + add_local_terms(StdI, isite, iL) + + # Define bonds here... + # add_neighbor_interaction(StdI, fp, iW, iL, ...) + + +# --------------------------------------------------------------------------- +# Lattice plugin registration +# --------------------------------------------------------------------------- + +from . import LatticePlugin, register_lattice + + +class DicePlugin(LatticePlugin): + """Plugin for the 2D dice (T3) lattice.""" + + @property + def name(self) -> str: + return "dice" + + @property + def aliases(self) -> list[str]: + return ["dice", "t3", "dicelattice"] + + @property + def ndim(self) -> int: + return 2 + + def setup(self, StdI: StdIntList) -> None: + """Delegate to the dice() function.""" + dice(StdI) + + # No boost() override needed -- default no-op is fine. + + +register_lattice(DicePlugin()) +``` + +**Key points:** +- The `LatticePlugin` subclass defines `name`, `aliases`, `ndim`, and `setup()`. +- `register_lattice()` is called at module level -- the plugin registers itself when the module is imported. +- If the lattice supports HPhi Boost mode, override `boost(self, StdI)`. + +### Step 2: Register the module for auto-discovery + +Edit `python/stdface/lattice/__init__.py` and add your module to `_discover_lattices()`: + +```python +def _discover_lattices() -> None: + """Import all lattice modules to trigger auto-registration.""" + from . import chain_lattice, square_lattice, ladder, triangular_lattice + from . import honeycomb_lattice, kagome, orthorhombic, fc_ortho, pyrochlore + from . import wannier90 + from . import dice # <-- Add this line +``` + +That's it. Now `lattice = dice` (or `lattice = t3`) will work in input files. + +### Step 3: Add tests + +Create `test/unit/test_dice.py`: + +```python +"""Unit tests for the dice lattice plugin.""" +from __future__ import annotations + +import pytest +from stdface.lattice import get_lattice, LatticePlugin + + +class TestDicePlugin: + """Tests for the dice lattice plugin.""" + + def test_registered(self): + plugin = get_lattice("dice") + assert isinstance(plugin, LatticePlugin) + + def test_name(self): + assert get_lattice("dice").name == "dice" + + def test_ndim(self): + assert get_lattice("dice").ndim == 2 + + def test_aliases(self): + p = get_lattice("dice") + assert get_lattice("t3") is p + assert get_lattice("dicelattice") is p + + def test_setup_callable(self): + assert callable(get_lattice("dice").setup) +``` + +### Step 4: Verify + +```bash +# Unit tests +python3 -m pytest test/unit/ -q + +# Integration test (if you have a reference stan.in for the dice lattice) +/path/to/StdFace/stdface stan.in +``` + +--- + +## Part 2: Adding a New Solver + +### Example: Adding a "MySolver" solver + +Suppose you want to add a new solver that writes its own output format. + +### Step 1: Create the solver package + +``` +python/stdface/solvers/mysolver/ + __init__.py + _plugin.py + writer.py # (optional) solver-specific output functions +``` + +### Step 2: Implement the plugin + +Create `python/stdface/solvers/mysolver/_plugin.py`: + +```python +"""MySolver plugin. + +Encapsulates MySolver-specific keyword parsing, field reset tables, +and Expert-mode file writing. +""" +from __future__ import annotations + +from ...plugin import SolverPlugin, register +from ...core.stdface_vals import StdIntList, NaN_i, NaN_d +from ...core.keyword_parser import ( + store_with_check_dup_i, store_with_check_dup_d, + store_with_check_dup_s, +) + + +class MySolverPlugin(SolverPlugin): + """Plugin for the MySolver backend.""" + + # --- Required abstract properties --- + + @property + def name(self) -> str: + """Canonical solver name (used in --solver CLI argument).""" + return "MySolver" + + @property + def keyword_table(self) -> dict[str, tuple]: + """Solver-specific keywords recognized in stan.in. + + Each entry maps a lowercased keyword to a tuple: + (store_function, field_name, [extra_args...]) + + For example, if MySolver accepts 'maxiter = 1000' in stan.in: + "maxiter": (store_with_check_dup_i, "MaxIter") + """ + return { + "maxiter": (store_with_check_dup_i, "MaxIter"), + "threshold": (store_with_check_dup_d, "Threshold"), + "outputfile": (store_with_check_dup_s, "OutputFile"), + } + + @property + def reset_scalars(self) -> list[tuple[str, object]]: + """Fields to reset to sentinel values at initialization. + + These are set via setattr(StdI, field_name, value) during + _reset_vals(). Use NaN_i for integers, NaN_d for floats. + """ + return [ + ("MaxIter", NaN_i), + ("Threshold", NaN_d), + ] + + @property + def reset_arrays(self) -> list[tuple[str, object]]: + """Array fields to reset (filled via arr[...] = value). + + Return an empty list if there are no array fields. + """ + return [] + + # --- Template method steps (override as needed) --- + + def write_solver_specific(self, StdI: StdIntList) -> None: + """Write MySolver-specific output files. + + This is called as part of the write() template method, + after locspn, trans, interactions, and modpara. + """ + _write_mysolver_config(StdI) + + def write_green(self, StdI: StdIntList) -> None: + """Override if MySolver needs different Green's function output. + + For example, if it only needs greenone.def: + """ + from ...writer.common_writer import print_1_green + print_1_green(StdI) + + # --- Optional lifecycle hooks --- + + def post_lattice(self, StdI: StdIntList) -> None: + """Called after lattice construction, before file writing. + + Use this for solver-specific validation or derived quantities. + """ + # Example: set default MaxIter if not specified + if StdI.MaxIter == NaN_i: + StdI.MaxIter = 500 + + +def _write_mysolver_config(StdI: StdIntList) -> None: + """Write mysolver_config.def.""" + with open("mysolver_config.def", "w") as fp: + fp.write(f"MaxIter = {StdI.MaxIter}\n") + fp.write(f"Threshold = {StdI.Threshold}\n") + if hasattr(StdI, 'OutputFile'): + fp.write(f"OutputFile = {StdI.OutputFile}\n") + + +# Auto-register on import +register(MySolverPlugin()) +``` + +### Step 3: Make it discoverable + +Create `python/stdface/solvers/mysolver/__init__.py`: + +```python +"""MySolver plugin package.""" +from . import _plugin # noqa: F401 — triggers auto-registration +``` + +Edit `python/stdface/solvers/__init__.py` to import the new package: + +```python +"""Built-in solver plugins for StdFace.""" +from __future__ import annotations + +from . import hphi, mvmc, uhf, hwave # noqa: F401 +from . import mysolver # noqa: F401 <-- Add this line +``` + +### Step 4: Add tests + +Create `test/unit/test_mysolver_plugin.py`: + +```python +"""Unit tests for the MySolver plugin.""" +from __future__ import annotations + +import pytest +from stdface.plugin import get_plugin, SolverPlugin + + +class TestMySolverPlugin: + """Tests for the MySolver plugin.""" + + def test_registered(self): + plugin = get_plugin("MySolver") + assert isinstance(plugin, SolverPlugin) + + def test_name(self): + assert get_plugin("MySolver").name == "MySolver" + + def test_keyword_table(self): + plugin = get_plugin("MySolver") + assert "maxiter" in plugin.keyword_table + assert "threshold" in plugin.keyword_table + + def test_reset_scalars(self): + plugin = get_plugin("MySolver") + names = [name for name, _ in plugin.reset_scalars] + assert "MaxIter" in names + assert "Threshold" in names +``` + +### Step 5: Use it + +```bash +./stdface stan.in --solver MySolver +``` + +--- + +## Plugin API Reference + +### LatticePlugin (lattice/__init__.py) + +| Member | Type | Required | Description | +|--------|------|----------|-------------| +| `name` | `str` (property) | Yes | Canonical name (e.g. `"chain"`) | +| `aliases` | `list[str]` (property) | Yes | All recognized aliases | +| `ndim` | `int` (property) | Yes | Spatial dimensions (1, 2, or 3) | +| `setup(StdI)` | method | Yes | Build lattice geometry and interactions | +| `boost(StdI)` | method | No | HPhi Boost mode (default: no-op) | + +Registration: call `register_lattice(MyPlugin())` at module level. + +### SolverPlugin (plugin.py) + +| Member | Type | Required | Description | +|--------|------|----------|-------------| +| `name` | `str` (property) | Yes | Canonical name (e.g. `"HPhi"`) | +| `keyword_table` | `dict` (property) | Yes | Input keyword dispatch table | +| `reset_scalars` | `list[tuple]` (property) | Yes | Scalar field reset table | +| `reset_arrays` | `list[tuple]` (property) | Yes | Array field reset table | +| `write(StdI)` | method | No | Template method (calls steps below) | +| `write_locspn(StdI)` | method | No | Write locspn.def | +| `write_trans(StdI)` | method | No | Write trans.def | +| `write_interactions(StdI)` | method | No | Write interaction files | +| `check_and_write_modpara(StdI)` | method | No | Write modpara.def | +| `write_solver_specific(StdI)` | method | No | Solver-specific files | +| `write_green(StdI)` | method | No | Write Green's function files | +| `write_namelist(StdI)` | method | No | Write namelist.def | +| `post_lattice(StdI)` | method | No | Hook after lattice construction | + +Registration: call `register(MyPlugin())` at module level. + +### Template Method Flow (SolverPlugin.write) + +``` +write() + |-- write_locspn() # locspn.def + |-- write_trans() # trans.def + |-- write_interactions() # coulombinter.def, etc. + |-- check_and_write_modpara() # modpara.def + |-- write_solver_specific() # solver-specific files + |-- check_output_mode() # validate output settings + |-- write_green() # greenone.def, greentwo.def + |-- write_namelist() # namelist.def +``` + +Override individual steps to customize output. Override `write()` entirely for a completely different output sequence. + +--- + +## Appendix: Existing Plugins Reference + +### Lattice Plugins + +| Plugin Class | Module | name | aliases | ndim | +|---|---|---|---|---| +| `ChainPlugin` | `chain_lattice.py` | `chain` | chain, chainlattice | 1 | +| `LadderPlugin` | `ladder.py` | `ladder` | ladder, ladderlattice | 1 | +| `SquarePlugin` | `square_lattice.py` | `tetragonal` | tetragonal, tetragonallattice, square, squarelattice | 2 | +| `TriangularPlugin` | `triangular_lattice.py` | `triangular` | triangular, triangularlattice | 2 | +| `HoneycombPlugin` | `honeycomb_lattice.py` | `honeycomb` | honeycomb, honeycomblattice | 2 | +| `KagomePlugin` | `kagome.py` | `kagome` | kagome, kagomelattice | 2 | +| `OrthorhombicPlugin` | `orthorhombic.py` | `orthorhombic` | orthorhombic, simpleorthorhombic, cubic, simplecubic | 3 | +| `FCOrthoPlugin` | `fc_ortho.py` | `fco` | face-centeredorthorhombic, fcorthorhombic, fco, face-centeredcubic, fccubic, fcc | 3 | +| `PyrochlorePlugin` | `pyrochlore.py` | `pyrochlore` | pyrochlore | 3 | +| `Wannier90Plugin` | `wannier90.py` | `wannier90` | wannier90 | 3 | + +Lattices with Boost support: chain, honeycomb, kagome, ladder. + +### Solver Plugins + +| Plugin Class | Module | name | +|---|---|---| +| `HPhiPlugin` | `solvers/hphi/_plugin.py` | `HPhi` | +| `MVMCPlugin` | `solvers/mvmc/_plugin.py` | `mVMC` | +| `UHFPlugin` | `solvers/uhf/_plugin.py` | `UHF` | +| `HWavePlugin` | `solvers/hwave/_plugin.py` | `HWAVE` | diff --git a/python/README.md b/python/README.md index 1b53cf1..a54a05f 100644 --- a/python/README.md +++ b/python/README.md @@ -9,45 +9,83 @@ for HPhi, mVMC, UHF, and H-wave. - Python 3.10 or later - NumPy +## Installation + +```bash +cd python +pip install -e . # install with runtime dependencies (numpy) +pip install -e ".[dev]" # also install dev dependencies (pytest, pytest-cov) +``` + +After installation, the `stdface` command becomes available: + +```bash +stdface stan.in +stdface stan.in --solver mVMC +stdface -v +``` + +The editable install (`-e`) is recommended for development. Source changes take effect immediately without reinstalling. + +## Architecture + +StdFace uses a **plugin architecture** for both solvers and lattices. +New solvers and lattices can be added without modifying any existing dispatch logic. + +### Plugin System Overview + +``` +stdface/ + plugin.py # SolverPlugin ABC + solver registry + lattice/__init__.py # LatticePlugin ABC + lattice registry +``` + +- **SolverPlugin** (`plugin.py`): Defines how a solver parses keywords, resets fields, and writes output files. Each solver (HPhi, mVMC, UHF, H-wave) is a plugin registered at import time. +- **LatticePlugin** (`lattice/__init__.py`): Defines lattice geometry, aliases, and the setup/boost methods. Each lattice (chain, square, kagome, etc.) is a plugin registered at import time. + +See [docs/tutorial_plugin.md](../docs/tutorial_plugin.md) for a step-by-step guide on adding new solvers and lattices. + ## Directory Structure ``` python/ __main__.py # CLI entry point (port of dry.c) - stdface_main.py # Main logic (port of StdFace_main.c) - stdface_vals.py # StdIntList dataclass (port of StdFace_vals.h) - stdface_model_util.py # Shared utilities (port of StdFace_ModelUtil.c) - version.py # Version information - keyword_parser.py # Keyword parsing subsystem - param_check.py # Parameter validation utilities + stdface/ + plugin.py # SolverPlugin ABC + solver registry + core/ + stdface_main.py # Main logic (port of StdFace_main.c) + stdface_vals.py # StdIntList dataclass (port of StdFace_vals.h) + keyword_parser.py # Keyword parsing subsystem + param_check.py # Parameter validation utilities + lattice/ # Lattice plugins + __init__.py # LatticePlugin ABC + lattice registry + chain_lattice.py # 1D chain (ChainPlugin) + square_lattice.py # 2D square (SquarePlugin) + ladder.py # 2-leg ladder (LadderPlugin) + triangular_lattice.py # 2D triangular (TriangularPlugin) + honeycomb_lattice.py # 2D honeycomb (HoneycombPlugin) + kagome.py # 2D kagome (KagomePlugin) + orthorhombic.py # 3D orthorhombic (OrthorhombicPlugin) + fc_ortho.py # 3D face-centered orthorhombic (FCOrthoPlugin) + pyrochlore.py # 3D pyrochlore (PyrochlorePlugin) + wannier90.py # Wannier90 interface (Wannier90Plugin) + boost_output.py # Boost output utilities + geometry_output.py # Geometry output functions + input_params.py # Input parameter resolution + interaction_builder.py # Interaction building utilities + site_util.py # Site utility functions + solvers/ # Solver plugins + __init__.py # Auto-imports all solver plugins + hphi/_plugin.py # HPhi plugin (HPhiPlugin) + mvmc/_plugin.py # mVMC plugin (MVMCPlugin) + uhf/_plugin.py # UHF plugin (UHFPlugin) + hwave/_plugin.py # H-wave plugin (HWavePlugin) + writer/ # Output writers (shared) + common_writer.py # Common output functions + interaction_writer.py # Interaction file writer + export_wannier90.py # Wannier90 format export history/ refactoring_log.md # Refactoring change log - lattice/ # Lattice implementations - __init__.py - chain_lattice.py # 1D chain lattice - square_lattice.py # 2D square lattice - ladder.py # 2-leg ladder lattice - triangular_lattice.py # 2D triangular lattice - honeycomb_lattice.py # 2D honeycomb lattice - kagome.py # 2D kagome lattice - orthorhombic.py # 3D orthorhombic lattice - fc_ortho.py # Face-centered orthorhombic lattice - pyrochlore.py # Pyrochlore lattice - wannier90.py # Wannier90 input reader - boost_output.py # Boost output utilities - geometry_output.py # Geometry output functions - input_params.py # Input parameter resolution - interaction_builder.py # Interaction building utilities - site_util.py # Site utility functions - writer/ # Solver-specific writers - __init__.py - common_writer.py # Common output functions - hphi_writer.py # HPhi-specific writer - mvmc_writer.py # mVMC-specific writer - mvmc_variational.py # mVMC variational functions - interaction_writer.py # Interaction file writer - solver_writer.py # Solver writer base classes - export_wannier90.py # Wannier90 format export ``` ## Usage @@ -57,15 +95,15 @@ python/ From the project root: ```bash -PYTHONPATH=python python3 python/__main__.py stan.in +./stdface stan.in ``` To select a solver: ```bash -PYTHONPATH=python python3 python/__main__.py stan.in --solver mVMC -PYTHONPATH=python python3 python/__main__.py stan.in --solver UHF -PYTHONPATH=python python3 python/__main__.py stan.in --solver HWAVE +./stdface stan.in --solver mVMC +./stdface stan.in --solver UHF +./stdface stan.in --solver HWAVE ``` The default solver is HPhi. @@ -73,7 +111,7 @@ The default solver is HPhi. ### Print Version ```bash -PYTHONPATH=python python3 python/__main__.py -v +./stdface -v ``` ### Calling from Python @@ -82,7 +120,7 @@ PYTHONPATH=python python3 python/__main__.py -v import sys sys.path.insert(0, "python") -from stdface_main import stdface_main +from stdface.core.stdface_main import stdface_main stdface_main("stan.in", solver="HPhi") ``` @@ -142,8 +180,8 @@ python3 -m pytest test/unit/ --cov=python --cov-report=html ``` **Test Coverage:** -- 1,252 unit tests covering all modules -- Tests for lattice implementations, writers, parsers, and utilities +- 1,268 unit tests covering all modules +- Tests for lattice implementations, writers, parsers, plugin registries, and utilities - All tests maintain byte-identical output with C version ### Integration Tests @@ -158,8 +196,7 @@ bash test/run_all_integration.sh # Copy test input to a working directory and run cp test/hphi/lanczos_hubbard_square/stan.in /tmp/work/ cd /tmp/work -PYTHONPATH=/path/to/StdFace/python python3 -c \ - "from stdface_main import stdface_main; stdface_main('stan.in', solver='HPhi')" +/path/to/StdFace/stdface stan.in # Compare outputs against reference diff /path/to/StdFace/test/hphi/lanczos_hubbard_square/ref/modpara.def modpara.def @@ -174,34 +211,35 @@ diff /path/to/StdFace/test/hphi/lanczos_hubbard_square/ref/modpara.def modpara.d The Python implementation has been refactored from a direct C translation into idiomatic Python: -- **Modular design**: Code organized into logical subpackages (`lattice/`, `writer/`) -- **Python idioms**: Enums, dict dispatch, context managers, type hints +- **Plugin architecture**: Solvers and lattices are self-registering plugins +- **Modular design**: Code organized into logical subpackages (`lattice/`, `solvers/`, `writer/`) +- **Python idioms**: Enums, ABC, context managers, type hints - **Reduced duplication**: Helper functions extracted to eliminate code duplication -- **Comprehensive testing**: 1,252 unit tests with full coverage +- **Comprehensive testing**: 1,268 unit tests with full coverage - **Documentation**: NumPy-style docstrings throughout -See `history/refactoring_log.md` for detailed refactoring history (Steps 1-77). +See `history/refactoring_log.md` for detailed refactoring history. ## Mapping to C Sources -| C Source File | Python File | +| C Source File | Python Module | |---|---| | `dry.c` | `__main__.py` | -| `StdFace_main.c` | `stdface_main.py` (refactored into multiple modules) | -| `StdFace_vals.h` | `stdface_vals.py` | -| `StdFace_ModelUtil.c/h` | `stdface_model_util.py` | -| `version.h` | `version.py` | -| `ChainLattice.c` | `lattice/chain_lattice.py` | -| `SquareLattice.c` | `lattice/square_lattice.py` | -| `Ladder.c` | `lattice/ladder.py` | -| `TriangularLattice.c` | `lattice/triangular_lattice.py` | -| `HoneycombLattice.c` | `lattice/honeycomb_lattice.py` | -| `Kagome.c` | `lattice/kagome.py` | -| `Orthorhombic.c` | `lattice/orthorhombic.py` | -| `FCOrtho.c` | `lattice/fc_ortho.py` | -| `Pyrochlore.c` | `lattice/pyrochlore.py` | -| `Wannier90.c` | `lattice/wannier90.py` | -| `export_wannier90.c` | `writer/export_wannier90.py` | +| `StdFace_main.c` | `stdface/core/stdface_main.py` | +| `StdFace_vals.h` | `stdface/core/stdface_vals.py` | +| `StdFace_ModelUtil.c/h` | `stdface/core/stdface_model_util.py` | +| `version.h` | `stdface/version.py` | +| `ChainLattice.c` | `stdface/lattice/chain_lattice.py` | +| `SquareLattice.c` | `stdface/lattice/square_lattice.py` | +| `Ladder.c` | `stdface/lattice/ladder.py` | +| `TriangularLattice.c` | `stdface/lattice/triangular_lattice.py` | +| `HoneycombLattice.c` | `stdface/lattice/honeycomb_lattice.py` | +| `Kagome.c` | `stdface/lattice/kagome.py` | +| `Orthorhombic.c` | `stdface/lattice/orthorhombic.py` | +| `FCOrtho.c` | `stdface/lattice/fc_ortho.py` | +| `Pyrochlore.c` | `stdface/lattice/pyrochlore.py` | +| `Wannier90.c` | `stdface/lattice/wannier90.py` | +| `export_wannier90.c` | `stdface/writer/export_wannier90.py` | Note: The Python implementation has been significantly refactored beyond the original C structure for better maintainability and Pythonic code style. diff --git a/python/__main__.py b/python/__main__.py index d3549d0..6382f28 100644 --- a/python/__main__.py +++ b/python/__main__.py @@ -1,80 +1,9 @@ -"""Command-line entry point for the StdFace standard-mode input generator. +"""Entry point for ``python python/__main__.py``. -This module is the Python translation of ``dry.c``. It can be invoked as:: - - python -m stdface stan.in # default solver: HPhi - python -m stdface stan.in --solver mVMC - -or, equivalently, via the installed console script (if packaged). - -License -------- -HPhi-mVMC-StdFace - Common input generator -Copyright (C) 2015 The University of Tokyo - -This program is free software: you can redistribute it and/or modify -it under the terms of the GNU General Public License as published by -the Free Software Foundation, either version 3 of the License, or -(at your option) any later version. +Delegates to ``stdface.__main__.main()``. """ -from __future__ import annotations - -import argparse +from stdface.__main__ import main import sys -from version import print_version -from stdface_main import stdface_main - - -def main(argv: list[str] | None = None) -> int: - """Parse command-line arguments and run the standard-mode generator. - - Parameters - ---------- - argv : list of str or None - Command-line arguments. ``None`` means ``sys.argv[1:]``. - - Returns - ------- - int - Exit status (0 = success, 1 = usage error). - """ - parser = argparse.ArgumentParser( - prog="stdface", - description="StdFace: standard-mode input generator for HPhi / mVMC / UHF / H-wave", - ) - parser.add_argument( - "-v", "--version", - action="store_true", - help="print version and exit", - ) - parser.add_argument( - "input_file", - nargs="?", - default=None, - help="standard-mode input file (e.g. stan.in)", - ) - parser.add_argument( - "--solver", - choices=["HPhi", "mVMC", "UHF", "HWAVE"], - default="HPhi", - help="target solver (default: HPhi)", - ) - - args = parser.parse_args(argv) - - if args.version: - print_version() - return 0 - - if args.input_file is None: - print_version() - parser.print_usage() - return 1 - - stdface_main(args.input_file, solver=args.solver) - return 0 - - if __name__ == "__main__": sys.exit(main()) diff --git a/python/history/refactoring_log.md b/python/history/refactoring_log.md index d61a827..1913e19 100644 --- a/python/history/refactoring_log.md +++ b/python/history/refactoring_log.md @@ -8113,3 +8113,83 @@ Each function body is now 2-3 lines shorter. **Tests**: - Unit: 1260 passed - Integration: 83/83 passed + +## Step 189 — 2026-01-29 + +**Files**: Multiple — major restructuring +**Change**: Implemented Solver Plugin Architecture: +1. Restructured `python/` into `python/stdface/` package with `core/`, `lattice/`, `writer/`, `solvers/` subpackages +2. Created `pyproject.toml` for `pip install -e .` support +3. Defined `SolverPlugin` ABC and plugin registry in `stdface/plugin.py` +4. Extracted 4 solver plugins: `solvers/hphi.py`, `solvers/mvmc.py`, `solvers/uhf.py`, `solvers/hwave.py` +5. Each plugin encapsulates solver-specific: keyword table, reset tables, write method, post_lattice hook +6. Refactored `stdface_main.py` to delegate to plugins for field resets, keyword parsing, post-lattice, and writing +7. Backward-compatible shim modules at old `python/` locations for existing imports +8. Added entry_points support in pyproject.toml for future external plugins +**Phase**: 4 — Architecture (plugin system) + +**Tests**: +- Unit: 1260 passed +- Integration: verified HPhi solver produces correct output + +## Step 190 — 2026-01-29 + +**Files**: `stdface/writer/` → `stdface/solvers/` +**Change**: Moved solver-specific writer modules into solver plugin directory: +- `writer/hphi_writer.py` → `solvers/hphi_writer.py` +- `writer/mvmc_writer.py` → `solvers/mvmc_writer.py` +- `writer/mvmc_variational.py` → `solvers/mvmc_variational.py` +- `writer/export_wannier90.py` → `solvers/export_wannier90.py` +- `writer/solver_writer.py` replaced with legacy wrapper delegating to plugins +- `writer/` retains only shared code: `common_writer.py`, `interaction_writer.py` +- Plugin files now import writer helpers from sibling modules within `solvers/` +- Backward-compatible shims use `sys.modules` aliasing for monkeypatch support +**Phase**: 4 — Architecture (plugin system) + +**Tests**: +- Unit: 1260 passed +- Integration: 83/83 passed + +## Step 191 — 2026-01-29 + +**Files**: `python/`, `test/unit/` +**Change**: Removed all backward-compatibility shim modules: +- Deleted top-level shims: `stdface_vals.py`, `param_check.py`, `keyword_parser.py`, `stdface_main.py`, `stdface_model_util.py`, `version.py` +- Deleted old `lattice/` and `writer/` directories (contained only shims) +- Deleted `writer/solver_writer.py`, `writer/hphi_writer.py`, etc. shims from `stdface/writer/` +- Updated all test imports to use new `stdface.*` paths directly +- Retained only `python/__main__.py` as entry point for integration tests +- Rewrote `test_solver_writer.py` to test plugin system instead of legacy Writer classes +**Phase**: 4 — Architecture (cleanup) + +**Tests**: +- Unit: 1260 passed +- Integration: 83/83 passed + +--- + +### Step 189: Lattice plugin architecture (2026-01-29) + +**Files changed**: +- `python/stdface/lattice/__init__.py` — added `LatticePlugin` ABC, registry (`register_lattice`, `get_lattice`, `get_all_lattices`) +- `python/stdface/lattice/chain_lattice.py` — added `ChainPlugin` (auto-registered) +- `python/stdface/lattice/square_lattice.py` — added `SquarePlugin` +- `python/stdface/lattice/triangular_lattice.py` — added `TriangularPlugin` +- `python/stdface/lattice/honeycomb_lattice.py` — added `HoneycombPlugin` (with boost) +- `python/stdface/lattice/kagome.py` — added `KagomePlugin` (with boost) +- `python/stdface/lattice/ladder.py` — added `LadderPlugin` (with boost) +- `python/stdface/lattice/orthorhombic.py` — added `OrthorhombicPlugin` +- `python/stdface/lattice/fc_ortho.py` — added `FCOrthoPlugin` +- `python/stdface/lattice/pyrochlore.py` — added `PyrochlorePlugin` +- `python/stdface/lattice/wannier90.py` — added `Wannier90Plugin` +- `python/stdface/core/stdface_main.py` — replaced `LATTICE_DISPATCH`/`BOOST_DISPATCH` dicts with proxy objects over the plugin registry; `_build_lattice_and_boost` now uses `get_lattice()` directly +- `python/stdface/solvers/hphi/_plugin.py` — `post_lattice` uses `get_lattice().boost()` instead of `BOOST_DISPATCH` +- `test/unit/test_lattice_dispatch.py` — rewritten to test plugin registry + proxies + +**Why**: Apply the same plugin/template pattern used for solvers to lattices, enabling new lattices to be added without modifying dispatch tables in `stdface_main.py`. + +**Phase**: 4 — Architecture (lattice plugins) + +**Tests**: +- Unit: 1268 passed +- Integration: pending diff --git a/python/lattice/__init__.py b/python/lattice/__init__.py deleted file mode 100644 index 0985bcc..0000000 --- a/python/lattice/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -"""Lattice construction subpackage. - -This package contains modules for building lattice geometries and their -associated interaction terms. Each lattice type is implemented in its own -module; shared utilities (site initialisation, geometry output, interaction -building, input parameter helpers) are also included. -""" -from __future__ import annotations - -from . import chain_lattice, square_lattice, ladder, triangular_lattice -from . import honeycomb_lattice, kagome, orthorhombic, fc_ortho, pyrochlore -from . import wannier90 diff --git a/python/pyproject.toml b/python/pyproject.toml new file mode 100644 index 0000000..db50111 --- /dev/null +++ b/python/pyproject.toml @@ -0,0 +1,39 @@ +[build-system] +requires = ["setuptools>=64"] +build-backend = "setuptools.build_meta" + +[project] +name = "stdface" +version = "0.5.0" +description = "Standard-mode input generator for HPhi / mVMC / UHF / H-wave" +readme = "README.md" +license = "GPL-3.0-or-later" +requires-python = ">=3.10" +dependencies = ["numpy"] + +[project.optional-dependencies] +dev = ["pytest", "pytest-cov"] + +[project.scripts] +stdface = "stdface.__main__:main" + +[project.entry-points."stdface.solvers"] +hphi = "stdface.solvers.hphi:HPhiPlugin" +mvmc = "stdface.solvers.mvmc:MVMCPlugin" +uhf = "stdface.solvers.uhf:UHFPlugin" +hwave = "stdface.solvers.hwave:HWavePlugin" + +[project.entry-points."stdface.lattices"] +chain = "stdface.lattice.chain_lattice:ChainPlugin" +square = "stdface.lattice.square_lattice:SquarePlugin" +triangular = "stdface.lattice.triangular_lattice:TriangularPlugin" +honeycomb = "stdface.lattice.honeycomb_lattice:HoneycombPlugin" +kagome = "stdface.lattice.kagome:KagomePlugin" +ladder = "stdface.lattice.ladder:LadderPlugin" +orthorhombic = "stdface.lattice.orthorhombic:OrthorhombicPlugin" +fc_ortho = "stdface.lattice.fc_ortho:FCOrthoPlugin" +pyrochlore = "stdface.lattice.pyrochlore:PyrochlorePlugin" +wannier90 = "stdface.lattice.wannier90:Wannier90Plugin" + +[tool.setuptools.packages.find] +include = ["stdface*"] diff --git a/python/stdface/__init__.py b/python/stdface/__init__.py new file mode 100644 index 0000000..69c7afd --- /dev/null +++ b/python/stdface/__init__.py @@ -0,0 +1,6 @@ +"""StdFace: Standard-mode input generator for HPhi / mVMC / UHF / H-wave. + +This package provides tools for generating Expert-mode definition files +from a simplified Standard-mode input file. +""" +from __future__ import annotations diff --git a/python/stdface/__main__.py b/python/stdface/__main__.py new file mode 100644 index 0000000..dc174a6 --- /dev/null +++ b/python/stdface/__main__.py @@ -0,0 +1,80 @@ +"""Command-line entry point for the StdFace standard-mode input generator. + +This module is the Python translation of ``dry.c``. It can be invoked as:: + + python -m stdface stan.in # default solver: HPhi + python -m stdface stan.in --solver mVMC + +or, equivalently, via the installed console script (if packaged). + +License +------- +HPhi-mVMC-StdFace - Common input generator +Copyright (C) 2015 The University of Tokyo + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation, either version 3 of the License, or +(at your option) any later version. +""" +from __future__ import annotations + +import argparse +import sys + +from .core.version import print_version +from .core.stdface_main import stdface_main + + +def main(argv: list[str] | None = None) -> int: + """Parse command-line arguments and run the standard-mode generator. + + Parameters + ---------- + argv : list of str or None + Command-line arguments. ``None`` means ``sys.argv[1:]``. + + Returns + ------- + int + Exit status (0 = success, 1 = usage error). + """ + parser = argparse.ArgumentParser( + prog="stdface", + description="StdFace: standard-mode input generator for HPhi / mVMC / UHF / H-wave", + ) + parser.add_argument( + "-v", "--version", + action="store_true", + help="print version and exit", + ) + parser.add_argument( + "input_file", + nargs="?", + default=None, + help="standard-mode input file (e.g. stan.in)", + ) + parser.add_argument( + "--solver", + choices=["HPhi", "mVMC", "UHF", "HWAVE"], + default="HPhi", + help="target solver (default: HPhi)", + ) + + args = parser.parse_args(argv) + + if args.version: + print_version() + return 0 + + if args.input_file is None: + print_version() + parser.print_usage() + return 1 + + stdface_main(args.input_file, solver=args.solver) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/python/stdface/core/__init__.py b/python/stdface/core/__init__.py new file mode 100644 index 0000000..985c691 --- /dev/null +++ b/python/stdface/core/__init__.py @@ -0,0 +1,2 @@ +"""Core modules for the StdFace standard-mode input generator.""" +from __future__ import annotations diff --git a/python/keyword_parser.py b/python/stdface/core/keyword_parser.py similarity index 99% rename from python/keyword_parser.py rename to python/stdface/core/keyword_parser.py index e9947c3..920cceb 100644 --- a/python/keyword_parser.py +++ b/python/stdface/core/keyword_parser.py @@ -9,8 +9,8 @@ import cmath import math -from stdface_vals import StdIntList, SolverType, NaN_i, UNSET_STRING -from param_check import exit_program +from .stdface_vals import StdIntList, SolverType, NaN_i, UNSET_STRING +from .param_check import exit_program _TRIM_TABLE = str.maketrans("", "", " :;\"\\\b\v\n\0") diff --git a/python/param_check.py b/python/stdface/core/param_check.py similarity index 99% rename from python/param_check.py rename to python/stdface/core/param_check.py index 2e2486d..833cdf1 100644 --- a/python/param_check.py +++ b/python/stdface/core/param_check.py @@ -45,7 +45,7 @@ import numpy as np -from stdface_vals import NaN_i +from .stdface_vals import NaN_i # --------------------------------------------------------------------------- diff --git a/python/stdface_main.py b/python/stdface/core/stdface_main.py similarity index 81% rename from python/stdface_main.py rename to python/stdface/core/stdface_main.py index 13b7a28..6d43fee 100644 --- a/python/stdface_main.py +++ b/python/stdface/core/stdface_main.py @@ -33,83 +33,103 @@ from collections.abc import Callable from typing import NamedTuple -from stdface_vals import ( +from .stdface_vals import ( StdIntList, ModelType, SolverType, MethodType, NaN_i, NaN_d, NaN_c, UNSET_STRING, ) -from param_check import exit_program -from writer.hphi_writer import ( - large_value as _large_value, +from .param_check import exit_program +from ..solvers.hphi.writer import ( vector_potential as _vector_potential, ) -from writer.common_writer import ( +from ..writer.common_writer import ( unsupported_system as _unsupported_system, ) -from writer.solver_writer import get_solver_writer -from keyword_parser import ( +from .keyword_parser import ( trim_space_quote as _trim_space_quote, parse_common_keyword as _parse_common_keyword, parse_solver_keyword as _parse_solver_keyword, + _apply_keyword_table, ) -from lattice import ( - chain_lattice, - square_lattice, - ladder, - triangular_lattice, - honeycomb_lattice, - kagome, - orthorhombic, - fc_ortho, - pyrochlore, - wannier90 as wannier90_mod, -) +from ..lattice import get_lattice as _get_lattice # --------------------------------------------------------------------------- -# Lattice dispatch tables +# Lattice dispatch (via plugin registry) # --------------------------------------------------------------------------- -# Maps every recognised lattice alias to its builder function. -# Each function has the signature ``(StdI: StdIntList) -> None``. -LATTICE_DISPATCH: dict[str, Callable[[StdIntList], None]] = { - "chain": chain_lattice.chain, - "chainlattice": chain_lattice.chain, - "face-centeredorthorhombic": fc_ortho.fc_ortho, - "fcorthorhombic": fc_ortho.fc_ortho, - "fco": fc_ortho.fc_ortho, - "face-centeredcubic": fc_ortho.fc_ortho, - "fccubic": fc_ortho.fc_ortho, - "fcc": fc_ortho.fc_ortho, - "honeycomb": honeycomb_lattice.honeycomb, - "honeycomblattice": honeycomb_lattice.honeycomb, - "kagome": kagome.kagome, - "kagomelattice": kagome.kagome, - "ladder": ladder.ladder, - "ladderlattice": ladder.ladder, - "orthorhombic": orthorhombic.orthorhombic, - "simpleorthorhombic": orthorhombic.orthorhombic, - "cubic": orthorhombic.orthorhombic, - "simplecubic": orthorhombic.orthorhombic, - "pyrochlore": pyrochlore.pyrochlore, - "tetragonal": square_lattice.tetragonal, - "tetragonallattice": square_lattice.tetragonal, - "square": square_lattice.tetragonal, - "squarelattice": square_lattice.tetragonal, - "triangular": triangular_lattice.triangular, - "triangularlattice": triangular_lattice.triangular, - "wannier90": wannier90_mod.wannier90, -} -"""Maps lattice name aliases to their builder functions.""" - -BOOST_DISPATCH: dict[str, Callable[[StdIntList], None]] = { - "chain": chain_lattice.chain_boost, - "chainlattice": chain_lattice.chain_boost, - "honeycomb": honeycomb_lattice.honeycomb_boost, - "honeycomblattice": honeycomb_lattice.honeycomb_boost, - "kagome": kagome.kagome_boost, - "kagomelattice": kagome.kagome_boost, - "ladder": ladder.ladder_boost, - "ladderlattice": ladder.ladder_boost, -} -"""Maps lattice name aliases to their boost builder functions (HPhi only).""" +# Backward-compatible dict-like objects that delegate to the lattice plugin +# registry. New code should use ``get_lattice(name).setup(StdI)`` and +# ``get_lattice(name).boost(StdI)`` directly. + + +class _LatticeDispatchProxy: + """Dict-like proxy over the lattice plugin registry (setup methods).""" + + def __getitem__(self, key: str) -> Callable[[StdIntList], None]: + return _get_lattice(key).setup + + def get(self, key: str, default=None): + try: + return self[key] + except KeyError: + return default + + def __contains__(self, key: str) -> bool: + try: + _get_lattice(key) + return True + except KeyError: + return False + + def items(self): + from ..lattice import get_all_lattices + result = [] + for plugin in get_all_lattices(): + for alias in plugin.aliases: + result.append((alias, plugin.setup)) + return result + + +class _BoostDispatchProxy: + """Dict-like proxy over the lattice plugin registry (boost methods).""" + + # Only lattices whose boost() is overridden (not the ABC default) + _BOOST_LATTICES = {"chain", "chainlattice", "honeycomb", "honeycomblattice", + "kagome", "kagomelattice", "ladder", "ladderlattice"} + + def __getitem__(self, key: str) -> Callable[[StdIntList], None]: + if key not in self._BOOST_LATTICES: + raise KeyError(key) + return _get_lattice(key).boost + + def get(self, key: str, default=None): + try: + return self[key] + except KeyError: + return default + + def __contains__(self, key: str) -> bool: + return key in self._BOOST_LATTICES + + def items(self): + result = [] + seen = set() + for alias in self._BOOST_LATTICES: + try: + plugin = _get_lattice(alias) + if id(plugin) not in seen: + seen.add(id(plugin)) + for a in plugin.aliases: + if a in self._BOOST_LATTICES: + result.append((a, plugin.boost)) + except KeyError: + pass + return result + + +LATTICE_DISPATCH = _LatticeDispatchProxy() +"""Dict-like proxy that maps lattice aliases to setup functions via the plugin registry.""" + +BOOST_DISPATCH = _BoostDispatchProxy() +"""Dict-like proxy that maps lattice aliases to boost functions via the plugin registry.""" # --------------------------------------------------------------------------- # Model name normalisation table @@ -386,7 +406,7 @@ class _ModelConfig(NamedTuple): def _apply_field_resets(StdI: StdIntList, solver: SolverType) -> None: - """Apply solver-specific field resets from the data tables. + """Apply solver-specific field resets from the plugin registry. Parameters ---------- @@ -395,9 +415,14 @@ def _apply_field_resets(StdI: StdIntList, solver: SolverType) -> None: solver : SolverType The solver whose field-reset tables to apply. """ - for name, value in _SOLVER_RESET_SCALARS.get(solver, ()): + from ..plugin import get_plugin + try: + plugin = get_plugin(solver) + except KeyError: + return + for name, value in plugin.reset_scalars: setattr(StdI, name, value) - for name, value in _SOLVER_RESET_ARRAYS.get(solver, ()): + for name, value in plugin.reset_arrays: arr = getattr(StdI, name) arr[...] = value @@ -521,21 +546,19 @@ def _build_lattice_and_boost(StdI: StdIntList, solver: str) -> None: If the lattice is not recognised. """ lattice = StdI.lattice - lattice_builder = LATTICE_DISPATCH.get(lattice) - if lattice_builder is not None: - lattice_builder(StdI) - else: + try: + lattice_plugin = _get_lattice(lattice) + except KeyError: _unsupported_system(StdI.model, StdI.lattice) + else: + lattice_plugin.setup(StdI) - if solver == SolverType.HPhi: - _large_value(StdI) - - if StdI.lBoost == 1: - boost_builder = BOOST_DISPATCH.get(lattice) - if boost_builder is not None: - boost_builder(StdI) - else: - _unsupported_system(StdI.model, StdI.lattice) + from ..plugin import get_plugin + try: + plugin = get_plugin(solver) + plugin.post_lattice(StdI) + except KeyError: + pass # =================================================================== @@ -543,6 +566,37 @@ def _build_lattice_and_boost(StdI: StdIntList, solver: str) -> None: # =================================================================== +def _parse_solver_keyword_via_plugin( + keyword: str, value: str, StdI: StdIntList, solver: str +) -> bool: + """Parse a solver-specific keyword using the plugin registry. + + Falls back to the legacy ``parse_solver_keyword`` if no plugin is found. + + Parameters + ---------- + keyword : str + The lowered keyword string. + value : str + The raw value string from the input file. + StdI : StdIntList + The global parameter structure, modified in place. + solver : str + The solver type. + + Returns + ------- + bool + True if the keyword was recognised, False otherwise. + """ + from ..plugin import get_plugin + try: + plugin = get_plugin(solver) + return _apply_keyword_table(plugin.keyword_table, keyword, value, StdI) + except KeyError: + return _parse_solver_keyword(keyword, value, StdI, solver) + + def _parse_input_file(fname: str, StdI: StdIntList, solver: str) -> None: """Open and parse a Standard-mode input file into *StdI*. @@ -592,7 +646,7 @@ def _parse_input_file(fname: str, StdI: StdIntList, solver: str) -> None: print(f" KEYWORD : {keyword:<20s} | VALUE : {value} ") if not _parse_common_keyword(keyword, value, StdI): - if not _parse_solver_keyword(keyword, value, StdI, solver): + if not _parse_solver_keyword_via_plugin(keyword, value, StdI, solver): print("ERROR ! Unsupported Keyword in Standard mode!") exit_program(-1) @@ -668,8 +722,9 @@ def stdface_main(fname: str, solver: str = "HPhi") -> None: print("###### Print Expert input files ######") print("") - writer = get_solver_writer(solver) - writer.write(StdI) + from ..plugin import get_plugin + plugin = get_plugin(solver) + plugin.write(StdI) # ------------------------------------------------------------------ # Finalise diff --git a/python/stdface_model_util.py b/python/stdface/core/stdface_model_util.py similarity index 77% rename from python/stdface_model_util.py rename to python/stdface/core/stdface_model_util.py index e773c80..9555434 100644 --- a/python/stdface_model_util.py +++ b/python/stdface/core/stdface_model_util.py @@ -36,7 +36,7 @@ from __future__ import annotations -from param_check import ( # noqa: F401 – re-exported for backward compatibility +from .param_check import ( # noqa: F401 – re-exported for backward compatibility exit_program, print_val_d, print_val_dd, @@ -47,17 +47,17 @@ not_used_i, required_val_i, ) -from lattice.input_params import ( # noqa: F401 – re-exported for backward compatibility +from ..lattice.input_params import ( # noqa: F401 – re-exported for backward compatibility input_spin_nn, input_spin, input_coulomb_v, input_hopp, ) -from lattice.geometry_output import ( # noqa: F401 – re-exported for backward compatibility +from ..lattice.geometry_output import ( # noqa: F401 – re-exported for backward compatibility print_xsf, print_geometry, ) -from lattice.interaction_builder import ( # noqa: F401 – re-exported for backward compatibility +from ..lattice.interaction_builder import ( # noqa: F401 – re-exported for backward compatibility trans, hopping, hubbard_local, @@ -67,7 +67,7 @@ coulomb, malloc_interactions, ) -from lattice.site_util import ( # noqa: F401 – re-exported for backward compatibility +from ..lattice.site_util import ( # noqa: F401 – re-exported for backward compatibility _fold_site, init_site, find_site, diff --git a/python/stdface_vals.py b/python/stdface/core/stdface_vals.py similarity index 100% rename from python/stdface_vals.py rename to python/stdface/core/stdface_vals.py diff --git a/python/version.py b/python/stdface/core/version.py similarity index 100% rename from python/version.py rename to python/stdface/core/version.py diff --git a/python/stdface/lattice/__init__.py b/python/stdface/lattice/__init__.py new file mode 100644 index 0000000..c4d9d9f --- /dev/null +++ b/python/stdface/lattice/__init__.py @@ -0,0 +1,158 @@ +"""Lattice construction subpackage. + +This package contains modules for building lattice geometries and their +associated interaction terms. Each lattice type is implemented in its own +module; shared utilities (site initialisation, geometry output, interaction +building, input parameter helpers) are also included. + +The :class:`LatticePlugin` ABC and plugin registry provide a uniform +interface for lattice discovery and dispatch. +""" +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ..core.stdface_vals import StdIntList + + +class LatticePlugin(ABC): + """Abstract base class for lattice plugins. + + Each lattice plugin declares its geometry metadata and provides + a :meth:`setup` method that builds the lattice Hamiltonian terms. + Lattices that support HPhi Boost mode override :meth:`boost`. + + Attributes + ---------- + name : str + Canonical lattice name (e.g. ``"chain"``). + aliases : list[str] + All recognised name aliases (including the canonical name). + ndim : int + Number of spatial dimensions (1, 2, or 3). + """ + + @property + @abstractmethod + def name(self) -> str: + """Canonical lattice name.""" + + @property + @abstractmethod + def aliases(self) -> list[str]: + """All recognised name aliases for this lattice.""" + + @property + @abstractmethod + def ndim(self) -> int: + """Number of spatial dimensions (1, 2, or 3).""" + + @abstractmethod + def setup(self, StdI: StdIntList) -> None: + """Build the lattice geometry and Hamiltonian terms. + + Parameters + ---------- + StdI : StdIntList + The parameter structure (modified in place). + """ + + def boost(self, StdI: StdIntList) -> None: + """Build HPhi Boost-mode data for this lattice. + + The default implementation does nothing. Override in lattices + that support Boost (chain, honeycomb, kagome, ladder). + + Parameters + ---------- + StdI : StdIntList + The parameter structure (modified in place). + """ + + +# --------------------------------------------------------------------------- +# Lattice plugin registry +# --------------------------------------------------------------------------- + +_lattice_plugins: dict[str, LatticePlugin] = {} +"""Maps every alias to its :class:`LatticePlugin` instance.""" + + +def register_lattice(plugin: LatticePlugin) -> None: + """Register a lattice plugin under all its aliases. + + Parameters + ---------- + plugin : LatticePlugin + The plugin instance to register. + + Raises + ------ + ValueError + If any alias is already registered to a *different* plugin. + """ + for alias in plugin.aliases: + existing = _lattice_plugins.get(alias) + if existing is not None and existing is not plugin: + raise ValueError( + f"Lattice alias {alias!r} is already registered to " + f"{existing.name!r}, cannot register {plugin.name!r}" + ) + _lattice_plugins[alias] = plugin + + +def get_lattice(name: str) -> LatticePlugin: + """Retrieve a registered lattice plugin by alias. + + Parameters + ---------- + name : str + A lattice name alias (e.g. ``"chain"``, ``"squarelattice"``). + + Returns + ------- + LatticePlugin + The registered plugin instance. + + Raises + ------ + KeyError + If no plugin is registered under *name*. + """ + if name not in _lattice_plugins: + _discover_lattices() + if name not in _lattice_plugins: + raise KeyError( + f"No lattice plugin registered for {name!r}. " + f"Available: {', '.join(sorted(set(p.name for p in _lattice_plugins.values()))) or '(none)'}" + ) + return _lattice_plugins[name] + + +def get_all_lattices() -> list[LatticePlugin]: + """Return all unique registered lattice plugins. + + Returns + ------- + list[LatticePlugin] + Unique plugin instances (deduplicated by identity). + """ + if not _lattice_plugins: + _discover_lattices() + seen: set[int] = set() + result: list[LatticePlugin] = [] + for plugin in _lattice_plugins.values(): + pid = id(plugin) + if pid not in seen: + seen.add(pid) + result.append(plugin) + return result + + +def _discover_lattices() -> None: + """Import all lattice modules to trigger auto-registration.""" + from . import chain_lattice, square_lattice, ladder, triangular_lattice # noqa: F401 + from . import honeycomb_lattice, kagome, orthorhombic, fc_ortho, pyrochlore # noqa: F401 + from . import wannier90 # noqa: F401 diff --git a/python/lattice/boost_output.py b/python/stdface/lattice/boost_output.py similarity index 99% rename from python/lattice/boost_output.py rename to python/stdface/lattice/boost_output.py index ba5a043..65f08c5 100644 --- a/python/lattice/boost_output.py +++ b/python/stdface/lattice/boost_output.py @@ -34,7 +34,7 @@ import numpy as np -from stdface_vals import StdIntList +from ..core.stdface_vals import StdIntList def write_boost_mag_field(fp: TextIO, StdI: StdIntList) -> None: diff --git a/python/lattice/chain_lattice.py b/python/stdface/lattice/chain_lattice.py similarity index 91% rename from python/lattice/chain_lattice.py rename to python/stdface/lattice/chain_lattice.py index f144394..94d54b2 100644 --- a/python/lattice/chain_lattice.py +++ b/python/stdface/lattice/chain_lattice.py @@ -19,8 +19,8 @@ import numpy as np -from stdface_vals import StdIntList, ModelType, SolverType -from param_check import ( +from ..core.stdface_vals import StdIntList, ModelType, SolverType +from ..core.param_check import ( exit_program, print_val_d, print_val_i, not_used_d, not_used_i, not_used_j, required_val_i, ) @@ -277,3 +277,37 @@ def chain_boost(StdI: StdIntList) -> None: StdI.list_6spin_pair[ipivot, :, 7] = [3, 5, 0, 1, 2, 4, 2] write_boost_6spin_pair(fp, StdI) + + +# --------------------------------------------------------------------------- +# Lattice plugin registration +# --------------------------------------------------------------------------- + +from . import LatticePlugin, register_lattice + + +class ChainPlugin(LatticePlugin): + """Plugin for the 1D chain lattice.""" + + @property + def name(self) -> str: + return "chain" + + @property + def aliases(self) -> list[str]: + return ["chain", "chainlattice"] + + @property + def ndim(self) -> int: + return 1 + + def setup(self, StdI: StdIntList) -> None: + """Delegate to the chain() function.""" + chain(StdI) + + def boost(self, StdI: StdIntList) -> None: + """Delegate to the chain_boost() function.""" + chain_boost(StdI) + + +register_lattice(ChainPlugin()) diff --git a/python/lattice/fc_ortho.py b/python/stdface/lattice/fc_ortho.py similarity index 89% rename from python/lattice/fc_ortho.py rename to python/stdface/lattice/fc_ortho.py index baedd99..01b34ff 100644 --- a/python/lattice/fc_ortho.py +++ b/python/stdface/lattice/fc_ortho.py @@ -16,8 +16,8 @@ import numpy as np -from stdface_vals import StdIntList, ModelType -from param_check import ( +from ..core.stdface_vals import StdIntList, ModelType +from ..core.param_check import ( print_val_d, print_val_i, not_used_j, not_used_d, not_used_i, ) @@ -183,3 +183,38 @@ def fc_ortho(StdI: StdIntList) -> None: StdI, iW, iL, iH, dW, dL, dH, si, sj, J, t, V) close_lattice_xsf(StdI) + + +# --------------------------------------------------------------------------- +# Lattice plugin registration +# --------------------------------------------------------------------------- + +from . import LatticePlugin, register_lattice + +_fc_ortho_setup = fc_ortho + + +class FCOrthoPlugin(LatticePlugin): + """Plugin for the 3D face-centered orthorhombic lattice.""" + + @property + def name(self) -> str: + return "fco" + + @property + def aliases(self) -> list[str]: + return [ + "face-centeredorthorhombic", "fcorthorhombic", "fco", + "face-centeredcubic", "fccubic", "fcc", + ] + + @property + def ndim(self) -> int: + return 3 + + def setup(self, StdI: StdIntList) -> None: + """Delegate to the fc_ortho() function.""" + _fc_ortho_setup(StdI) + + +register_lattice(FCOrthoPlugin()) diff --git a/python/lattice/geometry_output.py b/python/stdface/lattice/geometry_output.py similarity index 98% rename from python/lattice/geometry_output.py rename to python/stdface/lattice/geometry_output.py index 568a795..53caec3 100644 --- a/python/lattice/geometry_output.py +++ b/python/stdface/lattice/geometry_output.py @@ -25,7 +25,7 @@ import numpy as np -from stdface_vals import StdIntList, ModelType, SolverType +from ..core.stdface_vals import StdIntList, ModelType, SolverType def _cell_diff(Cell, iCell: int, jCell: int) -> list[int]: diff --git a/python/lattice/honeycomb_lattice.py b/python/stdface/lattice/honeycomb_lattice.py similarity index 92% rename from python/lattice/honeycomb_lattice.py rename to python/stdface/lattice/honeycomb_lattice.py index 826fb49..4596cb5 100644 --- a/python/lattice/honeycomb_lattice.py +++ b/python/stdface/lattice/honeycomb_lattice.py @@ -18,8 +18,8 @@ import numpy as np -from stdface_vals import StdIntList, ModelType -from param_check import ( +from ..core.stdface_vals import StdIntList, ModelType +from ..core.param_check import ( exit_program, print_val_d, print_val_i, not_used_j, not_used_d, not_used_i, ) @@ -268,3 +268,37 @@ def honeycomb_boost(StdI: StdIntList) -> None: StdI.list_6spin_pair[1, :, 3] = [2, 5, 0, 1, 3, 4, 3] write_boost_6spin_pair(fp, StdI) + + +# --------------------------------------------------------------------------- +# Lattice plugin registration +# --------------------------------------------------------------------------- + +from . import LatticePlugin, register_lattice + + +class HoneycombPlugin(LatticePlugin): + """Plugin for the 2D honeycomb lattice.""" + + @property + def name(self) -> str: + return "honeycomb" + + @property + def aliases(self) -> list[str]: + return ["honeycomb", "honeycomblattice"] + + @property + def ndim(self) -> int: + return 2 + + def setup(self, StdI: StdIntList) -> None: + """Delegate to the honeycomb() function.""" + honeycomb(StdI) + + def boost(self, StdI: StdIntList) -> None: + """Delegate to the honeycomb_boost() function.""" + honeycomb_boost(StdI) + + +register_lattice(HoneycombPlugin()) diff --git a/python/lattice/input_params.py b/python/stdface/lattice/input_params.py similarity index 99% rename from python/lattice/input_params.py rename to python/stdface/lattice/input_params.py index 655660d..2777d20 100644 --- a/python/lattice/input_params.py +++ b/python/stdface/lattice/input_params.py @@ -34,7 +34,7 @@ import numpy as np -from param_check import exit_program, SPIN_SUFFIXES +from ..core.param_check import exit_program, SPIN_SUFFIXES diff --git a/python/lattice/interaction_builder.py b/python/stdface/lattice/interaction_builder.py similarity index 97% rename from python/lattice/interaction_builder.py rename to python/stdface/lattice/interaction_builder.py index 2f67efe..ce28be6 100644 --- a/python/lattice/interaction_builder.py +++ b/python/stdface/lattice/interaction_builder.py @@ -50,7 +50,7 @@ from typing import TextIO -from stdface_vals import StdIntList, ModelType, SolverType, MethodType, ZERO_BODY_EPS, AMPLITUDE_EPS +from ..core.stdface_vals import StdIntList, ModelType, SolverType, MethodType, ZERO_BODY_EPS, AMPLITUDE_EPS def trans( @@ -375,6 +375,16 @@ def general_j( Si = 0.5 * Si2 Sj = 0.5 * Sj2 + # Precompute spin ladder factors for all needed (S, Sz) pairs + ladder_i: dict[int, float] = {} + for ispin in range(1, Si2 + 1): + Siz = Si - float(ispin) + ladder_i[ispin] = _spin_ladder_factor(Si, Siz) + ladder_j: dict[int, float] = {} + for jspin in range(1, Sj2 + 1): + Sjz = Sj - float(jspin) + ladder_j[jspin] = _spin_ladder_factor(Sj, Sjz) + for ispin in range(Si2 + 1): Siz = Si - float(ispin) for jspin in range(Sj2 + 1): @@ -388,8 +398,8 @@ def general_j( jsite, jspin, jsite, jspin) if ispin > 0 and jspin > 0 and use_ex: - fi = _spin_ladder_factor(Si, Siz) - fj = _spin_ladder_factor(Sj, Sjz) + fi = ladder_i[ispin] + fj = ladder_j[jspin] # (2) S_i^+ S_j^- + h.c. intr0 = 0.25 * (J[0, 0] + J[1, 1] + 1j * (J[0, 1] - J[1, 0])) * fi * fj @@ -411,7 +421,7 @@ def general_j( # (4) S_i^+ S_{jz} + h.c. if ispin > 0: - fi = _spin_ladder_factor(Si, Siz) + fi = ladder_i[ispin] intr0 = 0.5 * (J[0, 2] - 1j * J[1, 2]) * fi * Sjz intr(StdI, intr0, isite, ispin - 1, isite, ispin, @@ -422,7 +432,7 @@ def general_j( # (5) S_{iz} S_j^+ + h.c. if jspin > 0: - fj = _spin_ladder_factor(Sj, Sjz) + fj = ladder_j[jspin] intr0 = 0.5 * (J[2, 0] - 1j * J[2, 1]) * Siz * fj intr(StdI, intr0, isite, ispin, isite, ispin, diff --git a/python/lattice/kagome.py b/python/stdface/lattice/kagome.py similarity index 92% rename from python/lattice/kagome.py rename to python/stdface/lattice/kagome.py index 148dde6..ecfc6a4 100644 --- a/python/lattice/kagome.py +++ b/python/stdface/lattice/kagome.py @@ -18,8 +18,8 @@ import numpy as np -from stdface_vals import StdIntList, ModelType -from param_check import ( +from ..core.stdface_vals import StdIntList, ModelType +from ..core.param_check import ( exit_program, print_val_d, print_val_i, not_used_j, not_used_d, not_used_i, ) @@ -279,3 +279,40 @@ def kagome_boost(StdI: StdIntList) -> None: StdI.list_6spin_pair[3, :, 4] = [2, 4, 0, 1, 3, 5, 1] write_boost_6spin_pair(fp, StdI) + + +# --------------------------------------------------------------------------- +# Lattice plugin registration +# --------------------------------------------------------------------------- + +from . import LatticePlugin, register_lattice + +_kagome_setup = kagome +_kagome_boost = kagome_boost + + +class KagomePlugin(LatticePlugin): + """Plugin for the 2D kagome lattice.""" + + @property + def name(self) -> str: + return "kagome" + + @property + def aliases(self) -> list[str]: + return ["kagome", "kagomelattice"] + + @property + def ndim(self) -> int: + return 2 + + def setup(self, StdI: StdIntList) -> None: + """Delegate to the kagome() function.""" + _kagome_setup(StdI) + + def boost(self, StdI: StdIntList) -> None: + """Delegate to the kagome_boost() function.""" + _kagome_boost(StdI) + + +register_lattice(KagomePlugin()) diff --git a/python/lattice/ladder.py b/python/stdface/lattice/ladder.py similarity index 91% rename from python/lattice/ladder.py rename to python/stdface/lattice/ladder.py index abcc4e2..618761a 100644 --- a/python/lattice/ladder.py +++ b/python/stdface/lattice/ladder.py @@ -16,8 +16,8 @@ import numpy as np -from stdface_vals import StdIntList, ModelType -from param_check import ( +from ..core.stdface_vals import StdIntList, ModelType +from ..core.param_check import ( exit_program, print_val_d, print_val_i, not_used_j, not_used_d, not_used_i, required_val_i, ) @@ -280,3 +280,40 @@ def ladder_boost(StdI: StdIntList) -> None: StdI.list_6spin_pair[ipivot, :, 6] = [1, 2, 0, 3, 4, 5, 5] write_boost_6spin_pair(fp, StdI) + + +# --------------------------------------------------------------------------- +# Lattice plugin registration +# --------------------------------------------------------------------------- + +from . import LatticePlugin, register_lattice + +_ladder_setup = ladder +_ladder_boost = ladder_boost + + +class LadderPlugin(LatticePlugin): + """Plugin for the 1D ladder lattice.""" + + @property + def name(self) -> str: + return "ladder" + + @property + def aliases(self) -> list[str]: + return ["ladder", "ladderlattice"] + + @property + def ndim(self) -> int: + return 1 + + def setup(self, StdI: StdIntList) -> None: + """Delegate to the ladder() function.""" + _ladder_setup(StdI) + + def boost(self, StdI: StdIntList) -> None: + """Delegate to the ladder_boost() function.""" + _ladder_boost(StdI) + + +register_lattice(LadderPlugin()) diff --git a/python/lattice/orthorhombic.py b/python/stdface/lattice/orthorhombic.py similarity index 89% rename from python/lattice/orthorhombic.py rename to python/stdface/lattice/orthorhombic.py index 7f94020..490f032 100644 --- a/python/lattice/orthorhombic.py +++ b/python/stdface/lattice/orthorhombic.py @@ -16,8 +16,8 @@ import numpy as np -from stdface_vals import StdIntList, ModelType -from param_check import ( +from ..core.stdface_vals import StdIntList, ModelType +from ..core.param_check import ( print_val_d, print_val_c, print_val_i, not_used_j, not_used_d, not_used_i, ) @@ -181,3 +181,35 @@ def orthorhombic(StdI: StdIntList) -> None: StdI, iW, iL, iH, dW, dL, dH, si, sj, J, t, V) close_lattice_xsf(StdI) + + +# --------------------------------------------------------------------------- +# Lattice plugin registration +# --------------------------------------------------------------------------- + +from . import LatticePlugin, register_lattice + +_orthorhombic_setup = orthorhombic + + +class OrthorhombicPlugin(LatticePlugin): + """Plugin for the 3D simple orthorhombic lattice.""" + + @property + def name(self) -> str: + return "orthorhombic" + + @property + def aliases(self) -> list[str]: + return ["orthorhombic", "simpleorthorhombic", "cubic", "simplecubic"] + + @property + def ndim(self) -> int: + return 3 + + def setup(self, StdI: StdIntList) -> None: + """Delegate to the orthorhombic() function.""" + _orthorhombic_setup(StdI) + + +register_lattice(OrthorhombicPlugin()) diff --git a/python/lattice/pyrochlore.py b/python/stdface/lattice/pyrochlore.py similarity index 90% rename from python/lattice/pyrochlore.py rename to python/stdface/lattice/pyrochlore.py index d70367f..bb70809 100644 --- a/python/lattice/pyrochlore.py +++ b/python/stdface/lattice/pyrochlore.py @@ -16,8 +16,8 @@ import numpy as np -from stdface_vals import StdIntList, ModelType -from param_check import ( +from ..core.stdface_vals import StdIntList, ModelType +from ..core.param_check import ( print_val_d, print_val_i, not_used_j, not_used_d, not_used_i, ) @@ -196,3 +196,35 @@ def pyrochlore(StdI: StdIntList) -> None: StdI, iW, iL, iH, dW, dL, dH, si, sj, J, t, V) close_lattice_xsf(StdI) + + +# --------------------------------------------------------------------------- +# Lattice plugin registration +# --------------------------------------------------------------------------- + +from . import LatticePlugin, register_lattice + +_pyrochlore_setup = pyrochlore + + +class PyrochlorePlugin(LatticePlugin): + """Plugin for the 3D pyrochlore lattice.""" + + @property + def name(self) -> str: + return "pyrochlore" + + @property + def aliases(self) -> list[str]: + return ["pyrochlore"] + + @property + def ndim(self) -> int: + return 3 + + def setup(self, StdI: StdIntList) -> None: + """Delegate to the pyrochlore() function.""" + _pyrochlore_setup(StdI) + + +register_lattice(PyrochlorePlugin()) diff --git a/python/lattice/site_util.py b/python/stdface/lattice/site_util.py similarity index 94% rename from python/lattice/site_util.py rename to python/stdface/lattice/site_util.py index 77ff83c..73a0890 100644 --- a/python/lattice/site_util.py +++ b/python/stdface/lattice/site_util.py @@ -40,8 +40,8 @@ import numpy as np -from stdface_vals import StdIntList, ModelType, SolverType, NaN_i, AMPLITUDE_EPS -from param_check import exit_program, print_val_i +from ..core.stdface_vals import StdIntList, ModelType, SolverType, NaN_i, AMPLITUDE_EPS +from ..core.param_check import exit_program, print_val_i from .geometry_output import print_geometry, print_xsf @@ -67,11 +67,39 @@ def _cell_vector(Cell: np.ndarray, idx: int) -> list[int]: return [int(Cell[idx, 0]), int(Cell[idx, 1]), int(Cell[idx, 2])] +def _build_cell_map(StdI: StdIntList) -> dict[tuple[int, int, int], int]: + """Build or return a cached mapping from cell coordinates to index. + + The map is stored on ``StdI._cell_map`` and reused on subsequent + calls. It is invalidated whenever ``_enumerate_cells`` is called. + + Parameters + ---------- + StdI : StdIntList + Model parameter structure containing the ``Cell`` array and + ``NCell`` count. + + Returns + ------- + dict[tuple[int, int, int], int] + Mapping from ``(Cell[k,0], Cell[k,1], Cell[k,2])`` to ``k``. + """ + cell_map = getattr(StdI, '_cell_map', None) + if cell_map is not None: + return cell_map + cell_map = {} + for k in range(StdI.NCell): + key = (int(StdI.Cell[k, 0]), int(StdI.Cell[k, 1]), int(StdI.Cell[k, 2])) + if key not in cell_map: + cell_map[key] = k + StdI._cell_map = cell_map + return cell_map + + def _find_cell_index(StdI: StdIntList, cellV: list[int]) -> int: """Find the cell index whose coordinates match *cellV*. - Performs a linear search over ``StdI.Cell`` to find the row whose - three coordinates match *cellV*. + Uses a hash map for O(1) lookup instead of linear scan. Parameters ---------- @@ -87,10 +115,8 @@ def _find_cell_index(StdI: StdIntList, cellV: list[int]) -> int: Cell index ``k`` such that ``StdI.Cell[k]`` equals *cellV*. Returns 0 if no match is found (matching the original C behavior). """ - for k in range(StdI.NCell): - if _cell_vector(StdI.Cell, k) == cellV: - return k - return 0 + cell_map = _build_cell_map(StdI) + return cell_map.get((cellV[0], cellV[1], cellV[2]), 0) def _fold_to_cell( @@ -417,6 +443,7 @@ def _enumerate_cells(StdI: StdIntList) -> None: # Enumerate cells within the bounding box # Note: iteration order must be ic2 outermost, ic0 innermost (matching C code) StdI.Cell = np.zeros((StdI.NCell, 3), dtype=int) + StdI._cell_map = None # invalidate cell map cache jj_idx = 0 for ic2, ic1, ic0 in itertools.product( range(bound[2][0], bound[2][1] + 1), diff --git a/python/lattice/square_lattice.py b/python/stdface/lattice/square_lattice.py similarity index 89% rename from python/lattice/square_lattice.py rename to python/stdface/lattice/square_lattice.py index 8cb9709..7180f7c 100644 --- a/python/lattice/square_lattice.py +++ b/python/stdface/lattice/square_lattice.py @@ -16,8 +16,8 @@ import numpy as np -from stdface_vals import StdIntList, ModelType -from param_check import ( +from ..core.stdface_vals import StdIntList, ModelType +from ..core.param_check import ( print_val_d, print_val_i, not_used_j, not_used_d, not_used_i, ) @@ -169,3 +169,33 @@ def tetragonal(StdI: StdIntList) -> None: for dW, dL, si, sj, nn, J, t, V in _BONDS: add_neighbor_interaction( StdI, fp, iW, iL, dW, dL, si, sj, nn, J, t, V) + + +# --------------------------------------------------------------------------- +# Lattice plugin registration +# --------------------------------------------------------------------------- + +from . import LatticePlugin, register_lattice + + +class SquarePlugin(LatticePlugin): + """Plugin for the 2D tetragonal (square) lattice.""" + + @property + def name(self) -> str: + return "tetragonal" + + @property + def aliases(self) -> list[str]: + return ["tetragonal", "tetragonallattice", "square", "squarelattice"] + + @property + def ndim(self) -> int: + return 2 + + def setup(self, StdI: StdIntList) -> None: + """Delegate to the tetragonal() function.""" + tetragonal(StdI) + + +register_lattice(SquarePlugin()) diff --git a/python/lattice/triangular_lattice.py b/python/stdface/lattice/triangular_lattice.py similarity index 90% rename from python/lattice/triangular_lattice.py rename to python/stdface/lattice/triangular_lattice.py index 8cf086c..cfce8c6 100644 --- a/python/lattice/triangular_lattice.py +++ b/python/stdface/lattice/triangular_lattice.py @@ -18,8 +18,8 @@ import numpy as np -from stdface_vals import StdIntList, ModelType -from param_check import ( +from ..core.stdface_vals import StdIntList, ModelType +from ..core.param_check import ( print_val_d, print_val_i, not_used_j, not_used_d, not_used_i, ) @@ -190,3 +190,33 @@ def triangular(StdI: StdIntList) -> None: for dW, dL, si, sj, nn, J, t, V in _BONDS: add_neighbor_interaction( StdI, fp, iW, iL, dW, dL, si, sj, nn, J, t, V) + + +# --------------------------------------------------------------------------- +# Lattice plugin registration +# --------------------------------------------------------------------------- + +from . import LatticePlugin, register_lattice + + +class TriangularPlugin(LatticePlugin): + """Plugin for the 2D triangular lattice.""" + + @property + def name(self) -> str: + return "triangular" + + @property + def aliases(self) -> list[str]: + return ["triangular", "triangularlattice"] + + @property + def ndim(self) -> int: + return 2 + + def setup(self, StdI: StdIntList) -> None: + """Delegate to the triangular() function.""" + triangular(StdI) + + +register_lattice(TriangularPlugin()) diff --git a/python/lattice/wannier90.py b/python/stdface/lattice/wannier90.py similarity index 97% rename from python/lattice/wannier90.py rename to python/stdface/lattice/wannier90.py index ad0c5b0..26bf17a 100644 --- a/python/lattice/wannier90.py +++ b/python/stdface/lattice/wannier90.py @@ -26,8 +26,8 @@ import numpy as np -from stdface_vals import StdIntList, ModelType, SolverType, NaN_i, UNSET_STRING, AMPLITUDE_EPS -from param_check import exit_program, print_val_d, print_val_i, not_used_d +from ..core.stdface_vals import StdIntList, ModelType, SolverType, NaN_i, UNSET_STRING, AMPLITUDE_EPS +from ..core.param_check import exit_program, print_val_d, print_val_i, not_used_d from .geometry_output import print_geometry, print_xsf from .interaction_builder import ( malloc_interactions, mag_field, general_j, hubbard_local, hopping, coulomb, @@ -348,9 +348,8 @@ def _read_w90( Mat_tot[iWSC, iWan0 - 1, jWan0 - 1] = lam * (dtmp_re + 1j * dtmp_im) # Apply inversion symmetry and delete duplication - for jWSC in range(iWSC): - if np.all(indx_tot[iWSC] == -indx_tot[jWSC]): - Mat_tot[iWSC, :, :] = 0.0 + if iWSC > 0 and np.any(np.all(indx_tot[iWSC] == -indx_tot[:iWSC], axis=1)): + Mat_tot[iWSC, :, :] = 0.0 if np.all(indx_tot[iWSC] == 0): for iWan in range(StdI.NsiteUC): @@ -1316,3 +1315,35 @@ def wannier90(StdI: StdIntList) -> None: # Write wan2site.dat _write_wan2site(StdI) + + +# --------------------------------------------------------------------------- +# Lattice plugin registration +# --------------------------------------------------------------------------- + +from . import LatticePlugin, register_lattice + +_wannier90_setup = wannier90 + + +class Wannier90Plugin(LatticePlugin): + """Plugin for the Wannier90 interface.""" + + @property + def name(self) -> str: + return "wannier90" + + @property + def aliases(self) -> list[str]: + return ["wannier90"] + + @property + def ndim(self) -> int: + return 3 + + def setup(self, StdI: StdIntList) -> None: + """Delegate to the wannier90() function.""" + _wannier90_setup(StdI) + + +register_lattice(Wannier90Plugin()) diff --git a/python/stdface/plugin.py b/python/stdface/plugin.py new file mode 100644 index 0000000..67d7551 --- /dev/null +++ b/python/stdface/plugin.py @@ -0,0 +1,268 @@ +"""Solver plugin interface and registry. + +This module defines the :class:`SolverPlugin` abstract base class that every +solver plugin must implement, and a plugin registry for discovering and +retrieving plugins by name. + +New solver plugins should subclass :class:`SolverPlugin` and implement the +required abstract properties and methods. The :meth:`write` method provides +a template for the common output sequence (locspn → trans → interactions → +modpara → solver-specific → green → namelist); solvers with a different +sequence can override :meth:`write` entirely. +""" +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .core.stdface_vals import StdIntList + + +class SolverPlugin(ABC): + """Abstract base class for solver plugins. + + Each solver plugin provides: + + - A keyword table for parsing solver-specific input keywords. + - Field reset tables for initialising solver-specific fields on + :class:`StdIntList`. + - A :meth:`write` method to generate Expert-mode definition files. + - Optional hooks for post-lattice processing and field initialisation. + + Template Method + --------------- + The default :meth:`write` implementation calls a sequence of steps + that is common to most solvers:: + + write_locspn → write_trans → write_interactions → + check_and_write_modpara → write_solver_specific → + write_green → write_namelist + + Subclasses should override :meth:`write_solver_specific` for + solver-specific output (e.g. excitation files, variational parameters). + Solvers that need a completely different sequence (e.g. H-wave in + Wannier90 export mode) can override :meth:`write` directly. + + Attributes + ---------- + name : str + Canonical solver name (e.g. ``"HPhi"``). + keyword_table : dict[str, tuple] + Solver-specific keyword dispatch table. + reset_scalars : list[tuple[str, object]] + ``(field_name, sentinel_value)`` pairs for scalar field resets. + reset_arrays : list[tuple[str, object]] + ``(field_name, sentinel_value)`` pairs for array-fill field resets. + """ + + # ------------------------------------------------------------------ + # Abstract properties — must be implemented by every plugin + # ------------------------------------------------------------------ + + @property + @abstractmethod + def name(self) -> str: + """Canonical solver name.""" + + @property + @abstractmethod + def keyword_table(self) -> dict[str, tuple]: + """Solver-specific keyword dispatch table.""" + + @property + @abstractmethod + def reset_scalars(self) -> list[tuple[str, object]]: + """Scalar field reset table.""" + + @property + @abstractmethod + def reset_arrays(self) -> list[tuple[str, object]]: + """Array-fill field reset table.""" + + # ------------------------------------------------------------------ + # Template method for writing output files + # ------------------------------------------------------------------ + + def write(self, StdI: StdIntList) -> None: + """Write all Expert-mode definition files for this solver. + + This is a template method that calls the common output steps in + order. Override individual steps or this entire method as needed. + + Parameters + ---------- + StdI : StdIntList + The fully-populated parameter structure. + """ + self.write_locspn(StdI) + self.write_trans(StdI) + self.write_interactions(StdI) + self.check_and_write_modpara(StdI) + self.write_solver_specific(StdI) + self.check_output_mode(StdI) + self.write_green(StdI) + self.write_namelist(StdI) + + # ------------------------------------------------------------------ + # Default step implementations — override as needed + # ------------------------------------------------------------------ + + def write_locspn(self, StdI: StdIntList) -> None: + """Write locspn.def.""" + from .writer.common_writer import print_loc_spin + print_loc_spin(StdI) + + def write_trans(self, StdI: StdIntList) -> None: + """Write trans.def.""" + from .writer.common_writer import print_trans + print_trans(StdI) + + def write_interactions(self, StdI: StdIntList) -> None: + """Write interaction definition files.""" + from .writer.interaction_writer import print_interactions + print_interactions(StdI) + + def check_and_write_modpara(self, StdI: StdIntList) -> None: + """Check and write modpara.def.""" + from .writer.common_writer import check_mod_para, print_mod_para + check_mod_para(StdI) + print_mod_para(StdI) + + def write_solver_specific(self, StdI: StdIntList) -> None: + """Write solver-specific output files. + + Override this method to add solver-specific output (e.g. excitation + files for HPhi, variational parameter files for mVMC). + The default implementation does nothing. + + Parameters + ---------- + StdI : StdIntList + The fully-populated parameter structure. + """ + + def check_output_mode(self, StdI: StdIntList) -> None: + """Check and validate the output mode setting.""" + from .writer.common_writer import check_output_mode + check_output_mode(StdI) + + def write_green(self, StdI: StdIntList) -> None: + """Write Green's function definition files. + + Default writes both greenone.def and greentwo.def. + Override to write only a subset. + + Parameters + ---------- + StdI : StdIntList + The fully-populated parameter structure. + """ + from .writer.common_writer import print_1_green, print_2_green + print_1_green(StdI) + print_2_green(StdI) + + def write_namelist(self, StdI: StdIntList) -> None: + """Write namelist.def.""" + from .writer.common_writer import print_namelist + print_namelist(StdI) + + # ------------------------------------------------------------------ + # Optional lifecycle hooks + # ------------------------------------------------------------------ + + def post_lattice(self, StdI: StdIntList) -> None: + """Optional hook called after lattice construction. + + Override this to perform solver-specific post-lattice processing + (e.g. computing LargeValue for HPhi, running Boost). + The default implementation does nothing. + + Parameters + ---------- + StdI : StdIntList + The parameter structure (modified in place). + """ + + def init_fields(self, StdI: StdIntList) -> None: + """Optional hook to initialise solver-specific attributes on StdI. + + Override this to set up dynamic attributes that don't exist on the + base :class:`StdIntList` dataclass. The default implementation + does nothing. + + Parameters + ---------- + StdI : StdIntList + The parameter structure (modified in place). + """ + + +# --------------------------------------------------------------------------- +# Plugin registry +# --------------------------------------------------------------------------- + +_plugins: dict[str, SolverPlugin] = {} + + +def register(plugin: SolverPlugin) -> None: + """Register a solver plugin. + + Parameters + ---------- + plugin : SolverPlugin + The plugin instance to register. + + Raises + ------ + ValueError + If a plugin with the same name is already registered. + """ + if plugin.name in _plugins: + raise ValueError( + f"Solver plugin {plugin.name!r} is already registered" + ) + _plugins[plugin.name] = plugin + + +def get_plugin(name: str) -> SolverPlugin: + """Retrieve a registered solver plugin by name. + + Parameters + ---------- + name : str + The solver name (e.g. ``"HPhi"``). + + Returns + ------- + SolverPlugin + The registered plugin instance. + + Raises + ------ + KeyError + If no plugin with that name is registered. + """ + if name not in _plugins: + _discover_plugins() + if name not in _plugins: + raise KeyError( + f"No solver plugin registered for {name!r}. " + f"Available: {', '.join(_plugins) or '(none)'}" + ) + return _plugins[name] + + +def _discover_plugins() -> None: + """Load built-in plugins by importing the solvers package. + + This function is called lazily on first access to ensure plugins + are registered even if the solvers package hasn't been imported yet. + """ + try: + import stdface.solvers # noqa: F401 — triggers auto-registration + except ImportError: + # Solvers package not available (e.g., minimal install, missing + # dependencies, or corrupted package). Plugin lookup will fail + # with a helpful KeyError listing available plugins. + pass diff --git a/python/stdface/solvers/__init__.py b/python/stdface/solvers/__init__.py new file mode 100644 index 0000000..6e943ac --- /dev/null +++ b/python/stdface/solvers/__init__.py @@ -0,0 +1,7 @@ +"""Built-in solver plugins for StdFace. + +Importing this package auto-registers all built-in solver plugins. +""" +from __future__ import annotations + +from . import hphi, mvmc, uhf, hwave # noqa: F401 — auto-register plugins diff --git a/python/stdface/solvers/hphi/__init__.py b/python/stdface/solvers/hphi/__init__.py new file mode 100644 index 0000000..646ae6f --- /dev/null +++ b/python/stdface/solvers/hphi/__init__.py @@ -0,0 +1,9 @@ +"""HPhi solver plugin package. + +Importing this package auto-registers the HPhi plugin. +""" +from __future__ import annotations + +from ._plugin import HPhiPlugin # noqa: F401 — public API + +__all__ = ["HPhiPlugin"] diff --git a/python/stdface/solvers/hphi/_plugin.py b/python/stdface/solvers/hphi/_plugin.py new file mode 100644 index 0000000..dd878e6 --- /dev/null +++ b/python/stdface/solvers/hphi/_plugin.py @@ -0,0 +1,143 @@ +"""HPhi solver plugin. + +Encapsulates all HPhi-specific keyword parsing, field reset tables, +post-lattice hooks (LargeValue, Boost), and Expert-mode file writing. +""" +from __future__ import annotations + +from ...plugin import SolverPlugin, register +from ...core.stdface_vals import StdIntList, SolverType, MethodType, NaN_i, NaN_d +from ...core.keyword_parser import ( + store_with_check_dup_s, store_with_check_dup_sl, + store_with_check_dup_i, store_with_check_dup_d, +) +from .writer import large_value, print_calc_mod, print_excitation, print_pump + + +class HPhiPlugin(SolverPlugin): + """Plugin for the HPhi exact-diagonalisation / Lanczos solver.""" + + @property + def name(self) -> str: + return SolverType.HPhi + + @property + def keyword_table(self) -> dict[str, tuple]: + return _HPHI_KEYWORDS + + @property + def reset_scalars(self) -> list[tuple[str, object]]: + return _RESET_SCALARS + + @property + def reset_arrays(self) -> list[tuple[str, object]]: + return _RESET_ARRAYS + + def post_lattice(self, StdI: StdIntList) -> None: + """Compute LargeValue and optionally run Boost builder.""" + from ...lattice import get_lattice + from ...writer.common_writer import unsupported_system + + large_value(StdI) + + if StdI.lBoost == 1: + try: + lattice_plugin = get_lattice(StdI.lattice) + except KeyError: + unsupported_system(StdI.model, StdI.lattice) + else: + lattice_plugin.boost(StdI) + + def write_solver_specific(self, StdI: StdIntList) -> None: + """Write HPhi-specific files (excitation, pump, calcmod).""" + print_excitation(StdI) + if StdI.method == MethodType.TIME_EVOLUTION: + print_pump(StdI) + print_calc_mod(StdI) + + +# ----------------------------------------------------------------------- +# Keyword table +# ----------------------------------------------------------------------- + +_HPHI_KEYWORDS: dict[str, tuple] = { + "calcspec": (store_with_check_dup_sl, "CalcSpec"), + "exct": (store_with_check_dup_i, "exct"), + "eigenvecio": (store_with_check_dup_sl, "EigenVecIO"), + "expandcoef": (store_with_check_dup_i, "ExpandCoef"), + "expecinterval": (store_with_check_dup_i, "ExpecInterval"), + "cdatafilehead": (store_with_check_dup_s, "CDataFileHead"), + "dt": (store_with_check_dup_d, "dt"), + "flgtemp": (store_with_check_dup_i, "FlgTemp"), + "freq": (store_with_check_dup_d, "freq"), + "hamio": (store_with_check_dup_sl, "HamIO"), + "initialvectype": (store_with_check_dup_sl, "InitialVecType"), + "initial_iv": (store_with_check_dup_i, "initial_iv"), + "lanczoseps": (store_with_check_dup_i, "LanczosEps"), + "lanczostarget": (store_with_check_dup_i, "LanczosTarget"), + "lanczos_max": (store_with_check_dup_i, "Lanczos_max"), + "largevalue": (store_with_check_dup_d, "LargeValue"), + "method": (store_with_check_dup_sl, "method"), + "nomega": (store_with_check_dup_i, "Nomega"), + "numave": (store_with_check_dup_i, "NumAve"), + "nvec": (store_with_check_dup_i, "nvec"), + "omegamax": (store_with_check_dup_d, "OmegaMax"), + "omegamin": (store_with_check_dup_d, "OmegaMin"), + "omegaorg": (store_with_check_dup_d, "OmegaOrg"), + "omegaim": (store_with_check_dup_d, "OmegaIm"), + "outputexcitedvec": (store_with_check_dup_sl, "OutputExVec"), + "pumptype": (store_with_check_dup_sl, "PumpType"), + "restart": (store_with_check_dup_sl, "Restart"), + "spectrumqh": (store_with_check_dup_d, "SpectrumQ", 2, float), + "spectrumql": (store_with_check_dup_d, "SpectrumQ", 1, float), + "spectrumqw": (store_with_check_dup_d, "SpectrumQ", 0, float), + "spectrumtype": (store_with_check_dup_sl, "SpectrumType"), + "tdump": (store_with_check_dup_d, "tdump"), + "tshift": (store_with_check_dup_d, "tshift"), + "uquench": (store_with_check_dup_d, "Uquench"), + "vecpoth": (store_with_check_dup_d, "VecPot", 2, float), + "vecpotl": (store_with_check_dup_d, "VecPot", 1, float), + "vecpotw": (store_with_check_dup_d, "VecPot", 0, float), + "2s": (store_with_check_dup_i, "S2"), + "ngpu": (store_with_check_dup_i, "NGPU"), + "scalapack": (store_with_check_dup_i, "Scalapack"), +} + +# ----------------------------------------------------------------------- +# Reset tables +# ----------------------------------------------------------------------- + +_RESET_SCALARS: list[tuple[str, object]] = [ + ("LargeValue", NaN_d), + ("OmegaMax", NaN_d), + ("OmegaMin", NaN_d), + ("OmegaOrg", NaN_d), + ("OmegaIm", NaN_d), + ("Nomega", NaN_i), + ("FlgTemp", 1), + ("Lanczos_max", NaN_i), + ("initial_iv", NaN_i), + ("nvec", NaN_i), + ("exct", NaN_i), + ("LanczosEps", NaN_i), + ("LanczosTarget", NaN_i), + ("NumAve", NaN_i), + ("ExpecInterval", NaN_i), + ("dt", NaN_d), + ("tdump", NaN_d), + ("tshift", NaN_d), + ("freq", NaN_d), + ("Uquench", NaN_d), + ("ExpandCoef", NaN_i), + ("NGPU", NaN_i), + ("Scalapack", NaN_i), +] + +_RESET_ARRAYS: list[tuple[str, object]] = [ + ("SpectrumQ", NaN_d), + ("VecPot", NaN_d), +] + + +# Auto-register on import +register(HPhiPlugin()) diff --git a/python/writer/hphi_writer.py b/python/stdface/solvers/hphi/writer.py similarity index 91% rename from python/writer/hphi_writer.py rename to python/stdface/solvers/hphi/writer.py index 105d42c..e0a97b8 100644 --- a/python/writer/hphi_writer.py +++ b/python/stdface/solvers/hphi/writer.py @@ -35,9 +35,9 @@ import numpy as np -from stdface_vals import StdIntList, ModelType, MethodType, NaN_i, UNSET_STRING, AMPLITUDE_EPS -from param_check import exit_program, print_val_d, print_val_i -from .common_writer import _merge_duplicate_terms +from ...core.stdface_vals import StdIntList, ModelType, MethodType, NaN_i, UNSET_STRING, AMPLITUDE_EPS +from ...core.param_check import exit_program, print_val_d, print_val_i +from ...writer.common_writer import _merge_duplicate_terms # --------------------------------------------------------------------------- # String → integer dispatch tables for calcmod.def @@ -715,38 +715,40 @@ def _write_excitation_file( Imaginary parts of Fourier coefficients, length ``nsite``. """ if StdI.SpectrumBody == 1: - with open("single.def", "w") as fp: - fp.write("=============================================\n") - if StdI.model == ModelType.KONDO: - fp.write(f"NSingle {StdI.nsite // 2 * NumOp}\n") - else: - fp.write(f"NSingle {StdI.nsite * NumOp}\n") - fp.write("=============================================\n") - fp.write("============== Single Excitation ============\n") - fp.write("=============================================\n") - if StdI.model == ModelType.KONDO: - for isite in range(StdI.nsite // 2, StdI.nsite): - fp.write(f"{isite} {spin[0][0]} 0 " + lines = ["=============================================\n"] + if StdI.model == ModelType.KONDO: + lines.append(f"NSingle {StdI.nsite // 2 * NumOp}\n") + else: + lines.append(f"NSingle {StdI.nsite * NumOp}\n") + lines.append("=============================================\n") + lines.append("============== Single Excitation ============\n") + lines.append("=============================================\n") + if StdI.model == ModelType.KONDO: + for isite in range(StdI.nsite // 2, StdI.nsite): + lines.append(f"{isite} {spin[0][0]} 0 " f"{fourier_r[isite] * coef[0]:25.15f} " f"{fourier_i[isite] * coef[0]:25.15f}\n") - else: - for isite in range(StdI.nsite): - fp.write(f"{isite} {spin[0][0]} 0 " + else: + for isite in range(StdI.nsite): + lines.append(f"{isite} {spin[0][0]} 0 " f"{fourier_r[isite] * coef[0]:25.15f} " f"{fourier_i[isite] * coef[0]:25.15f}\n") + with open("single.def", "w") as fp: + fp.write("".join(lines)) print(" single.def is written.\n") else: - with open("pair.def", "w") as fp: - fp.write("=============================================\n") - fp.write(f"NPair {StdI.nsite * NumOp}\n") - fp.write("=============================================\n") - fp.write("=============== Pair Excitation =============\n") - fp.write("=============================================\n") - for isite in range(StdI.nsite): - for ispin in range(NumOp): - fp.write(f"{isite} {spin[ispin][0]} {isite} {spin[ispin][1]} 1 " + lines = ["=============================================\n", + f"NPair {StdI.nsite * NumOp}\n", + "=============================================\n", + "=============== Pair Excitation =============\n", + "=============================================\n"] + for isite in range(StdI.nsite): + for ispin in range(NumOp): + lines.append(f"{isite} {spin[ispin][0]} {isite} {spin[ispin][1]} 1 " f"{fourier_r[isite] * coef[ispin]:25.15f} " f"{fourier_i[isite] * coef[ispin]:25.15f}\n") + with open("pair.def", "w") as fp: + fp.write("".join(lines)) print(" pair.def is written.\n") @@ -1022,13 +1024,14 @@ def vector_potential(StdI: StdIntList) -> None: # Write potential.dat for one-body pump # ------------------------------------------------------------------ if StdI.PumpBody == 1: - with open("potential.dat", "w") as fp: - fp.write("# Time A_W A_L A_H E_W E_L E_H\n") - for it in range(StdI.Lanczos_max): - time = StdI.dt * float(it) - fp.write(f"{time:f} " + lines = ["# Time A_W A_L A_H E_W E_L E_H\n"] + for it in range(StdI.Lanczos_max): + time = StdI.dt * float(it) + lines.append(f"{time:f} " f"{StdI.At[it][0]:f} {StdI.At[it][1]:f} {StdI.At[it][2]:f} " f"{Et[it][0]:f} {Et[it][1]:f} {Et[it][2]:f}\n") + with open("potential.dat", "w") as fp: + fp.write("".join(lines)) def print_pump(StdI: StdIntList) -> None: @@ -1057,44 +1060,46 @@ def print_pump(StdI: StdIntList) -> None: - ``Uquench`` -- quench interaction strength. """ if StdI.PumpBody == 1: - with open("teone.def", "w") as fp: - fp.write("=============================================\n") - fp.write(f"AllTimeStep {StdI.Lanczos_max}\n") - fp.write("=============================================\n") - fp.write("========= OneBody Time Evolution ==========\n") - fp.write("=============================================\n") - - for it in range(StdI.Lanczos_max): - npump0 = _merge_duplicate_terms( - StdI.pumpindx[it], StdI.pump[it], StdI.npump[it]) - - fp.write(f"{StdI.dt * float(it):f} {npump0}\n") - - for ipump in range(StdI.npump[it]): - val = StdI.pump[it][ipump] - if abs(val) <= AMPLITUDE_EPS: - continue - i0, s0, i1, s1 = StdI.pumpindx[it][ipump] - fp.write( - f"{i0:5d} {s0:5d} {i1:5d} {s1:5d} " - f"{val.real:25.15f} {val.imag:25.15f}\n" - ) + lines = ["=============================================\n", + f"AllTimeStep {StdI.Lanczos_max}\n", + "=============================================\n", + "========= OneBody Time Evolution ==========\n", + "=============================================\n"] + for it in range(StdI.Lanczos_max): + npump0 = _merge_duplicate_terms( + StdI.pumpindx[it], StdI.pump[it], StdI.npump[it]) + + lines.append(f"{StdI.dt * float(it):f} {npump0}\n") + + for ipump in range(StdI.npump[it]): + val = StdI.pump[it][ipump] + if abs(val) <= AMPLITUDE_EPS: + continue + i0, s0, i1, s1 = StdI.pumpindx[it][ipump] + lines.append( + f"{i0:5d} {s0:5d} {i1:5d} {s1:5d} " + f"{val.real:25.15f} {val.imag:25.15f}\n" + ) + + with open("teone.def", "w") as fp: + fp.write("".join(lines)) print(" teone.def is written.\n") else: - with open("tetwo.def", "w") as fp: - fp.write("=============================================\n") - fp.write(f"AllTimeStep {StdI.Lanczos_max}\n") - fp.write("=============================================\n") - fp.write("========== TwoBody Time Evolution ===========\n") - fp.write("=============================================\n") - - for it in range(StdI.Lanczos_max): - fp.write(f"{StdI.dt * float(it):f} {StdI.nsite}\n") - for isite in range(StdI.nsite): - fp.write(f"{isite:5d} {0:5d} {isite:5d} {0:5d} " + lines = ["=============================================\n", + f"AllTimeStep {StdI.Lanczos_max}\n", + "=============================================\n", + "========== TwoBody Time Evolution ===========\n", + "=============================================\n"] + + for it in range(StdI.Lanczos_max): + lines.append(f"{StdI.dt * float(it):f} {StdI.nsite}\n") + for isite in range(StdI.nsite): + lines.append(f"{isite:5d} {0:5d} {isite:5d} {0:5d} " f"{isite:5d} {1:5d} {isite:5d} {1:5d} " f"{StdI.Uquench:25.15f} {0.0:25.15f}\n") + with open("tetwo.def", "w") as fp: + fp.write("".join(lines)) print(" tetwo.def is written.\n") diff --git a/python/stdface/solvers/hwave/__init__.py b/python/stdface/solvers/hwave/__init__.py new file mode 100644 index 0000000..964177a --- /dev/null +++ b/python/stdface/solvers/hwave/__init__.py @@ -0,0 +1,9 @@ +"""H-wave solver plugin package. + +Importing this package auto-registers the H-wave plugin. +""" +from __future__ import annotations + +from ._plugin import HWavePlugin # noqa: F401 — public API + +__all__ = ["HWavePlugin"] diff --git a/python/stdface/solvers/hwave/_plugin.py b/python/stdface/solvers/hwave/_plugin.py new file mode 100644 index 0000000..0a2754a --- /dev/null +++ b/python/stdface/solvers/hwave/_plugin.py @@ -0,0 +1,114 @@ +"""H-wave solver plugin. + +Encapsulates all H-wave-specific keyword parsing, field reset tables, +and Expert-mode file writing. +""" +from __future__ import annotations + +from ...plugin import SolverPlugin, register +from ...core.stdface_vals import StdIntList, SolverType, NaN_i, NaN_d +from ...core.keyword_parser import ( + store_with_check_dup_i, store_with_check_dup_d, store_with_check_dup_sl, + _grid3x3_keywords, +) +from .export_wannier90 import export_geometry, export_interaction + + +class HWavePlugin(SolverPlugin): + """Plugin for the H-wave solver.""" + + @property + def name(self) -> str: + return SolverType.HWAVE + + @property + def keyword_table(self) -> dict[str, tuple]: + return _HWAVE_KEYWORDS + + @property + def reset_scalars(self) -> list[tuple[str, object]]: + return _RESET_SCALARS + + @property + def reset_arrays(self) -> list[tuple[str, object]]: + return _RESET_ARRAYS + + def write(self, StdI: StdIntList) -> None: + """Write H-wave output files. + + Overrides the template method entirely because H-wave has two + completely different output modes (uhfr vs wannier90 export). + """ + from ...writer.common_writer import ( + print_trans, print_1_green, check_output_mode, check_mod_para, + ) + from ...writer.interaction_writer import print_interactions + + if StdI.calcmode == "uhfr": + print_trans(StdI) + print_interactions(StdI) + check_mod_para(StdI) + check_output_mode(StdI) + print_1_green(StdI) + else: + export_geometry(StdI) + export_interaction(StdI) + + +# ----------------------------------------------------------------------- +# Keyword table +# ----------------------------------------------------------------------- + +_BOXSUB_KEYWORDS: dict[str, tuple] = _grid3x3_keywords( + "{a}{c}sub", "boxsub", store_with_check_dup_i, int +) + +_UHF_BASE_KEYWORDS: dict[str, tuple] = { + "iteration_max": (store_with_check_dup_i, "Iteration_max"), + "rndseed": (store_with_check_dup_i, "RndSeed"), + "nmptrans": (store_with_check_dup_i, "NMPTrans"), + **_BOXSUB_KEYWORDS, + "hsub": (store_with_check_dup_i, "Hsub"), + "lsub": (store_with_check_dup_i, "Lsub"), + "wsub": (store_with_check_dup_i, "Wsub"), + "eps": (store_with_check_dup_i, "eps"), + "epsslater": (store_with_check_dup_i, "eps_slater"), + "mix": (store_with_check_dup_d, "mix"), +} + +_HWAVE_KEYWORDS: dict[str, tuple] = { + **_UHF_BASE_KEYWORDS, + "calcmode": (store_with_check_dup_sl, "calcmode"), + "fileprefix": (store_with_check_dup_sl, "fileprefix"), + "exportall": (store_with_check_dup_i, "export_all"), + "lattice_gp": (store_with_check_dup_i, "lattice_gp"), +} + +# ----------------------------------------------------------------------- +# Reset tables +# ----------------------------------------------------------------------- + +_UHF_BASE_SCALARS: list[tuple[str, object]] = [ + ("NMPTrans", NaN_i), + ("RndSeed", NaN_i), + ("mix", NaN_d), + ("eps", NaN_i), + ("eps_slater", NaN_i), + ("Iteration_max", NaN_i), + ("Hsub", NaN_i), + ("Lsub", NaN_i), + ("Wsub", NaN_i), +] + +_RESET_SCALARS: list[tuple[str, object]] = _UHF_BASE_SCALARS + [ + ("export_all", NaN_i), + ("lattice_gp", NaN_i), +] + +_RESET_ARRAYS: list[tuple[str, object]] = [ + ("boxsub", NaN_i), +] + + +# Auto-register on import +register(HWavePlugin()) diff --git a/python/writer/export_wannier90.py b/python/stdface/solvers/hwave/export_wannier90.py similarity index 99% rename from python/writer/export_wannier90.py rename to python/stdface/solvers/hwave/export_wannier90.py index a298585..2e06c61 100644 --- a/python/writer/export_wannier90.py +++ b/python/stdface/solvers/hwave/export_wannier90.py @@ -27,9 +27,9 @@ import numpy as np -from stdface_vals import StdIntList, NaN_i, UNSET_STRING -from param_check import exit_program -from lattice.site_util import _cell_vector +from ...core.stdface_vals import StdIntList, NaN_i, UNSET_STRING +from ...core.param_check import exit_program +from ...lattice.site_util import _cell_vector # ----------------------------------------------------------------------- # Module-level constants diff --git a/python/stdface/solvers/mvmc/__init__.py b/python/stdface/solvers/mvmc/__init__.py new file mode 100644 index 0000000..39ffee6 --- /dev/null +++ b/python/stdface/solvers/mvmc/__init__.py @@ -0,0 +1,9 @@ +"""mVMC solver plugin package. + +Importing this package auto-registers the mVMC plugin. +""" +from __future__ import annotations + +from ._plugin import MVMCPlugin # noqa: F401 — public API + +__all__ = ["MVMCPlugin"] diff --git a/python/stdface/solvers/mvmc/_plugin.py b/python/stdface/solvers/mvmc/_plugin.py new file mode 100644 index 0000000..9fa7916 --- /dev/null +++ b/python/stdface/solvers/mvmc/_plugin.py @@ -0,0 +1,127 @@ +"""mVMC solver plugin. + +Encapsulates all mVMC-specific keyword parsing, field reset tables, +and Expert-mode file writing. +""" +from __future__ import annotations + +from ...plugin import SolverPlugin, register +from ...core.stdface_vals import StdIntList, SolverType, NaN_i, NaN_d +from ...core.keyword_parser import ( + store_with_check_dup_s, store_with_check_dup_i, store_with_check_dup_d, + _grid3x3_keywords, +) +from ...core.param_check import print_val_i +from .variational import generate_orb, proj, print_jastrow +from .writer import print_orb, print_orb_para, print_gutzwiller + + +class MVMCPlugin(SolverPlugin): + """Plugin for the mVMC variational Monte Carlo solver.""" + + @property + def name(self) -> str: + return SolverType.mVMC + + @property + def keyword_table(self) -> dict[str, tuple]: + return _MVMC_KEYWORDS + + @property + def reset_scalars(self) -> list[tuple[str, object]]: + return _RESET_SCALARS + + @property + def reset_arrays(self) -> list[tuple[str, object]]: + return _RESET_ARRAYS + + def write_solver_specific(self, StdI: StdIntList) -> None: + """Write mVMC-specific variational parameter files.""" + if StdI.lGC == 0 and (StdI.Sz2 == 0 or StdI.Sz2 == NaN_i): + StdI.ComplexType = print_val_i("ComplexType", StdI.ComplexType, 0) + else: + StdI.ComplexType = print_val_i("ComplexType", StdI.ComplexType, 1) + + generate_orb(StdI) + proj(StdI) + print_jastrow(StdI) + if StdI.lGC == 1 or (StdI.Sz2 != 0 and StdI.Sz2 != NaN_i): + print_orb_para(StdI) + print_gutzwiller(StdI) + print_orb(StdI) + + +# ----------------------------------------------------------------------- +# Keyword table +# ----------------------------------------------------------------------- + +_BOXSUB_KEYWORDS: dict[str, tuple] = _grid3x3_keywords( + "{a}{c}sub", "boxsub", store_with_check_dup_i, int +) + +_MVMC_KEYWORDS: dict[str, tuple] = { + **_BOXSUB_KEYWORDS, + "complextype": (store_with_check_dup_i, "ComplexType"), + "cparafilehead": (store_with_check_dup_s, "CParaFileHead"), + "dsroptredcut": (store_with_check_dup_d, "DSROptRedCut"), + "dsroptstadel": (store_with_check_dup_d, "DSROptStaDel"), + "dsroptstepdt": (store_with_check_dup_d, "DSROptStepDt"), + "hsub": (store_with_check_dup_i, "Hsub"), + "lsub": (store_with_check_dup_i, "Lsub"), + "nvmccalmode": (store_with_check_dup_i, "NVMCCalMode"), + "ndataidxstart": (store_with_check_dup_i, "NDataIdxStart"), + "ndataqtysmp": (store_with_check_dup_i, "NDataQtySmp"), + "nlanczosmode": (store_with_check_dup_i, "NLanczosMode"), + "nmptrans": (store_with_check_dup_i, "NMPTrans"), + "nspgaussleg": (store_with_check_dup_i, "NSPGaussLeg"), + "nsplitsize": (store_with_check_dup_i, "NSplitSize"), + "nspstot": (store_with_check_dup_i, "NSPStot"), + "nsroptitrsmp": (store_with_check_dup_i, "NSROptItrSmp"), + "nsroptitrstep": (store_with_check_dup_i, "NSROptItrStep"), + "nstore": (store_with_check_dup_i, "NStore"), + "nsrcg": (store_with_check_dup_i, "NSRCG"), + "nvmcinterval": (store_with_check_dup_i, "NVMCInterval"), + "nvmcsample": (store_with_check_dup_i, "NVMCSample"), + "nvmcwarmup": (store_with_check_dup_i, "NVMCWarmUp"), + "rndseed": (store_with_check_dup_i, "RndSeed"), + "wsub": (store_with_check_dup_i, "Wsub"), +} + +# ----------------------------------------------------------------------- +# Reset tables +# ----------------------------------------------------------------------- + +_RESET_SCALARS: list[tuple[str, object]] = [ + ("NVMCCalMode", NaN_i), + ("NLanczosMode", NaN_i), + ("NDataIdxStart", NaN_i), + ("NDataQtySmp", NaN_i), + ("NSPGaussLeg", NaN_i), + ("NSPStot", NaN_i), + ("NMPTrans", NaN_i), + ("NSROptItrStep", NaN_i), + ("NSROptItrSmp", NaN_i), + ("DSROptRedCut", NaN_d), + ("DSROptStaDel", NaN_d), + ("DSROptStepDt", NaN_d), + ("NVMCWarmUp", NaN_i), + ("NVMCInterval", NaN_i), + ("NVMCSample", NaN_i), + ("NExUpdatePath", NaN_i), + ("RndSeed", NaN_i), + ("NSplitSize", NaN_i), + ("NStore", NaN_i), + ("NSRCG", NaN_i), + ("ComplexType", NaN_i), + ("Hsub", NaN_i), + ("Lsub", NaN_i), + ("Wsub", NaN_i), +] + +_RESET_ARRAYS: list[tuple[str, object]] = [ + ("boxsub", NaN_i), +] + + +# Auto-register on import +register(MVMCPlugin()) diff --git a/python/writer/mvmc_variational.py b/python/stdface/solvers/mvmc/variational.py similarity index 84% rename from python/writer/mvmc_variational.py rename to python/stdface/solvers/mvmc/variational.py index 1882233..9e7b5fc 100644 --- a/python/writer/mvmc_variational.py +++ b/python/stdface/solvers/mvmc/variational.py @@ -25,9 +25,9 @@ import numpy as np -from stdface_vals import StdIntList, ModelType, NaN_i -from param_check import exit_program -from lattice.site_util import ( +from ...core.stdface_vals import StdIntList, ModelType, NaN_i +from ...core.param_check import exit_program +from ...lattice.site_util import ( _cell_vector, _fold_to_cell, _fold_site, _find_cell_index, _validate_box_params, _det_and_cofactor, find_site, ) @@ -161,17 +161,20 @@ def proj(StdI: StdIntList) -> None: StdI.NSym += 1 with open("qptransidx.def", "w") as fp: - fp.write("=============================================\n") - fp.write(f"NQPTrans {StdI.NSym:10d}\n") - fp.write("=============================================\n") - fp.write("======== TrIdx_TrWeight_and_TrIdx_i_xi ======\n") - fp.write("=============================================\n") + lines = [ + "=============================================\n", + f"NQPTrans {StdI.NSym:10d}\n", + "=============================================\n", + "======== TrIdx_TrWeight_and_TrIdx_i_xi ======\n", + "=============================================\n", + ] for iSym in range(StdI.NSym): - fp.write(f"{iSym} {1.0:10.5f}\n") + lines.append(f"{iSym} {1.0:10.5f}\n") for iSym in range(StdI.NSym): for jsite in range(StdI.nsite): a = _parity_sign(Anti[iSym][jsite]) - fp.write(f"{iSym:5d} {jsite:5d} {Sym[iSym][jsite]:5d} {a:5d}\n") + lines.append(f"{iSym:5d} {jsite:5d} {Sym[iSym][jsite]:5d} {a:5d}\n") + fp.write("".join(lines)) print(" qptransidx.def is written.") @@ -276,16 +279,23 @@ def generate_orb(StdI: StdIntList) -> None: StdI.AntiOrb = np.zeros((StdI.nsite, StdI.nsite), dtype=int) CellDone = np.zeros((StdI.NCell, StdI.NCell), dtype=int) + # Pre-compute cell vectors and sector list outside loops + cell_vectors = [_cell_vector(StdI.Cell, k) for k in range(StdI.NCell)] + sectors = [(0, 0)] + if StdI.model == ModelType.KONDO: + half = StdI.nsite // 2 + sectors += [(half, 0), (0, half), (half, half)] + iOrb = 0 for iCell in range(StdI.NCell): - iCV = _cell_vector(StdI.Cell, iCell) + iCV = cell_vectors[iCell] nBox, iCellV = _fold_site_sub(StdI, iCV) nBox, iCellV = _fold_site(StdI, iCellV) iCell2 = _find_cell_index(StdI, iCellV) for jCell in range(StdI.NCell): - jCV = _cell_vector(StdI.Cell, jCell) + jCV = cell_vectors[jCell] jCellV = [jc + v - ic for jc, v, ic in zip(jCV, iCellV, iCV)] nBox, jCellV = _fold_site(StdI, jCellV) @@ -297,12 +307,6 @@ def generate_orb(StdI: StdIntList) -> None: anti_val = _parity_sign( _anti_period_dot(StdI.AntiPeriod, nBox_d)) - # Build list of (i_offset, j_offset) sector pairs - sectors = [(0, 0)] - if StdI.model == ModelType.KONDO: - half = StdI.nsite // 2 - sectors += [(half, 0), (0, half), (half, half)] - for isite in range(StdI.NsiteUC): for jsite in range(StdI.NsiteUC): for i_off, j_off in sectors: @@ -347,17 +351,24 @@ def _jastrow_momentum_projected( Jastrow : numpy.ndarray Final renumbered Jastrow index matrix. """ - # (1) Copy Orbital index - for isite in range(StdI.nsite): - for jsite in range(StdI.nsite): - Jastrow[isite, jsite] = StdI.Orb[isite, jsite] + # (1) Copy Orbital index (vectorised) + Jastrow[:, :] = StdI.Orb[:StdI.nsite, :StdI.nsite] + + # (2) Symmetrize — build reverse map in one pass, then process + # each orbital in ascending order to preserve overwrite semantics. + from collections import defaultdict + nsite = StdI.nsite + orb_positions: dict[int, list[tuple[int, int]]] = defaultdict(list) + for r in range(nsite): + for c in range(nsite): + v = int(Jastrow[r, c]) + if 0 <= v < StdI.NOrb: + orb_positions[v].append((r, c)) - # (2) Symmetrize for iorb in range(StdI.NOrb): - for isite in range(StdI.nsite): - for jsite in range(StdI.nsite): - if Jastrow[isite, jsite] == iorb: - Jastrow[jsite, isite] = Jastrow[isite, jsite] + for r, c in orb_positions.get(iorb, ()): + if Jastrow[r, c] == iorb: + Jastrow[c, r] = iorb # (3) Exclude local-spin sites and renumber NJastrow = 0 if StdI.model == ModelType.HUBBARD else -1 @@ -371,8 +382,7 @@ def _jastrow_momentum_projected( if Jastrow[isite, jsite] >= 0: iJastrow = Jastrow[isite, jsite] NJastrow -= 1 - mask = (Jastrow == iJastrow) - Jastrow[mask] = NJastrow + Jastrow[Jastrow == iJastrow] = NJastrow NJastrow = -NJastrow Jastrow = -1 - Jastrow @@ -412,8 +422,11 @@ def _jastrow_global_optimization( Jastrow[:half, :] = 0 NJastrow += 1 + # Pre-compute cell vectors to avoid repeated _cell_vector calls + cell_vectors = [_cell_vector(StdI.Cell, k) for k in range(StdI.NCell)] + for dCell in range(StdI.NCell): - dCV = _cell_vector(StdI.Cell, dCell) + dCV = cell_vectors[dCell] isite, jsite, Cphase, dR_arr = find_site( StdI, 0, 0, 0, -dCV[0], -dCV[1], -dCV[2], 0, 0) if StdI.model == ModelType.KONDO: @@ -422,16 +435,17 @@ def _jastrow_global_optimization( if iCell_j < dCell: continue reversal = 1 if iCell_j == dCell else 0 + dCV_is_zero = (dCV == [0, 0, 0]) for isiteUC in range(StdI.NsiteUC): for jsiteUC in range(StdI.NsiteUC): if reversal == 1 and jsiteUC > isiteUC: continue - if isiteUC == jsiteUC and dCV == [0, 0, 0]: + if isiteUC == jsiteUC and dCV_is_zero: continue for iCell_idx in range(StdI.NCell): - iCV = _cell_vector(StdI.Cell, iCell_idx) + iCV = cell_vectors[iCell_idx] i_s, j_s, _, _ = find_site( StdI, iCV[0], iCV[1], iCV[2], @@ -465,21 +479,24 @@ def print_jastrow(StdI: StdIntList) -> None: NJastrow = _jastrow_global_optimization(StdI, Jastrow) with open("jastrowidx.def", "w") as fp: - fp.write("=============================================\n") - fp.write(f"NJastrowIdx {NJastrow:10d}\n") - fp.write(f"ComplexType {0:10d}\n") - fp.write("=============================================\n") - fp.write("=============================================\n") + lines = [ + "=============================================\n", + f"NJastrowIdx {NJastrow:10d}\n", + f"ComplexType {0:10d}\n", + "=============================================\n", + "=============================================\n", + ] for isite in range(StdI.nsite): for jsite in range(StdI.nsite): if isite == jsite: continue - fp.write(f"{isite:5d} {jsite:5d} {Jastrow[isite, jsite]:5d}\n") + lines.append(f"{isite:5d} {jsite:5d} {Jastrow[isite, jsite]:5d}\n") for iJastrow in range(NJastrow): if StdI.model == ModelType.HUBBARD or iJastrow > 0: - fp.write(f"{iJastrow:5d} {1:5d}\n") + lines.append(f"{iJastrow:5d} {1:5d}\n") else: - fp.write(f"{iJastrow:5d} {0:5d}\n") + lines.append(f"{iJastrow:5d} {0:5d}\n") + fp.write("".join(lines)) print(" jastrowidx.def is written.") diff --git a/python/writer/mvmc_writer.py b/python/stdface/solvers/mvmc/writer.py similarity index 72% rename from python/writer/mvmc_writer.py rename to python/stdface/solvers/mvmc/writer.py index a1f4a80..fa6c84e 100644 --- a/python/writer/mvmc_writer.py +++ b/python/stdface/solvers/mvmc/writer.py @@ -28,7 +28,7 @@ from __future__ import annotations -from stdface_vals import StdIntList, ModelType, NaN_i +from ...core.stdface_vals import StdIntList, ModelType, NaN_i def _has_anti_period(StdI: StdIntList) -> bool: @@ -70,26 +70,29 @@ def print_orb(StdI: StdIntList) -> None: Translated from the C function ``PrintOrb()`` in ``StdFace_main.c``. """ with open("orbitalidx.def", "w") as fp: - fp.write("=============================================\n") - fp.write(f"NOrbitalIdx {StdI.NOrb:10d}\n") - fp.write(f"ComplexType {StdI.ComplexType:10d}\n") - fp.write("=============================================\n") - fp.write("=============================================\n") + lines = [ + "=============================================\n", + f"NOrbitalIdx {StdI.NOrb:10d}\n", + f"ComplexType {StdI.ComplexType:10d}\n", + "=============================================\n", + "=============================================\n", + ] has_anti = _has_anti_period(StdI) for isite in range(StdI.nsite): for jsite in range(StdI.nsite): if has_anti: - fp.write(f"{isite:5d} {jsite:5d} " - f"{StdI.Orb[isite][jsite]:5d} " - f"{StdI.AntiOrb[isite][jsite]:5d}\n") + lines.append(f"{isite:5d} {jsite:5d} " + f"{StdI.Orb[isite][jsite]:5d} " + f"{StdI.AntiOrb[isite][jsite]:5d}\n") else: - fp.write(f"{isite:5d} {jsite:5d} " - f"{StdI.Orb[isite][jsite]:5d}\n") + lines.append(f"{isite:5d} {jsite:5d} " + f"{StdI.Orb[isite][jsite]:5d}\n") for iOrb in range(StdI.NOrb): - fp.write(f"{iOrb:5d} {1:5d}\n") + lines.append(f"{iOrb:5d} {1:5d}\n") + fp.write("".join(lines)) print(" orbitalidx.def is written.") @@ -133,40 +136,49 @@ def _compute_parallel_orbitals( NOrbGC : int Number of unique parallel orbital indices. """ - # (1) Copy - OrbGC = [[0] * nsite for _ in range(nsite)] - reverse = [[0] * nsite for _ in range(nsite)] - for isite in range(nsite): - for jsite in range(nsite): - OrbGC[isite][jsite] = int(Orb[isite][jsite]) - reverse[isite][jsite] = int(AntiOrb[isite][jsite]) + import numpy as np + + # (1) Copy into numpy arrays + OrbGC = np.asarray(Orb, dtype=int).copy() + reverse = np.asarray(AntiOrb, dtype=int).copy() + + # (2) Symmetrise — process each orbital in ascending order. + # Build a reverse map (value -> list of positions) in a single + # pass, then iterate orbitals in order. This replaces NOrb + # full-matrix scans with one scan + targeted updates. + from collections import defaultdict + orb_positions: dict[int, list[tuple[int, int]]] = defaultdict(list) + for r in range(nsite): + for c in range(nsite): + v = int(OrbGC[r, c]) + if 0 <= v < NOrb: + orb_positions[v].append((r, c)) - # (2) Symmetrise for iorb in range(NOrb): - for isite in range(nsite): - for jsite in range(nsite): - if OrbGC[isite][jsite] == iorb: - OrbGC[jsite][isite] = OrbGC[isite][jsite] - reverse[jsite][isite] = -reverse[isite][jsite] - - # (3) Renumber -- lower triangle (isite > jsite) + for r, c in orb_positions.get(iorb, ()): + if OrbGC[r, c] == iorb: + OrbGC[c, r] = iorb + reverse[c, r] = -reverse[r, c] + + # (3) Renumber — lower triangle (isite > jsite). + # Replace each newly-seen positive orbital value with a negative + # temporary across the entire matrix (vectorised). NOrbGC = 0 for isite in range(nsite): for jsite in range(isite): - if OrbGC[isite][jsite] >= 0: - iOrbGC = OrbGC[isite][jsite] + if OrbGC[isite, jsite] >= 0: + iOrbGC = OrbGC[isite, jsite] NOrbGC -= 1 - for isite1 in range(nsite): - for jsite1 in range(nsite): - if OrbGC[isite1][jsite1] == iOrbGC: - OrbGC[isite1][jsite1] = NOrbGC + OrbGC[OrbGC == iOrbGC] = NOrbGC NOrbGC = -NOrbGC - for isite in range(nsite): - for jsite in range(nsite): - OrbGC[isite][jsite] = -1 - OrbGC[isite][jsite] + OrbGC = -1 - OrbGC + + # Convert back to list-of-lists for downstream compatibility + OrbGC_list = OrbGC.tolist() + reverse_list = reverse.tolist() - return OrbGC, reverse, NOrbGC + return OrbGC_list, reverse_list, NOrbGC def _write_orbitalidxpara( @@ -192,22 +204,23 @@ def _write_orbitalidxpara( Number of parallel orbital indices. """ with open("orbitalidxpara.def", "w") as fp: - fp.write("=============================================\n") - fp.write(f"NOrbitalIdx {NOrbGC:10d}\n") - fp.write(f"ComplexType {ComplexType:10d}\n") - fp.write("=============================================\n") - fp.write("=============================================\n") + lines = [ + "=============================================\n", + f"NOrbitalIdx {NOrbGC:10d}\n", + f"ComplexType {ComplexType:10d}\n", + "=============================================\n", + "=============================================\n", + ] for isite in range(nsite): - for jsite in range(nsite): - if isite >= jsite: - continue - fp.write(f"{isite:5d} {jsite:5d} " - f"{OrbGC[isite][jsite]:5d} " - f"{reverse[isite][jsite]:5d}\n") + for jsite in range(isite + 1, nsite): + lines.append(f"{isite:5d} {jsite:5d} " + f"{OrbGC[isite][jsite]:5d} " + f"{reverse[isite][jsite]:5d}\n") for iOrbGC in range(NOrbGC): - fp.write(f"{iOrbGC:5d} {1:5d}\n") + lines.append(f"{iOrbGC:5d} {1:5d}\n") + fp.write("".join(lines)) def _write_orbitalidxgen( @@ -234,40 +247,41 @@ def _write_orbitalidxgen( has_anti = _has_anti_period(StdI) with open("orbitalidxgen.def", "w") as fp: - fp.write("=============================================\n") - fp.write(f"NOrbitalIdx {StdI.NOrb + 2 * NOrbGC:10d}\n") - fp.write(f"ComplexType {StdI.ComplexType:10d}\n") - fp.write("=============================================\n") - fp.write("=============================================\n") + lines = [ + "=============================================\n", + f"NOrbitalIdx {StdI.NOrb + 2 * NOrbGC:10d}\n", + f"ComplexType {StdI.ComplexType:10d}\n", + "=============================================\n", + "=============================================\n", + ] # -- anti-parallel section -- for isite in range(nsite): for jsite in range(nsite): if has_anti: - fp.write(f"{isite:5d} 0 {jsite:5d} 1 " - f"{StdI.Orb[isite][jsite]:5d} " - f"{StdI.AntiOrb[isite][jsite]:5d}\n") + lines.append(f"{isite:5d} 0 {jsite:5d} 1 " + f"{StdI.Orb[isite][jsite]:5d} " + f"{StdI.AntiOrb[isite][jsite]:5d}\n") else: - fp.write(f"{isite:5d} 0 {jsite:5d} 1 " - f"{StdI.Orb[isite][jsite]:5d} {1:5d}\n") + lines.append(f"{isite:5d} 0 {jsite:5d} 1 " + f"{StdI.Orb[isite][jsite]:5d} {1:5d}\n") # -- parallel section (upper triangle) -- for isite in range(nsite): - for jsite in range(nsite): - if isite >= jsite: - continue - fp.write(f"{isite:5d} 0 {jsite:5d} 0 " - f"{OrbGC[isite][jsite] + StdI.NOrb:5d} " - f"{reverse[isite][jsite]:5d}\n") - fp.write(f"{isite:5d} 1 {jsite:5d} 1 " - f"{OrbGC[isite][jsite] + StdI.NOrb + NOrbGC:5d} " - f"{reverse[isite][jsite]:5d}\n") + for jsite in range(isite + 1, nsite): + lines.append(f"{isite:5d} 0 {jsite:5d} 0 " + f"{OrbGC[isite][jsite] + StdI.NOrb:5d} " + f"{reverse[isite][jsite]:5d}\n") + lines.append(f"{isite:5d} 1 {jsite:5d} 1 " + f"{OrbGC[isite][jsite] + StdI.NOrb + NOrbGC:5d} " + f"{reverse[isite][jsite]:5d}\n") for iOrbGC in range(StdI.NOrb): - fp.write(f"{iOrbGC:5d} {1:5d}\n") + lines.append(f"{iOrbGC:5d} {1:5d}\n") for iOrbGC in range(NOrbGC * 2): - fp.write(f"{iOrbGC + StdI.NOrb:5d} {1:5d}\n") + lines.append(f"{iOrbGC + StdI.NOrb:5d} {1:5d}\n") + fp.write("".join(lines)) def print_orb_para(StdI: StdIntList) -> None: @@ -416,18 +430,21 @@ def _write_gutzwiller_file( Per-site Gutzwiller index array. """ with open("gutzwilleridx.def", "w") as fp: - fp.write("=============================================\n") - fp.write(f"NGutzwillerIdx {NGutzwiller:10d}\n") - fp.write(f"ComplexType {0:10d}\n") - fp.write("=============================================\n") - fp.write("=============================================\n") + lines = [ + "=============================================\n", + f"NGutzwillerIdx {NGutzwiller:10d}\n", + f"ComplexType {0:10d}\n", + "=============================================\n", + "=============================================\n", + ] for isite in range(StdI.nsite): - fp.write(f"{isite:5d} {Gutz[isite]:5d}\n") + lines.append(f"{isite:5d} {Gutz[isite]:5d}\n") for iGutz in range(NGutzwiller): flag = int(StdI.model == ModelType.HUBBARD or iGutz > 0) - fp.write(f"{iGutz:5d} {flag:5d}\n") + lines.append(f"{iGutz:5d} {flag:5d}\n") + fp.write("".join(lines)) def print_gutzwiller(StdI: StdIntList) -> None: diff --git a/python/stdface/solvers/uhf/__init__.py b/python/stdface/solvers/uhf/__init__.py new file mode 100644 index 0000000..f325ed4 --- /dev/null +++ b/python/stdface/solvers/uhf/__init__.py @@ -0,0 +1,9 @@ +"""UHF solver plugin package. + +Importing this package auto-registers the UHF plugin. +""" +from __future__ import annotations + +from ._plugin import UHFPlugin # noqa: F401 — public API + +__all__ = ["UHFPlugin"] diff --git a/python/stdface/solvers/uhf/_plugin.py b/python/stdface/solvers/uhf/_plugin.py new file mode 100644 index 0000000..e72a360 --- /dev/null +++ b/python/stdface/solvers/uhf/_plugin.py @@ -0,0 +1,84 @@ +"""UHF solver plugin. + +Encapsulates all UHF-specific keyword parsing, field reset tables, +and Expert-mode file writing. +""" +from __future__ import annotations + +from ...plugin import SolverPlugin, register +from ...core.stdface_vals import StdIntList, SolverType, NaN_i, NaN_d +from ...core.keyword_parser import ( + store_with_check_dup_i, store_with_check_dup_d, + _grid3x3_keywords, +) + + +class UHFPlugin(SolverPlugin): + """Plugin for the UHF (unrestricted Hartree-Fock) solver.""" + + @property + def name(self) -> str: + return SolverType.UHF + + @property + def keyword_table(self) -> dict[str, tuple]: + return _UHF_KEYWORDS + + @property + def reset_scalars(self) -> list[tuple[str, object]]: + return _RESET_SCALARS + + @property + def reset_arrays(self) -> list[tuple[str, object]]: + return _RESET_ARRAYS + + def write_green(self, StdI: StdIntList) -> None: + """Write only greenone.def (UHF does not use greentwo).""" + from ...writer.common_writer import print_1_green + print_1_green(StdI) + + +# ----------------------------------------------------------------------- +# Keyword table +# ----------------------------------------------------------------------- + +_BOXSUB_KEYWORDS: dict[str, tuple] = _grid3x3_keywords( + "{a}{c}sub", "boxsub", store_with_check_dup_i, int +) + +_UHF_KEYWORDS: dict[str, tuple] = { + "iteration_max": (store_with_check_dup_i, "Iteration_max"), + "rndseed": (store_with_check_dup_i, "RndSeed"), + "nmptrans": (store_with_check_dup_i, "NMPTrans"), + **_BOXSUB_KEYWORDS, + "hsub": (store_with_check_dup_i, "Hsub"), + "lsub": (store_with_check_dup_i, "Lsub"), + "wsub": (store_with_check_dup_i, "Wsub"), + "eps": (store_with_check_dup_i, "eps"), + "epsslater": (store_with_check_dup_i, "eps_slater"), + "mix": (store_with_check_dup_d, "mix"), +} + +# ----------------------------------------------------------------------- +# Reset tables +# ----------------------------------------------------------------------- + +_RESET_SCALARS: list[tuple[str, object]] = [ + ("NMPTrans", NaN_i), + ("RndSeed", NaN_i), + ("mix", NaN_d), + ("eps", NaN_i), + ("eps_slater", NaN_i), + ("Iteration_max", NaN_i), + ("Hsub", NaN_i), + ("Lsub", NaN_i), + ("Wsub", NaN_i), +] + +_RESET_ARRAYS: list[tuple[str, object]] = [ + ("boxsub", NaN_i), +] + + +# Auto-register on import +register(UHFPlugin()) diff --git a/python/stdface/writer/__init__.py b/python/stdface/writer/__init__.py new file mode 100644 index 0000000..fc257ba --- /dev/null +++ b/python/stdface/writer/__init__.py @@ -0,0 +1,8 @@ +"""Output file generation subpackage. + +This package contains shared modules for writing Expert-mode definition files. +Solver-specific writers have moved to ``stdface.solvers``. +""" +from __future__ import annotations + +from .interaction_writer import print_interactions # noqa: F401 diff --git a/python/writer/common_writer.py b/python/stdface/writer/common_writer.py similarity index 87% rename from python/writer/common_writer.py rename to python/stdface/writer/common_writer.py index 2fd8e2f..ce0fec8 100644 --- a/python/writer/common_writer.py +++ b/python/stdface/writer/common_writer.py @@ -44,15 +44,16 @@ from __future__ import annotations from collections.abc import Callable +from itertools import product from typing import NamedTuple import numpy as np -from stdface_vals import ( +from ..core.stdface_vals import ( StdIntList, ModelType, SolverType, MethodType, NaN_i, UNSET_STRING, AMPLITUDE_EPS, ) -from param_check import exit_program, print_val_i, print_val_d, required_val_i, not_used_i +from ..core.param_check import exit_program, print_val_i, print_val_d, required_val_i, not_used_i # --------------------------------------------------------------------------- @@ -145,14 +146,15 @@ def print_loc_spin(StdI: StdIntList) -> None: """ nlocspin = int(np.count_nonzero(StdI.locspinflag[:StdI.nsite])) + lines = ["================================ \n", + f"NlocalSpin {nlocspin:5d} \n", + "================================ \n", + "========i_1LocSpn_0IteElc ====== \n", + "================================ \n"] + for isite in range(StdI.nsite): + lines.append(f"{isite:5d} {StdI.locspinflag[isite]:5d}\n") with open("locspn.def", "w") as fp: - fp.write("================================ \n") - fp.write(f"NlocalSpin {nlocspin:5d} \n") - fp.write("================================ \n") - fp.write("========i_1LocSpn_0IteElc ====== \n") - fp.write("================================ \n") - for isite in range(StdI.nsite): - fp.write(f"{isite:5d} {StdI.locspinflag[isite]:5d}\n") + fp.write("".join(lines)) print(" locspn.def is written.") @@ -178,20 +180,21 @@ def print_trans(StdI: StdIntList) -> None: ntrans0 = _merge_duplicate_terms(StdI.transindx, StdI.trans, StdI.ntrans) # --- write file --- + lines = ["======================== \n", + f"NTransfer {ntrans0:7d} \n", + "======================== \n", + "========i_j_s_tijs====== \n", + "======================== \n"] + for ktrans in range(StdI.ntrans): + val = StdI.trans[ktrans] + if abs(val) > AMPLITUDE_EPS: + i0, s0, i1, s1 = StdI.transindx[ktrans] + lines.append( + f"{i0:5d} {s0:5d} {i1:5d} {s1:5d} " + f"{val.real:25.15f} {val.imag:25.15f}\n" + ) with open("trans.def", "w") as fp: - fp.write("======================== \n") - fp.write(f"NTransfer {ntrans0:7d} \n") - fp.write("======================== \n") - fp.write("========i_j_s_tijs====== \n") - fp.write("======================== \n") - for ktrans in range(StdI.ntrans): - val = StdI.trans[ktrans] - if abs(val) > AMPLITUDE_EPS: - i0, s0, i1, s1 = StdI.transindx[ktrans] - fp.write( - f"{i0:5d} {s0:5d} {i1:5d} {s1:5d} " - f"{val.real:25.15f} {val.imag:25.15f}\n" - ) + fp.write("".join(lines)) print(" trans.def is written.") @@ -513,6 +516,11 @@ def __init__( self.is_kondo = is_kondo self.is_mvmc = is_mvmc + # Pre-compute lookup tables to avoid per-call overhead in tight loops + self._spin_max = [1 if locspinflag[s] == 0 else locspinflag[s] + for s in range(nsite)] + self._is_local_spin = [locspinflag[s] != 0 for s in range(nsite)] + # ------------------------------------------------------------------ # Low-level helpers # ------------------------------------------------------------------ @@ -533,8 +541,7 @@ def spin_max(self, site: int) -> int: int Maximum spin index (inclusive upper bound). """ - flag = self.locspinflag[site] - return 1 if flag == 0 else flag + return self._spin_max[site] def skip_local_spin_pair(self, site_a: int, site_b: int) -> bool: """Return True if a pair of distinct local-spin sites should be skipped. @@ -552,8 +559,8 @@ def skip_local_spin_pair(self, site_a: int, site_b: int) -> bool: ``True`` if both sites are local-spin and distinct. """ return (site_a != site_b - and self.locspinflag[site_a] != 0 - and self.locspinflag[site_b] != 0) + and self._is_local_spin[site_a] + and self._is_local_spin[site_b]) def kondo_site(self, isite: int) -> int: """Map a Kondo unit-cell index to the physical site index. @@ -589,17 +596,16 @@ def green1_corr(self) -> list[tuple[int, int, int, int]]: """ indices: list[tuple[int, int, int, int]] = [] xkondo = 2 if self.is_kondo else 1 + is_local = self._is_local_spin for isite in range(self.NsiteUC * xkondo): isite2 = self.kondo_site(isite) - - for ispin in range(self.spin_max(isite2) + 1): + for ispin in range(self._spin_max[isite2] + 1): for jsite in range(self.nsite): - for jspin in range(self.spin_max(jsite) + 1): - if self.skip_local_spin_pair(isite2, jsite): - continue - if ispin == jspin: - indices.append((isite2, ispin, jsite, jspin)) + if is_local[isite2] and is_local[jsite] and isite2 != jsite: + continue + if ispin <= self._spin_max[jsite]: + indices.append((isite2, ispin, jsite, ispin)) return indices @@ -614,16 +620,15 @@ def green1_raw(self) -> list[tuple[int, int, int, int]]: list of tuple[int, int, int, int] Index tuples ``(isite, ispin, jsite, jspin)``. """ - indices: list[tuple[int, int, int, int]] = [] - - for isite in range(self.nsite): - for ispin in range(self.spin_max(isite) + 1): - for jsite in range(self.nsite): - for jspin in range(self.spin_max(jsite) + 1): - if self.skip_local_spin_pair(isite, jsite): - continue - indices.append((isite, ispin, jsite, jspin)) - + site_spins = [(s, sp) for s in range(self.nsite) + for sp in range(self._spin_max[s] + 1)] + is_local = self._is_local_spin + indices: list[tuple[int, int, int, int]] = [ + (isite, ispin, jsite, jspin) + for isite, ispin in site_spins + for jsite, jspin in site_spins + if not (is_local[isite] and is_local[jsite] and isite != jsite) + ] return indices # ------------------------------------------------------------------ @@ -645,35 +650,39 @@ def green2_corr(self) -> list[tuple[int, int, int, int, int, int, int, int]]: """ indices: list[tuple[int, int, int, int, int, int, int, int]] = [] xkondo = 2 if self.is_kondo else 1 + is_mvmc = self.is_mvmc + spin_max = self._spin_max for site1 in range(self.NsiteUC * xkondo): site1k = self.kondo_site(site1) - S1Max = self.spin_max(site1k) + S1Max = spin_max[site1k] - for spin1 in range(S1Max + 1): - for spin2 in range(S1Max + 1): - for site3 in range(self.nsite): - S3Max = self.spin_max(site3) + for site3 in range(self.nsite): + S3Max = spin_max[site3] + for spin1 in range(S1Max + 1): + for spin2 in range(S1Max + 1): + # spin4 = spin1 - spin2 + spin3 for spin3 in range(S3Max + 1): - for spin4 in range(S3Max + 1): - if spin1 - spin2 + spin3 - spin4 == 0: - if self.is_mvmc and ( - spin1 != spin2 or spin3 != spin4 - ): - indices.append(( - site1k, spin1, - site3, spin4, - site3, spin3, - site1k, spin2, - )) - else: - indices.append(( - site1k, spin1, - site1k, spin2, - site3, spin3, - site3, spin4, - )) + spin4 = spin1 - spin2 + spin3 + if spin4 < 0 or spin4 > S3Max: + continue + if is_mvmc and ( + spin1 != spin2 or spin3 != spin4 + ): + indices.append(( + site1k, spin1, + site3, spin4, + site3, spin3, + site1k, spin2, + )) + else: + indices.append(( + site1k, spin1, + site1k, spin2, + site3, spin3, + site3, spin4, + )) return indices @@ -688,28 +697,19 @@ def green2_raw(self) -> list[tuple[int, int, int, int, int, int, int, int]]: list of tuple[int, int, int, int, int, int, int, int] Index tuples of 8 elements each. """ - indices: list[tuple[int, int, int, int, int, int, int, int]] = [] + site_spins = [(s, sp) for s in range(self.nsite) + for sp in range(self._spin_max[s] + 1)] + is_local = self._is_local_spin - for site1 in range(self.nsite): - for spin1 in range(self.spin_max(site1) + 1): - for site2 in range(self.nsite): - if self.skip_local_spin_pair(site1, site2): - continue - - for spin2 in range(self.spin_max(site2) + 1): - for site3 in range(self.nsite): - for spin3 in range(self.spin_max(site3) + 1): - for site4 in range(self.nsite): - if self.skip_local_spin_pair(site3, site4): - continue - - for spin4 in range(self.spin_max(site4) + 1): - indices.append(( - site1, spin1, - site2, spin2, - site3, spin3, - site4, spin4, - )) + # Precompute valid site pairs (no local-spin pair skip) + indices: list[tuple[int, int, int, int, int, int, int, int]] = [] + for (s1, sp1), (s2, sp2) in product(site_spins, repeat=2): + if is_local[s1] and is_local[s2] and s1 != s2: + continue + for (s3, sp3), (s4, sp4) in product(site_spins, repeat=2): + if is_local[s3] and is_local[s4] and s3 != s4: + continue + indices.append((s1, sp1, s2, sp2, s3, sp3, s4, sp4)) return indices @@ -747,14 +747,16 @@ def print_1_green(StdI: StdIntList) -> None: ngreen = len(greenindx) with open("greenone.def", "w") as fp: - fp.write("===============================\n") - fp.write(f"NCisAjs {ngreen:10d}\n") - fp.write("===============================\n") - fp.write("======== Green functions ======\n") - fp.write("===============================\n") - for row in greenindx: - i0, s0, i1, s1 = row - fp.write(f"{i0:5d} {s0:5d} {i1:5d} {s1:5d}\n") + lines = [ + "===============================\n", + f"NCisAjs {ngreen:10d}\n", + "===============================\n", + "======== Green functions ======\n", + "===============================\n", + ] + for i0, s0, i1, s1 in greenindx: + lines.append(f"{i0:5d} {s0:5d} {i1:5d} {s1:5d}\n") + fp.write("".join(lines)) print(" greenone.def is written.") @@ -793,17 +795,19 @@ def print_2_green(StdI: StdIntList) -> None: ngreen = len(greenindx) with open("greentwo.def", "w") as fp: - fp.write("=============================================\n") - fp.write(f"NCisAjsCktAltDC {ngreen:10d}\n") - fp.write("=============================================\n") - fp.write("======== Green functions for Sq AND Nq ======\n") - fp.write("=============================================\n") - for row in greenindx: - i0, s0, i1, s1, i2, s2, i3, s3 = row - fp.write( + lines = [ + "=============================================\n", + f"NCisAjsCktAltDC {ngreen:10d}\n", + "=============================================\n", + "======== Green functions for Sq AND Nq ======\n", + "=============================================\n", + ] + for i0, s0, i1, s1, i2, s2, i3, s3 in greenindx: + lines.append( f"{i0:5d} {s0:5d} {i1:5d} {s1:5d} " f"{i2:5d} {s2:5d} {i3:5d} {s3:5d}\n" ) + fp.write("".join(lines)) print(" greentwo.def is written.") diff --git a/python/writer/interaction_writer.py b/python/stdface/writer/interaction_writer.py similarity index 92% rename from python/writer/interaction_writer.py rename to python/stdface/writer/interaction_writer.py index 79d668d..be32f74 100644 --- a/python/writer/interaction_writer.py +++ b/python/stdface/writer/interaction_writer.py @@ -26,7 +26,7 @@ from typing import NamedTuple -from stdface_vals import StdIntList, AMPLITUDE_EPS +from ..core.stdface_vals import StdIntList, AMPLITUDE_EPS # --------------------------------------------------------------------------- @@ -134,16 +134,17 @@ def _write_interaction_file( Number of site indices per term (1 or 2). """ nintr0 = _count_nonzero(nterms, coeff) + lines = ["=============================================\n", + f"{count_label} {nintr0:10d}\n", + "=============================================\n", + f"{banner}\n", + "=============================================\n"] + for k in range(nterms): + if abs(coeff[k]) > AMPLITUDE_EPS: + idx_str = " ".join(f"{indx[k][i]:5d}" for i in range(n_indices)) + lines.append(f"{idx_str} {coeff[k]:25.15f}\n") with open(filename, "w") as fp: - fp.write("=============================================\n") - fp.write(f"{count_label} {nintr0:10d}\n") - fp.write("=============================================\n") - fp.write(f"{banner}\n") - fp.write("=============================================\n") - for k in range(nterms): - if abs(coeff[k]) > AMPLITUDE_EPS: - idx_str = " ".join(f"{indx[k][i]:5d}" for i in range(n_indices)) - fp.write(f"{idx_str} {coeff[k]:25.15f}\n") + fp.write("".join(lines)) print(f" {filename} is written.") @@ -472,26 +473,27 @@ def _write_interall(StdI: StdIntList) -> None: StdI.Lintr = 1 if StdI.Lintr == 1: - with open("interall.def", "w") as fp: - fp.write("====================== \n") - fp.write(f"NInterAll {nintr0:7d} \n") - fp.write("====================== \n") - fp.write("========zInterAll===== \n") - fp.write("====================== \n") - - if StdI.lBoost == 0: - for kintr in range(StdI.nintr): - val = StdI.intr[kintr] - if abs(val) > AMPLITUDE_EPS: - i0, s0, i1, s1, i2, s2, i3, s3 = StdI.intrindx[kintr] - fp.write( - f"{i0:5d} {s0:5d} " - f"{i1:5d} {s1:5d} " - f"{i2:5d} {s2:5d} " - f"{i3:5d} {s3:5d} " - f"{val.real:25.15f} {val.imag:25.15f}\n" - ) + lines = ["====================== \n", + f"NInterAll {nintr0:7d} \n", + "====================== \n", + "========zInterAll===== \n", + "====================== \n"] + + if StdI.lBoost == 0: + for kintr in range(StdI.nintr): + val = StdI.intr[kintr] + if abs(val) > AMPLITUDE_EPS: + i0, s0, i1, s1, i2, s2, i3, s3 = StdI.intrindx[kintr] + lines.append( + f"{i0:5d} {s0:5d} " + f"{i1:5d} {s1:5d} " + f"{i2:5d} {s2:5d} " + f"{i3:5d} {s3:5d} " + f"{val.real:25.15f} {val.imag:25.15f}\n" + ) + with open("interall.def", "w") as fp: + fp.write("".join(lines)) print(" interall.def is written.") diff --git a/python/writer/__init__.py b/python/writer/__init__.py deleted file mode 100644 index c84891e..0000000 --- a/python/writer/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -"""Output file generation subpackage. - -This package contains modules for writing Expert-mode definition files -(``.def``) for all supported solvers (HPhi, mVMC, UHF, H-wave). -""" -from __future__ import annotations - -from .solver_writer import get_solver_writer -from .interaction_writer import print_interactions diff --git a/python/writer/solver_writer.py b/python/writer/solver_writer.py deleted file mode 100644 index 776969b..0000000 --- a/python/writer/solver_writer.py +++ /dev/null @@ -1,252 +0,0 @@ -"""Solver-specific output writer classes. - -This module provides a class hierarchy for writing Expert-mode definition -files for each supported solver. It replaces the ``if solver == ...`` -dispatch block that was previously in ``stdface_main.stdface_main()``. - -Classes -------- -SolverWriter - Abstract base class defining the writer interface. -HPhiWriter - Writes definition files for the HPhi solver. -MVMCWriter - Writes definition files for the mVMC solver. -UHFWriter - Writes definition files for the UHF solver. -HWaveWriter - Writes definition files for the H-wave solver. - -Functions ---------- -get_solver_writer - Factory function returning the correct writer for a solver name. -""" -from __future__ import annotations - -from abc import ABC, abstractmethod - -from stdface_vals import StdIntList, SolverType, MethodType, NaN_i -from param_check import print_val_i -from .mvmc_variational import ( - generate_orb, - proj, - print_jastrow, -) -from .hphi_writer import ( - print_calc_mod, - print_excitation, - print_pump, -) -from .mvmc_writer import ( - print_orb, - print_orb_para, - print_gutzwiller, -) -from .common_writer import ( - print_loc_spin, - print_trans, - print_namelist, - print_mod_para, - print_1_green, - print_2_green, - check_output_mode, - check_mod_para, -) -from .interaction_writer import print_interactions -from .export_wannier90 import export_geometry, export_interaction - - -class SolverWriter(ABC): - """Abstract base class for solver-specific output writers. - - Each concrete subclass implements :meth:`write` to produce the set of - Expert-mode definition files required by a particular solver. - - Parameters - ---------- - name : str - Human-readable solver name (e.g. ``"HPhi"``). - """ - - def __init__(self, name: str) -> None: - self.name = name - - @abstractmethod - def write(self, StdI: StdIntList) -> None: - """Write all Expert-mode definition files for this solver. - - Parameters - ---------- - StdI : StdIntList - The fully-populated parameter structure. - """ - - -class HPhiWriter(SolverWriter): - """Writer for the HPhi exact-diagonalisation / Lanczos solver. - - Produces: locspn, trans, interactions, modpara, single/pair (excitation), - teone/tetwo (time-evolution pump), calcmod, greenone, greentwo, namelist. - """ - - def __init__(self) -> None: - super().__init__(SolverType.HPhi) - - def write(self, StdI: StdIntList) -> None: - """Write all HPhi definition files. - - Parameters - ---------- - StdI : StdIntList - The fully-populated parameter structure. - """ - print_loc_spin(StdI) - print_trans(StdI) - print_interactions(StdI) - check_mod_para(StdI) - print_mod_para(StdI) - print_excitation(StdI) - if StdI.method == MethodType.TIME_EVOLUTION: - print_pump(StdI) - print_calc_mod(StdI) - check_output_mode(StdI) - print_1_green(StdI) - print_2_green(StdI) - print_namelist(StdI) - - -class MVMCWriter(SolverWriter): - """Writer for the mVMC variational Monte Carlo solver. - - Produces: locspn, trans, interactions, modpara, orbitalidx, - orbitalidxpara/gen, gutzwilleridx, greenone, greentwo, namelist, - plus Jastrow and projection files. - """ - - def __init__(self) -> None: - super().__init__(SolverType.mVMC) - - def write(self, StdI: StdIntList) -> None: - """Write all mVMC definition files. - - Parameters - ---------- - StdI : StdIntList - The fully-populated parameter structure. - """ - print_loc_spin(StdI) - print_trans(StdI) - print_interactions(StdI) - check_mod_para(StdI) - print_mod_para(StdI) - - if StdI.lGC == 0 and (StdI.Sz2 == 0 or StdI.Sz2 == NaN_i): - StdI.ComplexType = print_val_i("ComplexType", StdI.ComplexType, 0) - else: - StdI.ComplexType = print_val_i("ComplexType", StdI.ComplexType, 1) - - generate_orb(StdI) - proj(StdI) - print_jastrow(StdI) - if StdI.lGC == 1 or (StdI.Sz2 != 0 and StdI.Sz2 != NaN_i): - print_orb_para(StdI) - print_gutzwiller(StdI) - print_orb(StdI) - check_output_mode(StdI) - print_1_green(StdI) - print_2_green(StdI) - print_namelist(StdI) - - -class UHFWriter(SolverWriter): - """Writer for the UHF (unrestricted Hartree--Fock) solver. - - Produces: locspn, trans, interactions, modpara, greenone, namelist. - """ - - def __init__(self) -> None: - super().__init__(SolverType.UHF) - - def write(self, StdI: StdIntList) -> None: - """Write all UHF definition files. - - Parameters - ---------- - StdI : StdIntList - The fully-populated parameter structure. - """ - print_loc_spin(StdI) - print_trans(StdI) - print_interactions(StdI) - check_mod_para(StdI) - print_mod_para(StdI) - check_output_mode(StdI) - print_1_green(StdI) - print_namelist(StdI) - - -class HWaveWriter(SolverWriter): - """Writer for the H-wave solver. - - In ``uhfr`` calc-mode, writes: trans, interactions, modpara, greenone. - Otherwise, exports Wannier90 geometry and interaction files. - """ - - def __init__(self) -> None: - super().__init__(SolverType.HWAVE) - - def write(self, StdI: StdIntList) -> None: - """Write all H-wave definition files. - - Parameters - ---------- - StdI : StdIntList - The fully-populated parameter structure. - """ - if StdI.calcmode == "uhfr": - print_trans(StdI) - print_interactions(StdI) - check_mod_para(StdI) - check_output_mode(StdI) - print_1_green(StdI) - else: - export_geometry(StdI) - export_interaction(StdI) - - -# Registry of solver writers -_SOLVER_WRITERS: dict[SolverType, type[SolverWriter]] = { - SolverType.HPhi: HPhiWriter, - SolverType.mVMC: MVMCWriter, - SolverType.UHF: UHFWriter, - SolverType.HWAVE: HWaveWriter, -} - - -def get_solver_writer(solver: str) -> SolverWriter: - """Return the appropriate solver writer instance. - - Parameters - ---------- - solver : str - Solver name. Must be one of ``"HPhi"``, ``"mVMC"``, ``"UHF"``, - or ``"HWAVE"``. - - Returns - ------- - SolverWriter - An instance of the concrete writer for *solver*. - - Raises - ------ - ValueError - If *solver* is not a recognised solver name. - """ - cls = _SOLVER_WRITERS.get(solver) - if cls is None: - raise ValueError( - f"Unknown solver {solver!r}. " - f"Expected one of: {', '.join(_SOLVER_WRITERS)}" - ) - return cls() diff --git a/stdface b/stdface new file mode 100755 index 0000000..cdb7f85 --- /dev/null +++ b/stdface @@ -0,0 +1,8 @@ +#!/usr/bin/env bash +# StdFace entry point +# Usage: +# ./stdface stan.in +# ./stdface stan.in --solver mVMC +# ./stdface -v +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +PYTHONPATH="${SCRIPT_DIR}/python" exec python3 -m stdface "$@" diff --git a/test/unit/conftest.py b/test/unit/conftest.py index ce034cc..ab8a2ef 100644 --- a/test/unit/conftest.py +++ b/test/unit/conftest.py @@ -1,7 +1,7 @@ """Pytest configuration for StdFace unit tests. -Adds the python/ directory to sys.path so that translated Python modules -can be imported directly by name (e.g., ``from stdface_vals import StdIntList``). +Adds the python/ directory to sys.path so that the ``stdface`` package +can be imported (e.g., ``from stdface.core.stdface_vals import StdIntList``). """ from __future__ import annotations diff --git a/test/unit/test_boost_output.py b/test/unit/test_boost_output.py index 06c02f2..5899d4d 100644 --- a/test/unit/test_boost_output.py +++ b/test/unit/test_boost_output.py @@ -11,8 +11,8 @@ import numpy as np import pytest -from stdface_vals import StdIntList -from lattice.boost_output import ( +from stdface.core.stdface_vals import StdIntList +from stdface.lattice.boost_output import ( write_boost_mag_field, write_boost_j_full, write_boost_j_symmetric, diff --git a/test/unit/test_chain_lattice.py b/test/unit/test_chain_lattice.py index a2a8c15..7b0fc75 100644 --- a/test/unit/test_chain_lattice.py +++ b/test/unit/test_chain_lattice.py @@ -10,8 +10,8 @@ import numpy as np import pytest -from stdface_vals import StdIntList -from lattice import chain_lattice as cl +from stdface.core.stdface_vals import StdIntList +from stdface.lattice import chain_lattice as cl # --------------------------------------------------------------------------- diff --git a/test/unit/test_common_writer.py b/test/unit/test_common_writer.py index 7b2d220..0216976 100644 --- a/test/unit/test_common_writer.py +++ b/test/unit/test_common_writer.py @@ -10,8 +10,8 @@ import pytest -from stdface_vals import StdIntList, ModelType, SolverType -from writer.common_writer import ( +from stdface.core.stdface_vals import StdIntList, ModelType, SolverType +from stdface.writer.common_writer import ( print_loc_spin, print_trans, print_namelist, @@ -41,7 +41,7 @@ GreenFunctionIndices, _merge_duplicate_terms, ) -from writer.interaction_writer import print_interactions +from stdface.writer.interaction_writer import print_interactions # Sentinel values matching what _reset_vals sets at runtime NaN_i = 2147483647 diff --git a/test/unit/test_export_wannier90.py b/test/unit/test_export_wannier90.py index f47b1ea..d0593bc 100644 --- a/test/unit/test_export_wannier90.py +++ b/test/unit/test_export_wannier90.py @@ -12,8 +12,8 @@ import numpy as np import pytest -from stdface_vals import StdIntList -from writer import export_wannier90 as ew +from stdface.core.stdface_vals import StdIntList +from stdface.solvers.hwave import export_wannier90 as ew # --------------------------------------------------------------------------- diff --git a/test/unit/test_geometry_output.py b/test/unit/test_geometry_output.py index 5e8c7e6..2d2ccb2 100644 --- a/test/unit/test_geometry_output.py +++ b/test/unit/test_geometry_output.py @@ -11,8 +11,8 @@ import numpy as np import pytest -from stdface_vals import StdIntList -from lattice.geometry_output import print_xsf, print_geometry, _cell_diff +from stdface.core.stdface_vals import StdIntList +from stdface.lattice.geometry_output import print_xsf, print_geometry, _cell_diff def _make_stdi( @@ -301,10 +301,10 @@ class TestBackwardCompatibility: def test_import_from_stdface_model_util(self): """Test that both functions are re-exported.""" - from stdface_model_util import ( + from stdface.core.stdface_model_util import ( print_xsf as pxsf, print_geometry as pg, ) - from lattice.geometry_output import print_xsf, print_geometry + from stdface.lattice.geometry_output import print_xsf, print_geometry assert pxsf is print_xsf assert pg is print_geometry diff --git a/test/unit/test_hphi_writer.py b/test/unit/test_hphi_writer.py index 3eaa677..88e5f1d 100644 --- a/test/unit/test_hphi_writer.py +++ b/test/unit/test_hphi_writer.py @@ -11,9 +11,9 @@ import pytest -from stdface_vals import StdIntList, MethodType, ModelType -from stdface_vals import UNSET_STRING -from writer.hphi_writer import ( +from stdface.core.stdface_vals import StdIntList, MethodType, ModelType +from stdface.core.stdface_vals import UNSET_STRING +from stdface.solvers.hphi.writer import ( large_value, print_calc_mod, print_excitation, diff --git a/test/unit/test_input_params.py b/test/unit/test_input_params.py index 3d60286..bc13c6c 100644 --- a/test/unit/test_input_params.py +++ b/test/unit/test_input_params.py @@ -10,7 +10,7 @@ import numpy as np import pytest -from lattice.input_params import ( +from stdface.lattice.input_params import ( input_spin_nn, input_spin, input_coulomb_v, @@ -214,13 +214,13 @@ class TestBackwardCompatibility: def test_import_from_stdface_model_util(self): """Test that all 4 functions are re-exported.""" - from stdface_model_util import ( + from stdface.core.stdface_model_util import ( input_spin_nn as isn, input_spin as isp, input_coulomb_v as icv, input_hopp as ih, ) - from lattice.input_params import ( + from stdface.lattice.input_params import ( input_spin_nn, input_spin, input_coulomb_v, diff --git a/test/unit/test_interaction_builder.py b/test/unit/test_interaction_builder.py index f8330a5..94583ff 100644 --- a/test/unit/test_interaction_builder.py +++ b/test/unit/test_interaction_builder.py @@ -10,8 +10,8 @@ import numpy as np import pytest -from stdface_vals import StdIntList -from lattice.interaction_builder import ( +from stdface.core.stdface_vals import StdIntList +from stdface.lattice.interaction_builder import ( trans, hopping, hubbard_local, @@ -379,7 +379,7 @@ class TestBackwardCompatibility: def test_import_from_stdface_model_util(self): """Test that all 8 functions are re-exported.""" - from stdface_model_util import ( + from stdface.core.stdface_model_util import ( trans as t, hopping as h, hubbard_local as hl, @@ -389,7 +389,7 @@ def test_import_from_stdface_model_util(self): coulomb as c, malloc_interactions as mi, ) - from lattice.interaction_builder import ( + from stdface.lattice.interaction_builder import ( trans, hopping, hubbard_local, diff --git a/test/unit/test_interaction_writer.py b/test/unit/test_interaction_writer.py index 32b6f3b..bd7200a 100644 --- a/test/unit/test_interaction_writer.py +++ b/test/unit/test_interaction_writer.py @@ -8,8 +8,8 @@ import os import tempfile -from stdface_vals import StdIntList -from writer.interaction_writer import ( +from stdface.core.stdface_vals import StdIntList +from stdface.writer.interaction_writer import ( print_interactions, _merge_1idx, _merge_2idx, diff --git a/test/unit/test_keyword_parser.py b/test/unit/test_keyword_parser.py index ba8342f..ad74a65 100644 --- a/test/unit/test_keyword_parser.py +++ b/test/unit/test_keyword_parser.py @@ -10,7 +10,7 @@ import pytest -from keyword_parser import ( +from stdface.core.keyword_parser import ( trim_space_quote, store_with_check_dup_s, store_with_check_dup_sl, @@ -32,7 +32,7 @@ _grid3x3_keywords, NaN_i, ) -from stdface_vals import StdIntList, SolverType +from stdface.core.stdface_vals import StdIntList, SolverType import numpy as np diff --git a/test/unit/test_lattice_dispatch.py b/test/unit/test_lattice_dispatch.py index 56d94b9..88079c5 100644 --- a/test/unit/test_lattice_dispatch.py +++ b/test/unit/test_lattice_dispatch.py @@ -1,126 +1,143 @@ -"""Unit tests for lattice dispatch tables. +"""Unit tests for lattice dispatch tables and plugin registry. -Tests for the ``LATTICE_DISPATCH`` and ``BOOST_DISPATCH`` dict tables -in ``stdface_main``. +Tests for the lattice plugin registry and backward-compatible +``LATTICE_DISPATCH`` and ``BOOST_DISPATCH`` proxies in ``stdface_main``. """ from __future__ import annotations import pytest -from stdface_main import LATTICE_DISPATCH, BOOST_DISPATCH +from stdface.core.stdface_main import LATTICE_DISPATCH, BOOST_DISPATCH +from stdface.lattice import get_lattice, get_all_lattices, LatticePlugin -from lattice import chain_lattice -from lattice import square_lattice -from lattice import ladder -from lattice import triangular_lattice -from lattice import honeycomb_lattice -from lattice import kagome -from lattice import orthorhombic -from lattice import fc_ortho -from lattice import pyrochlore -from lattice import wannier90 as wannier90_mod +class TestLatticeRegistry: + """Tests for the lattice plugin registry.""" -class TestLatticeDispatch: - """Tests for the LATTICE_DISPATCH table.""" + def test_get_chain(self): + plugin = get_lattice("chain") + assert plugin.name == "chain" + assert plugin.ndim == 1 - def test_chain_aliases(self): - """Test that chain aliases all resolve to chain_lattice.chain.""" - assert LATTICE_DISPATCH["chain"] is chain_lattice.chain - assert LATTICE_DISPATCH["chainlattice"] is chain_lattice.chain + def test_get_chain_alias(self): + assert get_lattice("chain") is get_lattice("chainlattice") + + def test_get_square(self): + plugin = get_lattice("tetragonal") + assert plugin.name == "tetragonal" + assert plugin.ndim == 2 def test_square_aliases(self): - """Test that square/tetragonal aliases resolve correctly.""" - assert LATTICE_DISPATCH["tetragonal"] is square_lattice.tetragonal - assert LATTICE_DISPATCH["tetragonallattice"] is square_lattice.tetragonal - assert LATTICE_DISPATCH["square"] is square_lattice.tetragonal - assert LATTICE_DISPATCH["squarelattice"] is square_lattice.tetragonal - - def test_ladder_aliases(self): - """Test that ladder aliases resolve correctly.""" - assert LATTICE_DISPATCH["ladder"] is ladder.ladder - assert LATTICE_DISPATCH["ladderlattice"] is ladder.ladder - - def test_triangular_aliases(self): - """Test that triangular aliases resolve correctly.""" - assert LATTICE_DISPATCH["triangular"] is triangular_lattice.triangular - assert LATTICE_DISPATCH["triangularlattice"] is triangular_lattice.triangular - - def test_honeycomb_aliases(self): - """Test that honeycomb aliases resolve correctly.""" - assert LATTICE_DISPATCH["honeycomb"] is honeycomb_lattice.honeycomb - assert LATTICE_DISPATCH["honeycomblattice"] is honeycomb_lattice.honeycomb - - def test_kagome_aliases(self): - """Test that kagome aliases resolve correctly.""" - assert LATTICE_DISPATCH["kagome"] is kagome.kagome - assert LATTICE_DISPATCH["kagomelattice"] is kagome.kagome + p = get_lattice("tetragonal") + assert get_lattice("tetragonallattice") is p + assert get_lattice("square") is p + assert get_lattice("squarelattice") is p + + def test_get_triangular(self): + plugin = get_lattice("triangular") + assert plugin.name == "triangular" + assert plugin.ndim == 2 + + def test_get_honeycomb(self): + plugin = get_lattice("honeycomb") + assert plugin.name == "honeycomb" + assert plugin.ndim == 2 + + def test_get_kagome(self): + plugin = get_lattice("kagome") + assert plugin.name == "kagome" + assert plugin.ndim == 2 + + def test_get_ladder(self): + plugin = get_lattice("ladder") + assert plugin.name == "ladder" + assert plugin.ndim == 1 + + def test_get_orthorhombic(self): + plugin = get_lattice("orthorhombic") + assert plugin.name == "orthorhombic" + assert plugin.ndim == 3 def test_orthorhombic_aliases(self): - """Test that orthorhombic/cubic aliases resolve correctly.""" - assert LATTICE_DISPATCH["orthorhombic"] is orthorhombic.orthorhombic - assert LATTICE_DISPATCH["simpleorthorhombic"] is orthorhombic.orthorhombic - assert LATTICE_DISPATCH["cubic"] is orthorhombic.orthorhombic - assert LATTICE_DISPATCH["simplecubic"] is orthorhombic.orthorhombic + p = get_lattice("orthorhombic") + assert get_lattice("simpleorthorhombic") is p + assert get_lattice("cubic") is p + assert get_lattice("simplecubic") is p + + def test_get_fco(self): + plugin = get_lattice("fco") + assert plugin.name == "fco" + assert plugin.ndim == 3 def test_fco_aliases(self): - """Test that face-centered orthorhombic aliases resolve correctly.""" - assert LATTICE_DISPATCH["face-centeredorthorhombic"] is fc_ortho.fc_ortho - assert LATTICE_DISPATCH["fcorthorhombic"] is fc_ortho.fc_ortho - assert LATTICE_DISPATCH["fco"] is fc_ortho.fc_ortho - assert LATTICE_DISPATCH["face-centeredcubic"] is fc_ortho.fc_ortho - assert LATTICE_DISPATCH["fccubic"] is fc_ortho.fc_ortho - assert LATTICE_DISPATCH["fcc"] is fc_ortho.fc_ortho - - def test_pyrochlore(self): - """Test that pyrochlore resolves correctly.""" - assert LATTICE_DISPATCH["pyrochlore"] is pyrochlore.pyrochlore - - def test_wannier90(self): - """Test that wannier90 resolves correctly.""" - assert LATTICE_DISPATCH["wannier90"] is wannier90_mod.wannier90 - - def test_unknown_returns_none(self): - """Test that unknown lattice returns None via .get().""" + p = get_lattice("fco") + assert get_lattice("face-centeredorthorhombic") is p + assert get_lattice("fcorthorhombic") is p + assert get_lattice("face-centeredcubic") is p + assert get_lattice("fccubic") is p + assert get_lattice("fcc") is p + + def test_get_pyrochlore(self): + plugin = get_lattice("pyrochlore") + assert plugin.name == "pyrochlore" + assert plugin.ndim == 3 + + def test_get_wannier90(self): + plugin = get_lattice("wannier90") + assert plugin.name == "wannier90" + assert plugin.ndim == 3 + + def test_unknown_raises(self): + with pytest.raises(KeyError): + get_lattice("nosuchlattice") + + def test_all_plugins_are_lattice_plugins(self): + for plugin in get_all_lattices(): + assert isinstance(plugin, LatticePlugin) + + def test_all_have_setup(self): + for plugin in get_all_lattices(): + assert callable(plugin.setup) + + def test_all_have_boost(self): + for plugin in get_all_lattices(): + assert callable(plugin.boost) + + def test_get_all_lattices_count(self): + """There should be 10 unique lattice plugins.""" + assert len(get_all_lattices()) == 10 + + +class TestLatticeDispatchProxy: + """Tests for backward-compatible LATTICE_DISPATCH proxy.""" + + def test_chain_callable(self): + assert callable(LATTICE_DISPATCH["chain"]) + + def test_contains(self): + assert "chain" in LATTICE_DISPATCH + assert "nosuchlattice" not in LATTICE_DISPATCH + + def test_get_returns_none_for_unknown(self): assert LATTICE_DISPATCH.get("nosuchlattice") is None def test_all_entries_callable(self): - """Test that every entry in the dispatch table is callable.""" for name, func in LATTICE_DISPATCH.items(): assert callable(func), f"LATTICE_DISPATCH[{name!r}] is not callable" -class TestBoostDispatch: - """Tests for the BOOST_DISPATCH table.""" - - def test_chain_boost(self): - """Test chain boost alias.""" - assert BOOST_DISPATCH["chain"] is chain_lattice.chain_boost - assert BOOST_DISPATCH["chainlattice"] is chain_lattice.chain_boost - - def test_honeycomb_boost(self): - """Test honeycomb boost alias.""" - assert BOOST_DISPATCH["honeycomb"] is honeycomb_lattice.honeycomb_boost - assert BOOST_DISPATCH["honeycomblattice"] is honeycomb_lattice.honeycomb_boost - - def test_kagome_boost(self): - """Test kagome boost alias.""" - assert BOOST_DISPATCH["kagome"] is kagome.kagome_boost - assert BOOST_DISPATCH["kagomelattice"] is kagome.kagome_boost +class TestBoostDispatchProxy: + """Tests for backward-compatible BOOST_DISPATCH proxy.""" - def test_ladder_boost(self): - """Test ladder boost alias.""" - assert BOOST_DISPATCH["ladder"] is ladder.ladder_boost - assert BOOST_DISPATCH["ladderlattice"] is ladder.ladder_boost + def test_chain_boost_callable(self): + assert callable(BOOST_DISPATCH["chain"]) def test_unsupported_lattice_not_in_boost(self): - """Test that lattices without boost are not in BOOST_DISPATCH.""" assert "square" not in BOOST_DISPATCH assert "triangular" not in BOOST_DISPATCH assert "pyrochlore" not in BOOST_DISPATCH assert "wannier90" not in BOOST_DISPATCH def test_all_entries_callable(self): - """Test that every entry in the boost table is callable.""" for name, func in BOOST_DISPATCH.items(): assert callable(func), f"BOOST_DISPATCH[{name!r}] is not callable" diff --git a/test/unit/test_lattices.py b/test/unit/test_lattices.py index cde52c5..87ce269 100644 --- a/test/unit/test_lattices.py +++ b/test/unit/test_lattices.py @@ -16,15 +16,15 @@ # Each entry is (module_name, list_of_expected_functions). LATTICE_MODULES = [ - ("lattice.square_lattice", ["tetragonal"]), - ("lattice.ladder", ["ladder", "ladder_boost"]), - ("lattice.triangular_lattice", ["triangular"]), - ("lattice.honeycomb_lattice", ["honeycomb", "honeycomb_boost"]), - ("lattice.kagome", ["kagome", "kagome_boost"]), - ("lattice.orthorhombic", ["orthorhombic"]), - ("lattice.fc_ortho", ["fc_ortho"]), - ("lattice.pyrochlore", ["pyrochlore"]), - ("lattice.chain_lattice", ["chain", "chain_boost"]), + ("stdface.lattice.square_lattice", ["tetragonal"]), + ("stdface.lattice.ladder", ["ladder", "ladder_boost"]), + ("stdface.lattice.triangular_lattice", ["triangular"]), + ("stdface.lattice.honeycomb_lattice", ["honeycomb", "honeycomb_boost"]), + ("stdface.lattice.kagome", ["kagome", "kagome_boost"]), + ("stdface.lattice.orthorhombic", ["orthorhombic"]), + ("stdface.lattice.fc_ortho", ["fc_ortho"]), + ("stdface.lattice.pyrochlore", ["pyrochlore"]), + ("stdface.lattice.chain_lattice", ["chain", "chain_boost"]), ] # Flat list of (module_name, function_name) for parametrized tests. diff --git a/test/unit/test_model_method_dispatch.py b/test/unit/test_model_method_dispatch.py index d96e990..99eabb0 100644 --- a/test/unit/test_model_method_dispatch.py +++ b/test/unit/test_model_method_dispatch.py @@ -7,7 +7,7 @@ import pytest -from stdface_main import MODEL_ALIASES, MODEL_ALIASES_HPHI_BOOST, METHOD_ALIASES +from stdface.core.stdface_main import MODEL_ALIASES, MODEL_ALIASES_HPHI_BOOST, METHOD_ALIASES class TestModelAliases: diff --git a/test/unit/test_mvmc_variational.py b/test/unit/test_mvmc_variational.py index 3e18f8e..777a69e 100644 --- a/test/unit/test_mvmc_variational.py +++ b/test/unit/test_mvmc_variational.py @@ -14,8 +14,8 @@ import numpy as np import pytest -from stdface_vals import StdIntList -from writer.mvmc_variational import ( +from stdface.core.stdface_vals import StdIntList +from stdface.solvers.mvmc.variational import ( _anti_period_dot, _parity_sign, _check_commensurate, diff --git a/test/unit/test_mvmc_writer.py b/test/unit/test_mvmc_writer.py index 93216d4..67dd837 100644 --- a/test/unit/test_mvmc_writer.py +++ b/test/unit/test_mvmc_writer.py @@ -10,10 +10,10 @@ import pytest -from stdface_vals import StdIntList +from stdface.core.stdface_vals import StdIntList import numpy as np -from writer.mvmc_writer import ( +from stdface.solvers.mvmc.writer import ( print_orb, print_orb_para, print_gutzwiller, diff --git a/test/unit/test_param_check.py b/test/unit/test_param_check.py index f4ccca2..ddc413c 100644 --- a/test/unit/test_param_check.py +++ b/test/unit/test_param_check.py @@ -10,7 +10,7 @@ import numpy as np import pytest -from param_check import ( +from stdface.core.param_check import ( exit_program, print_val_d, print_val_dd, @@ -21,7 +21,7 @@ not_used_i, required_val_i, ) -from stdface_vals import NaN_i, NaN_d +from stdface.core.stdface_vals import NaN_i, NaN_d class TestExitProgram: @@ -198,11 +198,11 @@ def test_prints_value(self, capsys): class TestBackwardCompatibility: - """Test that functions are still importable from stdface_model_util.""" + """Test that functions are importable from stdface.core.stdface_model_util.""" def test_import_from_stdface_model_util(self): """Test that all extracted functions are re-exported.""" - from stdface_model_util import ( + from stdface.core.stdface_model_util import ( exit_program as ep, print_val_d as pvd, print_val_dd as pvdd, @@ -213,8 +213,7 @@ def test_import_from_stdface_model_util(self): not_used_i as nui, required_val_i as rvi, ) - # Verify they are the same objects as param_check - from param_check import ( + from stdface.core.param_check import ( exit_program, print_val_d, print_val_dd, diff --git a/test/unit/test_reset_vals_dispatch.py b/test/unit/test_reset_vals_dispatch.py index 0a4cc8a..d5fd8ea 100644 --- a/test/unit/test_reset_vals_dispatch.py +++ b/test/unit/test_reset_vals_dispatch.py @@ -13,7 +13,7 @@ import numpy as np import pytest -from stdface_main import ( +from stdface.core.stdface_main import ( _COMMON_RESET_SCALARS, _COMMON_RESET_ARRAYS, _SOLVER_RESET_SCALARS, @@ -26,7 +26,7 @@ NaN_d, NaN_c, ) -from stdface_vals import StdIntList, SolverType +from stdface.core.stdface_vals import StdIntList, SolverType def _make_stdi(solver: str) -> StdIntList: diff --git a/test/unit/test_site_util.py b/test/unit/test_site_util.py index d5a4415..294d1d4 100644 --- a/test/unit/test_site_util.py +++ b/test/unit/test_site_util.py @@ -12,8 +12,8 @@ import numpy as np import pytest -from stdface_vals import StdIntList -from lattice.site_util import ( +from stdface.core.stdface_vals import StdIntList +from stdface.lattice.site_util import ( _cell_vector, _fold_to_cell, _fold_site, _find_cell_index, _write_gnuplot_header, _write_gnuplot_bond, @@ -24,7 +24,7 @@ init_site, find_site, set_label, set_local_spin_flags, ) -from stdface_vals import ModelType, SolverType, NaN_i +from stdface.core.stdface_vals import ModelType, SolverType, NaN_i def _make_stdi_chain(L: int = 4) -> StdIntList: @@ -764,17 +764,17 @@ def test_fp_none_no_error(self): class TestBackwardCompatibility: - """Test that functions are still importable from stdface_model_util.""" + """Test that functions are importable from stdface.core.stdface_model_util.""" def test_import_from_stdface_model_util(self): """Test that all 4 functions are re-exported.""" - from stdface_model_util import ( + from stdface.core.stdface_model_util import ( _fold_site as fs, init_site as iis, find_site as fis, set_label as sl, ) - from lattice.site_util import _fold_site, init_site, find_site, set_label + from stdface.lattice.site_util import _fold_site, init_site, find_site, set_label assert fs is _fold_site assert iis is init_site assert fis is find_site diff --git a/test/unit/test_solver_writer.py b/test/unit/test_solver_writer.py index ca959a0..9dbe7c4 100644 --- a/test/unit/test_solver_writer.py +++ b/test/unit/test_solver_writer.py @@ -1,7 +1,6 @@ -"""Unit tests for solver_writer module. +"""Unit tests for solver plugin system. -Tests for the SolverWriter class hierarchy and the ``get_solver_writer`` -factory function. +Tests for the SolverPlugin classes and the plugin registry. """ from __future__ import annotations @@ -11,15 +10,12 @@ import pytest -from writer.solver_writer import ( - SolverWriter, - HPhiWriter, - MVMCWriter, - UHFWriter, - HWaveWriter, - get_solver_writer, -) -from stdface_vals import StdIntList +from stdface.plugin import SolverPlugin, get_plugin +from stdface.solvers.hphi import HPhiPlugin +from stdface.solvers.mvmc import MVMCPlugin +from stdface.solvers.uhf import UHFPlugin +from stdface.solvers.hwave import HWavePlugin +from stdface.core.stdface_vals import StdIntList # Sentinel values matching the C code NaN_i = 2147483647 @@ -103,70 +99,70 @@ def _make_stdi_for_hphi(nsite: int = 4) -> StdIntList: # ===================================================================== -# Tests for get_solver_writer factory +# Tests for get_plugin registry # ===================================================================== -class TestGetSolverWriter: - """Tests for the get_solver_writer factory function.""" +class TestGetPlugin: + """Tests for the get_plugin registry function.""" - def test_returns_hphi_writer(self): - """Test that 'HPhi' returns an HPhiWriter.""" - writer = get_solver_writer("HPhi") - assert isinstance(writer, HPhiWriter) - assert writer.name == "HPhi" + def test_returns_hphi_plugin(self): + """Test that 'HPhi' returns an HPhiPlugin.""" + plugin = get_plugin("HPhi") + assert isinstance(plugin, HPhiPlugin) + assert plugin.name == "HPhi" - def test_returns_mvmc_writer(self): - """Test that 'mVMC' returns an MVMCWriter.""" - writer = get_solver_writer("mVMC") - assert isinstance(writer, MVMCWriter) - assert writer.name == "mVMC" + def test_returns_mvmc_plugin(self): + """Test that 'mVMC' returns an MVMCPlugin.""" + plugin = get_plugin("mVMC") + assert isinstance(plugin, MVMCPlugin) + assert plugin.name == "mVMC" - def test_returns_uhf_writer(self): - """Test that 'UHF' returns a UHFWriter.""" - writer = get_solver_writer("UHF") - assert isinstance(writer, UHFWriter) - assert writer.name == "UHF" + def test_returns_uhf_plugin(self): + """Test that 'UHF' returns a UHFPlugin.""" + plugin = get_plugin("UHF") + assert isinstance(plugin, UHFPlugin) + assert plugin.name == "UHF" - def test_returns_hwave_writer(self): - """Test that 'HWAVE' returns an HWaveWriter.""" - writer = get_solver_writer("HWAVE") - assert isinstance(writer, HWaveWriter) - assert writer.name == "HWAVE" + def test_returns_hwave_plugin(self): + """Test that 'HWAVE' returns an HWavePlugin.""" + plugin = get_plugin("HWAVE") + assert isinstance(plugin, HWavePlugin) + assert plugin.name == "HWAVE" def test_raises_on_unknown_solver(self): - """Test that an unknown solver raises ValueError.""" - with pytest.raises(ValueError, match="Unknown solver"): - get_solver_writer("unknown") + """Test that an unknown solver raises KeyError.""" + with pytest.raises(KeyError, match="No solver plugin"): + get_plugin("unknown") # ===================================================================== -# Tests for SolverWriter subclasses +# Tests for SolverPlugin subclasses # ===================================================================== -class TestSolverWriterIsAbstract: +class TestSolverPluginIsAbstract: """Tests for the abstract base class.""" def test_cannot_instantiate_directly(self): - """Test that SolverWriter cannot be instantiated directly.""" + """Test that SolverPlugin cannot be instantiated directly.""" with pytest.raises(TypeError): - SolverWriter("test") + SolverPlugin() -class TestHPhiWriter: - """Tests for the HPhiWriter class.""" +class TestHPhiPlugin: + """Tests for the HPhiPlugin class.""" def test_writes_namelist_def(self): - """Test that HPhiWriter creates namelist.def.""" + """Test that HPhiPlugin creates namelist.def.""" StdI = _make_stdi_for_hphi(nsite=4) - writer = HPhiWriter() + plugin = get_plugin("HPhi") with tempfile.TemporaryDirectory() as tmpdir: orig = os.getcwd() os.chdir(tmpdir) try: - writer.write(StdI) + plugin.write(StdI) assert os.path.exists("namelist.def") assert os.path.exists("calcmod.def") assert os.path.exists("modpara.def") @@ -175,50 +171,50 @@ def test_writes_namelist_def(self): os.chdir(orig) def test_writes_trans_def(self): - """Test that HPhiWriter creates trans.def.""" + """Test that HPhiPlugin creates trans.def.""" StdI = _make_stdi_for_hphi(nsite=4) - writer = HPhiWriter() + plugin = get_plugin("HPhi") with tempfile.TemporaryDirectory() as tmpdir: orig = os.getcwd() os.chdir(tmpdir) try: - writer.write(StdI) + plugin.write(StdI) assert os.path.exists("trans.def") finally: os.chdir(orig) def test_writes_green_files(self): - """Test that HPhiWriter creates greenone.def and greentwo.def.""" + """Test that HPhiPlugin creates greenone.def and greentwo.def.""" StdI = _make_stdi_for_hphi(nsite=4) - writer = HPhiWriter() + plugin = get_plugin("HPhi") with tempfile.TemporaryDirectory() as tmpdir: orig = os.getcwd() os.chdir(tmpdir) try: - writer.write(StdI) + plugin.write(StdI) assert os.path.exists("greenone.def") assert os.path.exists("greentwo.def") finally: os.chdir(orig) -class TestUHFWriter: - """Tests for the UHFWriter class.""" +class TestUHFPlugin: + """Tests for the UHFPlugin class.""" def test_writes_expected_files(self): - """Test that UHFWriter creates the expected set of files.""" + """Test that UHFPlugin creates the expected set of files.""" StdI = _make_stdi_for_hphi(nsite=4) StdI.solver = "UHF" StdI.outputmode = "****" - writer = UHFWriter() + plugin = get_plugin("UHF") with tempfile.TemporaryDirectory() as tmpdir: orig = os.getcwd() os.chdir(tmpdir) try: - writer.write(StdI) + plugin.write(StdI) assert os.path.exists("locspn.def") assert os.path.exists("trans.def") assert os.path.exists("modpara.def") @@ -228,8 +224,8 @@ def test_writes_expected_files(self): os.chdir(orig) -class TestHWaveWriter: - """Tests for the HWaveWriter class.""" +class TestHWavePlugin: + """Tests for the HWavePlugin class.""" def test_uhfr_mode_writes_trans(self): """Test that HWAVE in uhfr mode writes trans.def.""" @@ -237,13 +233,13 @@ def test_uhfr_mode_writes_trans(self): StdI.solver = "HWAVE" StdI.calcmode = "uhfr" StdI.outputmode = "****" - writer = HWaveWriter() + plugin = get_plugin("HWAVE") with tempfile.TemporaryDirectory() as tmpdir: orig = os.getcwd() os.chdir(tmpdir) try: - writer.write(StdI) + plugin.write(StdI) assert os.path.exists("trans.def") assert os.path.exists("greenone.def") finally: diff --git a/test/unit/test_stdface_main_helpers.py b/test/unit/test_stdface_main_helpers.py index eeaecd7..85c8b4c 100644 --- a/test/unit/test_stdface_main_helpers.py +++ b/test/unit/test_stdface_main_helpers.py @@ -8,8 +8,8 @@ import pytest -from stdface_vals import StdIntList, ModelType, SolverType, MethodType, NaN_d -from stdface_main import _parse_input_file, _resolve_model_and_method +from stdface.core.stdface_vals import StdIntList, ModelType, SolverType, MethodType +from stdface.core.stdface_main import _parse_input_file, _resolve_model_and_method class TestParseInputFile: diff --git a/test/unit/test_stdface_model_util.py b/test/unit/test_stdface_model_util.py index 80c0ade..7a7daf2 100644 --- a/test/unit/test_stdface_model_util.py +++ b/test/unit/test_stdface_model_util.py @@ -10,8 +10,8 @@ import numpy as np import pytest -from stdface_vals import StdIntList -import stdface_model_util as smu +from stdface.core.stdface_vals import StdIntList +import stdface.core.stdface_model_util as smu # --------------------------------------------------------------------------- diff --git a/test/unit/test_stdface_vals.py b/test/unit/test_stdface_vals.py index 23c4b92..d7b59db 100644 --- a/test/unit/test_stdface_vals.py +++ b/test/unit/test_stdface_vals.py @@ -7,7 +7,7 @@ import numpy as np import pytest -from stdface_vals import ( +from stdface.core.stdface_vals import ( StdIntList, ModelType, SolverType, MethodType, NaN_i, NaN_d, NaN_c, UNSET_STRING, AMPLITUDE_EPS, ZERO_BODY_EPS, diff --git a/test/unit/test_version.py b/test/unit/test_version.py index 6b6ae59..b0a9f06 100644 --- a/test/unit/test_version.py +++ b/test/unit/test_version.py @@ -4,7 +4,7 @@ """ from __future__ import annotations -import version +import stdface.core.version as version class TestVersionConstants: diff --git a/test/unit/test_wannier90.py b/test/unit/test_wannier90.py index 6601d4b..2215408 100644 --- a/test/unit/test_wannier90.py +++ b/test/unit/test_wannier90.py @@ -11,8 +11,8 @@ import numpy as np import pytest -from stdface_vals import StdIntList -from lattice import wannier90 as w90 +from stdface.core.stdface_vals import StdIntList +from stdface.lattice import wannier90 as w90 # --------------------------------------------------------------------------- @@ -227,7 +227,7 @@ class TestModuleStructure: def test_import(self): """wannier90 module should import without error.""" - from lattice import wannier90 + from stdface.lattice import wannier90 assert wannier90 is not None def test_wannier90_function_exists(self): @@ -750,7 +750,7 @@ def test_none_returns_notcorrect(self): def test_unset_string_returns_notcorrect(self): """UNSET_STRING sentinel should map to NOTCORRECT.""" - from stdface_vals import UNSET_STRING + from stdface.core.stdface_vals import UNSET_STRING assert w90._parse_double_counting_mode(UNSET_STRING) == w90._DCMode.NOTCORRECT def test_hartree_returns_hartree(self): @@ -777,7 +777,7 @@ def test_return_type_is_dcmode(self): def test_dc_mode_map_keys(self): """_DC_MODE_MAP should contain exactly the expected keys.""" - from stdface_vals import UNSET_STRING + from stdface.core.stdface_vals import UNSET_STRING expected_keys = {"none", UNSET_STRING, "hartree", "hartree_u", "full"} assert set(w90._DC_MODE_MAP.keys()) == expected_keys