@@ -49,6 +49,7 @@ def __init__(self, params):
4949 self .alpha = torch .tensor (params .get ("alpha" , [18.0 ]), dtype = torch .float64 , device = self .device )
5050 self .qgrid = torch .tensor (params .get ("qgrid" , [[- 10 + i * 0.1 ] for i in range (int ((10 - (- 10 )) / 0.1 ) + 1 )] ), dtype = torch .float64 , device = self .device ) #(N, D)
5151 self .ngrids = len (self .qgrid ) # N
52+ self .ndof = self .qgrid .shape [1 ]
5253 self .nstates = params .get ("nstates" , 2 )
5354 self .istate = params .get ("istate" , 0 )
5455 self .elec_ampl = params .get ("elec_ampl" , torch .tensor ([1.0 + 0.j ]* self .ngrids , dtype = torch .cdouble ))
@@ -82,6 +83,7 @@ def __init__(self, params):
8283 self .kinetic_energy = []
8384 self .potential_energy = []
8485 self .total_energy = []
86+ self .average_pos = []
8587 self .population_right = []
8688 self .denmat = []
8789 self .norm = []
@@ -97,7 +99,7 @@ def chi_overlap(self):
9799 self .Snucl = torch .exp (exponent )
98100
99101 def chi_kinetic (self ):
100- r """
102+ """
101103 Compute nuclear kinetic energy matrix Tnucl[i,j] = <g(x; qgrid[i]) | T | g(x; qgrid[j])>,
102104 with T = \sum_{\n u} -0.5* m_ν^{-1} \partial^{2}/\partial x_{\n u}^2.
103105 """
@@ -106,7 +108,7 @@ def chi_kinetic(self):
106108 tau_sum = torch .sum (tau , dim = 2 ) # (N, N)
107109
108110 self .Tnucl = self .Snucl * tau_sum # (N, N)
109-
111+
110112 def build_compound_overlap (self ):
111113 """
112114 Build the compound nuclear-electronic overlap matrix self.S (ndim, ndim)
@@ -263,6 +265,8 @@ def save_results(self, step):
263265 self .potential_energy .append (self .compute_potential_energy ())
264266 if "total_energy" in self .properties_to_save :
265267 self .total_energy .append (self .compute_total_energy ())
268+ if "average_pos" in self .properties_to_save :
269+ self .average_pos .append (self .compute_average_pos ())
266270 if "C_save" in self .properties_to_save :
267271 self .C_save .append (self .Ccurr )
268272
@@ -308,9 +312,8 @@ def compute_kinetic_energy(self):
308312
309313 # Rebuild compound kinetic matrix: T4D * Selec4D
310314 Selec4D = self .Selec .view (s , N , s , N )
311- T4D = self .Tnucl .unsqueeze (0 ).unsqueeze (2 ) # (1, n, 1, m)
312- T4D_compound = Selec4D * T4D
313- T_compound = T4D_compound .permute (0 , 1 , 2 , 3 ).reshape (ndim , ndim )
315+ T4D = self .Tnucl [None , :, None , :]
316+ T_compound = (Selec4D * T4D ).reshape (ndim , ndim )
314317
315318 Cvec = self .Ccurr
316319
@@ -327,11 +330,10 @@ def compute_potential_energy(self):
327330 N , s , ndim = self .ngrids , self .nstates , self .ndim
328331
329332 Selec4D = self .Selec .view (s , N , s , N )
330- S4D = self .Snucl . unsqueeze ( 0 ). unsqueeze ( 2 ) # (1, n, 1, m)
333+ S4D = self .Snucl [ None , :, None , :]
331334 Ej4D = self .E [None , None , :, :] # (1,1,j,m)
332335
333- V4D_compound = Selec4D * (Ej4D * S4D )
334- V_compound = V4D_compound .permute (0 , 1 , 2 , 3 ).reshape (ndim , ndim )
336+ V_compound = (Selec4D * (Ej4D * S4D )).reshape (ndim , ndim )
335337
336338 Cvec = self .Ccurr
337339
@@ -351,7 +353,31 @@ def compute_total_energy(self):
351353 denom = torch .vdot (Cvec , self .S @ Cvec ).real
352354
353355 return numer / denom
356+
357+ def compute_average_pos (self ):
358+ """
359+ Compute average position as <q_i> = \sum_i C^+ Q C / C^+ S C for a single step.
360+ """
361+ N , s , ndim = self .ngrids , self .nstates , self .ndim
354362
363+ Cvec = self .Ccurr
364+
365+ denom = torch .vdot (Cvec , self .S @ Cvec ).real
366+ Selec4D = self .Selec .view (s , N , s , N )
367+
368+ avg_q = []
369+ for idof in range (self .ndof ):
370+ q_med = 0.5 * (self .qgrid [:, None , idof ] + self .qgrid [None ,:,idof ])
371+ Qnucl = self .Snucl * q_med
372+ Q4D = Qnucl [None , :, None , :]
373+ Q4D_compound = Selec4D * Q4D
374+ Q_compound = Q4D_compound .permute (0 , 1 , 2 , 3 ).reshape (ndim , ndim )
375+
376+ numer = torch .vdot (Cvec , Q_compound @ Cvec ).real
377+ avg_q .append (numer / denom )
378+
379+ return avg_q
380+
355381 def save (self ):
356382 torch .save ( {"q0" :self .q0 ,
357383 "p0" :self .p0 ,
@@ -376,6 +402,7 @@ def save(self):
376402 "kinetic_energy" :self .kinetic_energy ,
377403 "potential_energy" :self .potential_energy ,
378404 "total_energy" :self .total_energy ,
405+ "average_pos" :self .average_pos ,
379406 "population_right" :self .population_right ,
380407 "denmat" :self .denmat ,
381408 "norm" :self .norm
0 commit comments