@@ -120,12 +120,12 @@ def build_compound_overlap(self):
120120 # b = j * N + m
121121 s_elec_4d = self .s_elec .view (s , N , s , N ) # (i, n, j, m)
122122
123- s_nucl_4d = self .s_nucl . unsqueeze ( 0 ). unsqueeze ( 2 ) # (1, n, 1, m)
123+ s_nucl_4d = self .s_nucl [ None , :, None , :] # (1, n, 1, m)
124124
125125 S_4d = s_elec_4d * s_nucl_4d
126126
127127 # Reshape back to (ndim, ndim) with compound indices
128- self .S = S_4d .permute ( 0 , 1 , 2 , 3 ). reshape (ndim , ndim )
128+ self .S = S_4d .reshape (ndim , ndim )
129129
130130 def build_compound_hamiltonian (self ):
131131 """
@@ -134,8 +134,8 @@ def build_compound_hamiltonian(self):
134134 N , s , ndim = self .ngrids , self .nstates , self .ndim
135135 scheme = self .hamiltonian_scheme
136136 s_elec_4d = self .s_elec .view (s , N , s , N ) # (s, N, s, N)
137- T_4d = self .t_nucl . unsqueeze ( 0 ). unsqueeze ( 2 ) # (1, N, 1, N)
138- S_4d = self .s_nucl . unsqueeze ( 0 ). unsqueeze ( 2 ) # (1, N, 1, N)
137+ T_4d = self .t_nucl [ None , :, None , :] # (1, N, 1, N)
138+ S_4d = self .s_nucl [ None , :, None , :] # (1, N, 1, N)
139139
140140 if scheme == 'as_is' : # For showing the original non-Hermitian form, not intended to use
141141 E_j_4d = self .E [None , None , :, :] # (1, 1, s, N)
@@ -147,8 +147,8 @@ def build_compound_hamiltonian(self):
147147 bracket_4d = T_4d + E_avg_4d * S_4d
148148 elif scheme == 'diagonal' :
149149 # Build Kronecker deltas for electronic and nuclear indices
150- delta_ij = torch .eye (s , device = self .device ). unsqueeze ( 1 ). unsqueeze ( 3 ) # (s, 1, s, 1)
151- delta_nm = torch .eye (N , device = self .device ). unsqueeze ( 0 ). unsqueeze ( 2 ) # (1, N, 1, N)
150+ delta_ij = torch .eye (s , device = self .device )[:, None , :, None ] # (s, 1, s, 1)
151+ delta_nm = torch .eye (N , device = self .device )[ None , :, None , :] # (1, N, 1, N)
152152 delta_4d = delta_ij * delta_nm
153153
154154 E_j_4d = self .E [None , None , :, :] # (1, 1, s, N)
@@ -371,7 +371,7 @@ def compute_average_pos(self):
371371 q_nucl = self .s_nucl * q_med
372372 Q_4d = q_nucl [None , :, None , :]
373373 Q_4d_compound = s_elec_4d * Q_4d
374- Q_compound = Q_4d_compound .permute ( 0 , 1 , 2 , 3 ). reshape (ndim , ndim )
374+ Q_compound = Q_4d_compound .reshape (ndim , ndim )
375375
376376 numer = torch .vdot (C_vec , Q_compound @ C_vec ).real
377377 avg_q .append (numer / denom )
0 commit comments