Skip to content

Commit ae935ae

Browse files
committed
Added the LDR solver with PyTorch
1 parent c80872b commit ae935ae

File tree

3 files changed

+374
-0
lines changed

3 files changed

+374
-0
lines changed

src/libra_py/dynamics/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
__all__ = ["bohmian",
1111
"exact",
1212
"exact_torch",
13+
"ldr_torch",
1314
"heom",
1415
"qtag",
1516
"tsh",
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# ***********************************************************
2+
# * Copyright (C) 2025 Alexey V. Akimov
3+
# * This file is distributed under the terms of the
4+
# * GNU General Public License as published by the
5+
# * Free Software Foundation; either version 3 of the
6+
# * License, or (at your option) any later version.
7+
# * http://www.gnu.org/copyleft/gpl.txt
8+
# ***********************************************************/
9+
10+
__all__ = ["compute",
11+
]
Lines changed: 362 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,362 @@
1+
# *********************************************************************************
2+
# * Copyright (C) 2025 Alexey V. Akimov
3+
# *
4+
# * This file is distributed under the terms of the GNU General Public License
5+
# * as published by the Free Software Foundation, either version 3 of
6+
# * the License, or (at your option) any later version.
7+
# * See the file LICENSE in the root directory of this distribution
8+
# * or <http://www.gnu.org/licenses/>.
9+
# ***********************************************************************************
10+
"""
11+
.. module:: compute
12+
:platform: Unix, Windows
13+
:synopsis: This module implements functions for doing local diabatic representation (LDR) dynamics with PyTorch
14+
List of functions:
15+
* sech # temporary here
16+
* Martens_model # temporary here
17+
* gaussian_wavepacket
18+
List of classes:
19+
* ldr_solver
20+
21+
.. moduleauthor:: Alexey V. Akimov, Daeho Han
22+
23+
"""
24+
25+
__author__ = "Alexey V. Akimov"
26+
__copyright__ = "Copyright 2025 Alexey V. Akimov"
27+
__credits__ = ["Alexey V. Akimov"]
28+
__license__ = "GNU-3"
29+
__version__ = "1.0"
30+
__maintainer__ = "Alexey V. Akimov"
31+
__email__ = "alexvakimov@gmail.com"
32+
__url__ = "https://github.com/Quantum-Dynamics-Hub/libra-code"
33+
34+
35+
36+
import torch
37+
38+
39+
class ldr_solver:
40+
def __init__(self, params):
41+
self.prefix = params.get("prefix", "ldr-solution")
42+
self.device = params.get("device", torch.device("cuda" if torch.cuda.is_available() else "cpu"))
43+
self.hbar = 1.0
44+
self.Hamiltonian_scheme = "symmetrized"
45+
self.q0 = torch.tensor(params.get("q0", [0.0]), dtype=torch.float64, device=self.device)
46+
self.p0 = torch.tensor(params.get("p0", [0.0]), dtype=torch.float64, device=self.device)
47+
self.k = torch.tensor(params.get("k", [0.001]), dtype=torch.float64, device=self.device)
48+
self.mass = torch.tensor(params.get("mass", [2000.0]), dtype=torch.float64, device=self.device)
49+
self.alpha = torch.tensor(params.get("alpha", [18.0]), dtype=torch.float64, device=self.device)
50+
self.qgrid = torch.tensor(params.get("qgrid", [[-10 + i * 0.1] for i in range(int((10 - (-10)) / 0.1) + 1)] ), dtype=torch.float64, device=self.device) #(N, D)
51+
self.ngrids = len(self.qgrid) # N
52+
self.nstates = params.get("nstates", 2)
53+
self.istate = params.get("istate", 0)
54+
55+
self.save_every_n_steps = params.get("save_every_n_steps", 1)
56+
self.properties_to_save = params.get("properties_to_save", ["time", "population_right"])
57+
self.dt = params.get("dt", 0.01)
58+
self.nsteps = params.get("nsteps", 500)
59+
self.ndim = self.nstates * self.ngrids
60+
61+
self.E = params.get("E", torch.zeros(self.nstates, self.ngrids, device=self.device) )
62+
63+
Selec_default = torch.zeros(self.ndim, self.ndim, dtype=torch.cdouble, device=self.device)
64+
for i in range(self.nstates):
65+
start, end = i * self.ngrids, (i + 1) * self.ngrids
66+
Selec_default[start:end, start:end] = torch.eye(self.ngrids, device=self.device)
67+
self.Selec = params.get("Selec", Selec_default )
68+
69+
# Computed with LDR methods
70+
self.C0 = torch.zeros(self.ndim, dtype=torch.cdouble, device=self.device)
71+
self.Ccurr = torch.zeros(self.ndim, dtype=torch.cdouble, device=self.device)
72+
73+
self.Snucl = torch.eye(self.ngrids, dtype=torch.cdouble, device=self.device)
74+
self.Tnucl = torch.zeros(self.ngrids, self.ngrids, dtype=torch.cdouble, device=self.device)
75+
76+
self.S, self.H = torch.zeros(self.ndim, self.ndim, dtype=torch.cdouble, device=self.device), torch.zeros(self.ndim, self.ndim, dtype=torch.cdouble, device=self.device)
77+
self.U = torch.zeros(self.ndim, self.ndim, dtype=torch.cdouble, device=self.device)
78+
79+
self.time = []
80+
self.kinetic_energy = []
81+
self.potential_energy = []
82+
self.total_energy = []
83+
self.population_right = []
84+
self.norm = []
85+
self.C_save = []
86+
87+
def chi_overlap(self):
88+
"""
89+
Compute nuclear overlap matrix Snucl[i, j] for the mesh qmesh.
90+
"""
91+
delta = self.qgrid[:, None, :] - self.qgrid[None, :, :] # (N, N, D)
92+
exponent = -0.5 * torch.sum(self.alpha * delta**2, dim=2) # (N, N)
93+
self.Snucl = torch.exp(exponent)
94+
95+
def chi_kinetic(self):
96+
r"""
97+
Compute nuclear kinetic energy matrix Tnucl[i,j] = <g(x; qgrid[i]) | T | g(x; qgrid[j])>,
98+
with T = Σ_ν -½ m_ν^{-1} ∂²/∂x_ν².
99+
"""
100+
delta = self.qgrid[:, None, :] - self.qgrid[None, :, :] # (N, N, D)
101+
tau = self.alpha / (2.0 * self.mass) * (1.0 - self.alpha * delta**2) # (N, N, D)
102+
tau_sum = torch.sum(tau, dim=2) # (N, N)
103+
104+
self.Tnucl = self.Snucl * tau_sum # (N, N)
105+
106+
def build_compound_overlap(self):
107+
"""
108+
Build the compound nuclear-electronic overlap matrix self.S (ndim, ndim)
109+
"""
110+
N, s, ndim = self.ngrids, self.nstates, self.ndim
111+
112+
# Reshape Selec[a, b] -> (i, n, j, m) with:
113+
# a = i * N + n
114+
# b = j * N + m
115+
Selec4D = self.Selec.view(s, N, s, N) # (i, n, j, m)
116+
117+
Snucl4D = self.Snucl.unsqueeze(0).unsqueeze(2) # (1, n, 1, m)
118+
119+
S4D = Selec4D * Snucl4D
120+
121+
# Reshape back to (ndim, ndim) with compound indices
122+
self.S = S4D.permute(0, 1, 2, 3).reshape(ndim, ndim)
123+
124+
def build_compound_hamiltonian(self):
125+
"""
126+
Build the compound nuclear-electronic Hamiltonian self.H (ndim, ndim) using different schemes.
127+
"""
128+
N, s, ndim = self.ngrids, self.nstates, self.ndim
129+
scheme = self.Hamiltonian_scheme
130+
Selec4D = self.Selec.view(s, N, s, N) # (s, N, s, N)
131+
T4D = self.Tnucl.unsqueeze(0).unsqueeze(2) # (1, N, 1, N)
132+
S4D = self.Snucl.unsqueeze(0).unsqueeze(2) # (1, N, 1, N)
133+
134+
if scheme == 'as_is':
135+
Ej4D = self.E[None, None, :, :] # (1, 1, s, N)
136+
bracket4D = T4D + Ej4D * S4D
137+
elif scheme == 'symmetrized':
138+
Ei4D = self.E[:, :, None, None] # (s, N, 1, 1)
139+
Ej4D = self.E[None, None, :, :] # (1, 1, s, N)
140+
Eavg4D = 0.5 * (Ei4D + Ej4D) # (s, N, s, N)
141+
bracket4D = T4D + Eavg4D * S4D
142+
elif scheme == 'diagonal':
143+
# Build Kronecker deltas for electronic and nuclear indices
144+
delta_ij = torch.eye(s, device=self.device).unsqueeze(1).unsqueeze(3) # (s, 1, s, 1)
145+
delta_nm = torch.eye(N, device=self.device).unsqueeze(0).unsqueeze(2) # (1, N, 1, N)
146+
delta4D = delta_ij * delta_nm
147+
148+
Ej4D = self.E[None, None, :, :] # (1, 1, s, N)
149+
bracket4D = T4D + Ej4D * S4D * delta4D
150+
151+
else:
152+
raise ValueError(f"Unknown Hamiltonian scheme: {scheme}")
153+
154+
H4D = Selec4D * bracket4D
155+
self.H = H4D.reshape(ndim, ndim)
156+
157+
def compute_propagator(self):
158+
"""
159+
Compute the exponential propagator matrix U = exp(-i H dt) in the non-orthogonal basis
160+
using the Lowdin orthonormalization.
161+
162+
"""
163+
S = self.S
164+
H = self.H
165+
dt = self.dt
166+
167+
evals_S, evecs_S = torch.linalg.eigh(S)
168+
169+
S_half = (evecs_S @ torch.diag(evals_S.sqrt().to(dtype=torch.cdouble)) @ evecs_S.T).to(dtype=torch.cdouble)
170+
S_invhalf = (evecs_S @ torch.diag((1.0 / evals_S).sqrt().to(dtype=torch.cdouble)) @ evecs_S.T).to(dtype=torch.cdouble)
171+
172+
H_ortho = S_invhalf @ H @ S_invhalf
173+
174+
evals_H, evecs_H = torch.linalg.eigh(H_ortho)
175+
176+
exp_diag = torch.diag(torch.exp(-1j * evals_H * dt))
177+
U_ortho = evecs_H @ exp_diag @ evecs_H.conj().T
178+
179+
self.U = S_invhalf @ U_ortho @ S_half
180+
181+
182+
def initialize_C(self):
183+
"""
184+
Initialize coefficient vector self.C0 at t=0, assuming:
185+
- electronic state self.istate
186+
- nuclear wavefunction is a Gaussian centered at self.q0 and self.p0, i.e.
187+
chi0 = exp( alpha0 * (qgrid point - self.q0)**2 + i * self.p0 * (qgrid point - self.q0) )
188+
alpha0 = 0.5/s_q **2; s_q = (1/(self.k * self.mass)) **0.25
189+
Sets:
190+
self.C0 : complex-valued coefficient vector of shape (ndim)
191+
"""
192+
N, ist = self.ngrids, self.istate
193+
194+
s_q = (1.0/(self.k*self.mass)) ** 0.25
195+
alpha0 = 1/(2*s_q**2)
196+
197+
# Compute Gaussian nuclear wavefunction at each grid point
198+
for n in range(N):
199+
index = ist * N + n
200+
201+
qn = self.qgrid[n] # (D,)
202+
delta = qn - self.q0 # (D,)
203+
exponent = -torch.dot(alpha0, delta**2) + .1j * torch.dot(self.p0, delta)
204+
205+
self.C0[index] = torch.exp(exponent)
206+
207+
# Normalize
208+
overlap = torch.matmul(self.S, self.C0)
209+
norm = torch.sqrt(torch.vdot(self.C0, overlap))
210+
211+
self.C0 /= norm
212+
213+
def propagate(self):
214+
"""
215+
Propagate coefficient.
216+
"""
217+
# Initialize first step with normalized initial wavefunction
218+
self.Ccurr = self.C0.clone()
219+
220+
print(F"step = 0")
221+
self.save_results(0)
222+
223+
for step in range(1, self.nsteps):
224+
Cvec = self.Ccurr.clone()
225+
self.Ccurr = self.U @ Cvec
226+
227+
if step % self.save_every_n_steps == 0:
228+
print(F"step = {step}")
229+
self.save_results(step)
230+
231+
def save_results(self, step):
232+
if "time" in self.properties_to_save:
233+
self.time.append(step*self.dt)
234+
if "norm" in self.properties_to_save:
235+
overlap = torch.matmul(self.S, self.Ccurr)
236+
self.norm.append(torch.sqrt(torch.vdot(self.Ccurr, overlap)))
237+
if "population_right" in self.properties_to_save:
238+
self.population_right.append(self.compute_populations())
239+
if "kinetic_energy" in self.properties_to_save:
240+
self.kinetic_energy.append(self.compute_kinetic_energy())
241+
if "potential_energy" in self.properties_to_save:
242+
self.potential_energy.append(self.compute_potential_energy())
243+
if "total_energy" in self.properties_to_save:
244+
self.total_energy.append(self.compute_total_energy())
245+
if "C_save" in self.properties_to_save:
246+
self.C_save.append(self.Ccurr)
247+
248+
def compute_populations(self):
249+
"""
250+
Compute electronic state population for a single step.
251+
"""
252+
N, s = self.ngrids, self.nstates
253+
Cvec = self.Ccurr
254+
255+
# Compute SC once: shape (ndim,)
256+
SC = self.S @ Cvec
257+
258+
C_blocks = Cvec.view(s, N)
259+
SC_blocks = SC.view(s, N)
260+
261+
# Compute P[i] = sum_j <C_j|S_{ji}|C_i> = Re[ sum_N (C_j*) * SC_j ]
262+
P = torch.sum(C_blocks.conj() * SC_blocks, dim=1).real
263+
264+
return P
265+
266+
def compute_kinetic_energy(self):
267+
"""
268+
Compute nuclear kinetic energy as <C|T|C>/<C|S|C> for a single step.
269+
"""
270+
N, s, ndim = self.ngrids, self.nstates, self.ndim
271+
272+
# Rebuild compound kinetic matrix: T4D * Selec4D
273+
Selec4D = self.Selec.view(s, N, s, N)
274+
T4D = self.Tnucl.unsqueeze(0).unsqueeze(2) # (1, n, 1, m)
275+
T4D_compound = Selec4D * T4D
276+
T_compound = T4D_compound.permute(0, 1, 2, 3).reshape(ndim, ndim)
277+
278+
Cvec = self.Ccurr
279+
280+
numer = torch.vdot(Cvec, T_compound @ Cvec).real
281+
denom = torch.vdot(Cvec, self.S @ Cvec).real
282+
283+
return numer / denom
284+
285+
286+
def compute_potential_energy(self):
287+
"""
288+
Compute potential energy as <C|V|C>/<C|S|C> for a single step.
289+
"""
290+
N, s, ndim = self.ngrids, self.nstates, self.ndim
291+
292+
Selec4D = self.Selec.view(s, N, s, N)
293+
S4D = self.Snucl.unsqueeze(0).unsqueeze(2) # (1, n, 1, m)
294+
Ej4D = self.E[None, None, :, :] # (1,1,j,m)
295+
296+
V4D_compound = Selec4D * (Ej4D * S4D)
297+
V_compound = V4D_compound.permute(0, 1, 2, 3).reshape(ndim, ndim)
298+
299+
Cvec = self.Ccurr
300+
301+
numer = torch.vdot(Cvec, V_compound @ Cvec).real
302+
denom = torch.vdot(Cvec, self.S @ Cvec).real
303+
304+
return numer / denom
305+
306+
307+
def compute_total_energy(self):
308+
"""
309+
Compute total energy as <C|H|C>/<C|S|C> for a single step.
310+
"""
311+
Cvec = self.Ccurr
312+
313+
numer = torch.vdot(Cvec, self.H @ Cvec).real
314+
denom = torch.vdot(Cvec, self.S @ Cvec).real
315+
316+
return numer / denom
317+
318+
def save(self):
319+
torch.save( {"q0":self.q0,
320+
"p0":self.p0,
321+
"k":self.k,
322+
"mass":self.mass,
323+
"alpha":self.alpha,
324+
"qgrid":self.qgrid,
325+
"nstates":self.nstates,
326+
"istate":self.istate,
327+
"Snucl":self.Snucl,
328+
"Tnucl":self.Tnucl,
329+
"E":self.E,
330+
"Selec":self.Selec,
331+
"S":self.S,
332+
"H":self.H,
333+
"U":self.U,
334+
"C_save":self.C_save,
335+
"save_every_n_steps":self.save_every_n_steps,
336+
"Hamiltonian_scheme": self.Hamiltonian_scheme,
337+
"dt":self.dt, "nsteps":self.nsteps,
338+
"time":self.time,
339+
"kinetic_energy":self.kinetic_energy,
340+
"potential_energy":self.potential_energy,
341+
"total_energy":self.total_energy,
342+
"population_right":self.population_right,
343+
"norm":self.norm
344+
}, F"{self.prefix}.pt" )
345+
346+
def buildSH(self):
347+
self.chi_overlap()
348+
self.chi_kinetic()
349+
self.build_compound_overlap()
350+
self.build_compound_hamiltonian()
351+
352+
def solve(self):
353+
print("Building overlap and Hamiltonian matrices")
354+
self.buildSH()
355+
print("Computing the time propagator")
356+
self.compute_propagator()
357+
print("Initializing Coefficients")
358+
self.initialize_C()
359+
print("Propagating Coefficients")
360+
self.propagate()
361+
self.save()
362+

0 commit comments

Comments
 (0)