Skip to content

Commit c3ad2f7

Browse files
committed
Remove the nested structure of the interface function by taking the electronic information as model_params as well
1 parent 93de946 commit c3ad2f7

File tree

1 file changed

+30
-24
lines changed
  • src/libra_py/workflows/nbra

1 file changed

+30
-24
lines changed

src/libra_py/workflows/nbra/rpi.py

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,32 @@
3434
import libra_py.data_read as data_read
3535
import util.libutil as comn
3636
import libra_py.dynamics.tsh.compute as tsh_dynamics
37+
38+
class tmp:
39+
pass
40+
41+
def compute_model(q, params, full_id):
42+
"""
43+
This function serves as an interface function for a serial patch dynamics calculation.
44+
"""
45+
46+
timestep = params["timestep"]
47+
nst = params["nstates"]
48+
E = params["E"]
49+
NAC = params["NAC"]
50+
Hvib = params["Hvib"]
51+
St = params["St"]
52+
53+
obj = tmp()
54+
55+
obj.ham_adi = data_conv.nparray2CMATRIX( np.diag(E[timestep, : ]) )
56+
obj.nac_adi = data_conv.nparray2CMATRIX( NAC[timestep, :, :] )
57+
obj.hvib_adi = data_conv.nparray2CMATRIX( Hvib[timestep, :, :] )
58+
obj.basis_transform = CMATRIX(nst,nst); obj.basis_transform.identity() #basis_transform
59+
obj.time_overlap_adi = data_conv.nparray2CMATRIX( St[timestep, :, :] )
60+
61+
return obj
62+
3763

3864
def run_patch_dyn_serial(rpi_params, ibatch, ipatch, istate):
3965
"""
@@ -107,7 +133,9 @@ def run_patch_dyn_serial(rpi_params, ibatch, ipatch, istate):
107133

108134
NSTEPS = fstep - istep # The patch duration
109135

110-
# Read the vibronic Hamiltonian and time overlap files in a selected basis
136+
# Set the NBRA model by reading the vibronic Hamiltonian and time overlap data
137+
model_params = {"timestep":0, "icond":0, "model0":0, "nstates":nstates}
138+
111139
E = []
112140
for step in range(istep, fstep):
113141
energy_filename = F"{path_to_save_Hvibs}/Hvib_{basis_type}_{step}_re.npz"
@@ -133,29 +161,7 @@ def run_patch_dyn_serial(rpi_params, ibatch, ipatch, istate):
133161
NAC = np.array(NAC)
134162
Hvib = np.array(Hvib)
135163

136-
# The interface function
137-
class tmp:
138-
pass
139-
140-
def compute_model(q, params, full_id):
141-
"""
142-
This function serves as an interface function for a serial patch dynamics calculation.
143-
"""
144-
145-
timestep = params["timestep"]
146-
nst = params["nstates"]
147-
obj = tmp()
148-
149-
obj.ham_adi = data_conv.nparray2CMATRIX( np.diag(E[timestep, : ]) )
150-
obj.nac_adi = data_conv.nparray2CMATRIX( NAC[timestep, :, :] )
151-
obj.hvib_adi = data_conv.nparray2CMATRIX( Hvib[timestep, :, :] )
152-
obj.basis_transform = CMATRIX(nst,nst); obj.basis_transform.identity() #basis_transform
153-
obj.time_overlap_adi = data_conv.nparray2CMATRIX( St[timestep, :, :] )
154-
155-
return obj
156-
157-
# Set the NBRA model
158-
model_params = {"timestep":0, "icond":0, "model0":0, "nstates":nstates}
164+
model_params.update({"E": E, "St": St, "NAC": NAC, "Hvib": Hvib})
159165

160166
# Setting the coherent Ehrenfest propagation. Define the argument-dependent part first.
161167
dyn_general = {"nsteps":NSTEPS, "nstates":nstates, "dt":dt, "nfiles": NSTEPS, "which_adi_states":range(nstates), "which_dia_states":range(nstates)}

0 commit comments

Comments
 (0)