Skip to content

Commit cf1b912

Browse files
committed
Rename variables and remove redundant permute and unsqueeze methods
1 parent 8739d4a commit cf1b912

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

src/libra_py/dynamics/ldr_torch/compute.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -120,12 +120,12 @@ def build_compound_overlap(self):
120120
# b = j * N + m
121121
s_elec_4d = self.s_elec.view(s, N, s, N) # (i, n, j, m)
122122

123-
s_nucl_4d = self.s_nucl.unsqueeze(0).unsqueeze(2) # (1, n, 1, m)
123+
s_nucl_4d = self.s_nucl[None, :, None, :] # (1, n, 1, m)
124124

125125
S_4d = s_elec_4d * s_nucl_4d
126126

127127
# Reshape back to (ndim, ndim) with compound indices
128-
self.S = S_4d.permute(0, 1, 2, 3).reshape(ndim, ndim)
128+
self.S = S_4d.reshape(ndim, ndim)
129129

130130
def build_compound_hamiltonian(self):
131131
"""
@@ -134,8 +134,8 @@ def build_compound_hamiltonian(self):
134134
N, s, ndim = self.ngrids, self.nstates, self.ndim
135135
scheme = self.hamiltonian_scheme
136136
s_elec_4d = self.s_elec.view(s, N, s, N) # (s, N, s, N)
137-
T_4d = self.t_nucl.unsqueeze(0).unsqueeze(2) # (1, N, 1, N)
138-
S_4d = self.s_nucl.unsqueeze(0).unsqueeze(2) # (1, N, 1, N)
137+
T_4d = self.t_nucl[None, :, None, :] # (1, N, 1, N)
138+
S_4d = self.s_nucl[None, :, None, :] # (1, N, 1, N)
139139

140140
if scheme == 'as_is': # For showing the original non-Hermitian form, not intended to use
141141
E_j_4d = self.E[None, None, :, :] # (1, 1, s, N)
@@ -147,8 +147,8 @@ def build_compound_hamiltonian(self):
147147
bracket_4d = T_4d + E_avg_4d * S_4d
148148
elif scheme == 'diagonal':
149149
# Build Kronecker deltas for electronic and nuclear indices
150-
delta_ij = torch.eye(s, device=self.device).unsqueeze(1).unsqueeze(3) # (s, 1, s, 1)
151-
delta_nm = torch.eye(N, device=self.device).unsqueeze(0).unsqueeze(2) # (1, N, 1, N)
150+
delta_ij = torch.eye(s, device=self.device)[:, None, :, None] # (s, 1, s, 1)
151+
delta_nm = torch.eye(N, device=self.device)[None, :, None, :] # (1, N, 1, N)
152152
delta_4d = delta_ij * delta_nm
153153

154154
E_j_4d = self.E[None, None, :, :] # (1, 1, s, N)
@@ -371,7 +371,7 @@ def compute_average_pos(self):
371371
q_nucl = self.s_nucl * q_med
372372
Q_4d = q_nucl[None, :, None, :]
373373
Q_4d_compound = s_elec_4d * Q_4d
374-
Q_compound = Q_4d_compound.permute(0, 1, 2, 3).reshape(ndim, ndim)
374+
Q_compound = Q_4d_compound.reshape(ndim, ndim)
375375

376376
numer = torch.vdot(C_vec, Q_compound @ C_vec).real
377377
avg_q.append(numer / denom)

0 commit comments

Comments
 (0)