Skip to content

Commit ac533e1

Browse files
committed
Add average position method and trim others
1 parent 44da4b1 commit ac533e1

File tree

1 file changed

+35
-8
lines changed

1 file changed

+35
-8
lines changed

src/libra_py/dynamics/ldr_torch/compute.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(self, params):
4949
self.alpha = torch.tensor(params.get("alpha", [18.0]), dtype=torch.float64, device=self.device)
5050
self.qgrid = torch.tensor(params.get("qgrid", [[-10 + i * 0.1] for i in range(int((10 - (-10)) / 0.1) + 1)] ), dtype=torch.float64, device=self.device) #(N, D)
5151
self.ngrids = len(self.qgrid) # N
52+
self.ndof = self.qgrid.shape[1]
5253
self.nstates = params.get("nstates", 2)
5354
self.istate = params.get("istate", 0)
5455
self.elec_ampl = params.get("elec_ampl", torch.tensor([1.0+0.j]*self.ngrids, dtype=torch.cdouble))
@@ -82,6 +83,7 @@ def __init__(self, params):
8283
self.kinetic_energy = []
8384
self.potential_energy = []
8485
self.total_energy = []
86+
self.average_pos = []
8587
self.population_right = []
8688
self.denmat = []
8789
self.norm = []
@@ -97,7 +99,7 @@ def chi_overlap(self):
9799
self.Snucl = torch.exp(exponent)
98100

99101
def chi_kinetic(self):
100-
r"""
102+
"""
101103
Compute nuclear kinetic energy matrix Tnucl[i,j] = <g(x; qgrid[i]) | T | g(x; qgrid[j])>,
102104
with T = \sum_{\nu} -0.5* m_ν^{-1} \partial^{2}/\partial x_{\nu}^2.
103105
"""
@@ -106,7 +108,7 @@ def chi_kinetic(self):
106108
tau_sum = torch.sum(tau, dim=2) # (N, N)
107109

108110
self.Tnucl = self.Snucl * tau_sum # (N, N)
109-
111+
110112
def build_compound_overlap(self):
111113
"""
112114
Build the compound nuclear-electronic overlap matrix self.S (ndim, ndim)
@@ -263,6 +265,8 @@ def save_results(self, step):
263265
self.potential_energy.append(self.compute_potential_energy())
264266
if "total_energy" in self.properties_to_save:
265267
self.total_energy.append(self.compute_total_energy())
268+
if "average_pos" in self.properties_to_save:
269+
self.average_pos.append(self.compute_average_pos())
266270
if "C_save" in self.properties_to_save:
267271
self.C_save.append(self.Ccurr)
268272

@@ -308,9 +312,8 @@ def compute_kinetic_energy(self):
308312

309313
# Rebuild compound kinetic matrix: T4D * Selec4D
310314
Selec4D = self.Selec.view(s, N, s, N)
311-
T4D = self.Tnucl.unsqueeze(0).unsqueeze(2) # (1, n, 1, m)
312-
T4D_compound = Selec4D * T4D
313-
T_compound = T4D_compound.permute(0, 1, 2, 3).reshape(ndim, ndim)
315+
T4D = self.Tnucl[None, :, None, :]
316+
T_compound = (Selec4D * T4D).reshape(ndim, ndim)
314317

315318
Cvec = self.Ccurr
316319

@@ -327,11 +330,10 @@ def compute_potential_energy(self):
327330
N, s, ndim = self.ngrids, self.nstates, self.ndim
328331

329332
Selec4D = self.Selec.view(s, N, s, N)
330-
S4D = self.Snucl.unsqueeze(0).unsqueeze(2) # (1, n, 1, m)
333+
S4D = self.Snucl[None, :, None, :]
331334
Ej4D = self.E[None, None, :, :] # (1,1,j,m)
332335

333-
V4D_compound = Selec4D * (Ej4D * S4D)
334-
V_compound = V4D_compound.permute(0, 1, 2, 3).reshape(ndim, ndim)
336+
V_compound = (Selec4D * (Ej4D * S4D)).reshape(ndim, ndim)
335337

336338
Cvec = self.Ccurr
337339

@@ -351,7 +353,31 @@ def compute_total_energy(self):
351353
denom = torch.vdot(Cvec, self.S @ Cvec).real
352354

353355
return numer / denom
356+
357+
def compute_average_pos(self):
358+
"""
359+
Compute average position as <q_i> = \sum_i C^+ Q C / C^+ S C for a single step.
360+
"""
361+
N, s, ndim = self.ngrids, self.nstates, self.ndim
354362

363+
Cvec = self.Ccurr
364+
365+
denom = torch.vdot(Cvec, self.S @ Cvec).real
366+
Selec4D = self.Selec.view(s, N, s, N)
367+
368+
avg_q = []
369+
for idof in range(self.ndof):
370+
q_med = 0.5 * (self.qgrid[:, None, idof] + self.qgrid[None,:,idof])
371+
Qnucl = self.Snucl * q_med
372+
Q4D = Qnucl[None, :, None, :]
373+
Q4D_compound = Selec4D * Q4D
374+
Q_compound = Q4D_compound.permute(0, 1, 2, 3).reshape(ndim, ndim)
375+
376+
numer = torch.vdot(Cvec, Q_compound @ Cvec).real
377+
avg_q.append(numer / denom)
378+
379+
return avg_q
380+
355381
def save(self):
356382
torch.save( {"q0":self.q0,
357383
"p0":self.p0,
@@ -376,6 +402,7 @@ def save(self):
376402
"kinetic_energy":self.kinetic_energy,
377403
"potential_energy":self.potential_energy,
378404
"total_energy":self.total_energy,
405+
"average_pos":self.average_pos,
379406
"population_right":self.population_right,
380407
"denmat":self.denmat,
381408
"norm":self.norm

0 commit comments

Comments
 (0)