Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/latte_interface/bomd/input.in
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ Threshold= 1.0E-7
#CoordsFile= coords_2955.pdb
CoordsFile= coords_100.pdb
#CoordsFile= water_128.xyz
GraphThreshold= 0.01
GraphThreshold= 0.0001
MaxDeg= 500 #Max graph degree
Rcut= 3.0 #Radius cutoff
PartitionType= Metis
#PartitionType= Regular
#PartitionType= MinCut
NumParts= 1
SCFTol= 1.0E-4
NumParts= 4
SCFTol= 1.0E-7
Overlap= True
ElectronicTemperature= 1000
MuCalculationType= #FromParts, Dynamical, None
Expand Down
125 changes: 100 additions & 25 deletions examples/latte_interface/bomd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@
import torch
import numpy as np
import gc
from copy import deepcopy

torch.set_default_dtype(torch.float64)

from sedacs.driver.init import init, available_device
from sedacs.graph_partition import get_coreHaloIndices, graph_partition
from sedacs.driver.graph_kernel_byparts import get_kernel_byParts, apply_kernel_byParts, rankN_update_byParts
from sedacs.driver.graph_adaptive_scf import get_adaptiveSCFDM
from sedacs.driver.graph_adaptive_sp_energy_forces import get_adaptive_sp_energy_forces
from sedacs.file_io import read_latte_tbparams
Expand Down Expand Up @@ -61,6 +64,8 @@ def main(args):
Temperature = args.temp
# If we want to run shadow md
shadow_md = args.shadow_md
# If we want to use kernel
use_kernel = args.use_kernel
# Initialize periodic table
pt = PeriodicTable()
# Get the atomic symbols for each atom in the system
Expand All @@ -72,6 +77,7 @@ def main(args):
# Get the Hubbard U values for each atom in the system
Hubbard_U = [latte_tbparams[symbol]["HubbardU"] for symbol in sy.symbols]
Hubbard_U = np.array(Hubbard_U)[sy.types]
sy.hubbard_u = Hubbard_U
# Get the atomic masses for each atom in the system
Mnuc = [pt.mass[pt.get_atomic_number(symbol)] for symbol in sy.symbols]
Mnuc = np.array(Mnuc)[sy.types]
Expand All @@ -84,15 +90,15 @@ def main(args):
)
# Read the coordinates as tensors
coords = torch.tensor(sy.coords)

# Perform a graph-adaptive calculation of the charges with SCF cycles
graphDH, sy.charges, mu, parts, subSysOnRank = get_adaptiveSCFDM(
graphDH, sy.charges, mu, parts, partsCoreHalo, subSysOnRank = get_adaptiveSCFDM(
sdc, eng, comm, rank, numranks, sy, hindex, graphNL, mu
)
#breakpoint()
# Perform a single-point graph-adaptive calculation of the energy and forces
graphDH, charges, EPOT, FTOT, mu, parts, subSysOnRank = (
graphDH, charges, EPOT, FTOT, mu, parts, partsCoreHalo, subSysOnRank = (
get_adaptive_sp_energy_forces(
sdc, eng, comm, rank, numranks, sy, hindex, graphDH, mu
sdc, eng, comm, rank, numranks, sy, parts, partsCoreHalo, hindex, graphDH, mu
)
)
# Convert the charges to a tensor
Expand Down Expand Up @@ -134,6 +140,7 @@ def main(args):
# Record unwrapped coordsinates
unwrap_coords = coords.clone().detach().double()

renew = 0
# MAIN MD LOOP {dR2(0)/dt2: V(0)->V(1/2); dn2(0)/dt2: n(0)->n(1); V(1/2): R(0)->R(1); dR2(1)/dt2: V(1/2)->V(1)}
for MD_step in range(MD_Iter):
# Calculate kinetic energy from particle velocities
Expand All @@ -154,7 +161,7 @@ def main(args):
# Here we record the time, temperature, and charges. Note that the last term, q, would be constant if not solving exact charges during MD
with torch.no_grad():
Energy_dat.write(
f"{Time/1000:<16.8f} {ETOT:<16.16f} {Temperature:<16.8f} {EKIN.item():<16.16f} {EPOT.item():<16.16f}\n"
f"{Time/1000:<16.8f} {ETOT:<16.16f} {Temperature:<16.8f} {EKIN.item():<16.16f} {EPOT.item():<16.16f} {torch.sum(q).item():<16.16f} {torch.sum(n_0).item():<16.16f} {mu:<16.16f}\n"
)

# Here we dump the MD trajectory
Expand All @@ -173,12 +180,50 @@ def main(args):
comm.Allreduce(coords, coords_sum, op=MPI.SUM)
coords = coords_sum / numranks

# Caculate the residual between q[n] and n
Res = q - n_0 # Note that n_0 is n from the previous step
# if use_kernel:
if use_kernel:
if MD_step == 0 or renew == 1:
get_kernel_byParts(sdc, rank, numranks, parts, partsCoreHalo, sy, mu)
syk = deepcopy(sy)
syk.subSy_list = deepcopy(sy.subSy_list)
for i, subSy in enumerate(syk.subSy_list):
subSy.ker = deepcopy(sy.subSy_list[i].ker)
partsk = deepcopy(parts)
partsCoreHalok = deepcopy(partsCoreHalo)
#breakpoint()
#KK0 = torch.tensor(sy.subSy_list[0].ker)
#KK0 = torch.tensor(collect_kernel_byParts(
# q, n_0, sdc, rank, numranks, comm, parts, partsCoreHalo, sy
#))
#dn2dt2 = -torch.matmul(KK0, Res)
dn2dt2 = 0
# ker = sy.subSy_list[0].ker
#else:
# sy.subSy_list[0].ker = ker
renew = 0
if MD_step > 0:
for i, subSy in enumerate(sy.subSy_list):
subSy.ker = deepcopy(syk.subSy_list[i].ker)
dn2dt2 = -rankN_update_byParts(
q, n_0, 6, sdc, rank, numranks, comm, parts, partsCoreHalo, sy, mu=mu
)
#dn2dt2 = -rankN_update_byParts(
# q, n_0, 6, sdc, eng, rank, numranks, comm, partsk, partsCoreHalok, syk, hindex, mu=mu
# )
#dn2dt2 = -torch.tensor(apply_kernel_byParts(
# q, n_0, sdc, rank, numranks, comm, partsk, syk
#))
#breakpoint()
else:
dn2dt2 = 0.8 * Res
# Propagating charge vector n for a better initial guess
# Or Propagating charge vector for shadow MD
n = (
2 * n_0
- n_1
+ 0.8 * kappa * (q - n_0)
+ kappa * dn2dt2
+ alpha
* (
C0 * n_0
Expand Down Expand Up @@ -210,50 +255,74 @@ def main(args):
# Update sy.coords in the system object
sy.coords = coords.numpy()
# Update neighbor list
nl, nlTrX, nlTrY, nlTrZ = build_nlist(
sy.coords,
sy.latticeVectors,
sdc.rcut,
api="old",
rank=rank,
numranks=numranks,
verb=False,
)
comm.Barrier()
#nl, nlTrX, nlTrY, nlTrZ = build_nlist(
# sy.coords,
# sy.latticeVectors,
# sdc.rcut,
# api="old",
# rank=rank,
# numranks=numranks,
# verb=False,
#)
#comm.Barrier()
# Create initial graph based on distances
if rank == 0:
graphNL = get_initial_graph(sy.coords, nl, sdc.rcut, sdc.maxDeg)
graphNL = comm.bcast(graphNL, root=0)
#if rank == 0:
# graphNL = get_initial_graph(sy.coords, nl, sdc.rcut, sdc.maxDeg)
#graphNL = comm.bcast(graphNL, root=0)
if not shadow_md:
# Perform a graph-adaptive calculation of the charges with SCF cycles
graphDH, sy.charges, mu, parts, subSysOnRank = get_adaptiveSCFDM(
sdc, eng, comm, rank, numranks, sy, hindex, graphNL, mu
sdc, eng, comm, rank, numranks, sy, hindex, graphDH, mu
)
#else:
# graphDH = graphNL

if MD_step % 100 == 99:
# Partition the graph
#parts = graph_partition(
# sdc, eng, graphDH, sdc.partitionType, sdc.nparts, sy.coords, True
#)

renew = 1

njumps = 1
partsCoreHalo = []
numCores = []
for i in range(sdc.nparts):
coreHalo, nc, nh = get_coreHaloIndices(parts[i], graphDH, njumps)
partsCoreHalo.append(coreHalo)
numCores.append(nc)
print("MD_step, core,halo size:", MD_step, i, "=", nc, nh)
# Perform a single-point graph-adaptive calculation of the energy and forces
graphDH, sy.charges, EPOT, FTOT, mu, parts, subSysOnRank = (
graphDH, sy.charges, EPOT, FTOT, mu, parts, partsCoreHalo, subSysOnRank = (
get_adaptive_sp_energy_forces(
sdc,
eng,
comm,
rank,
numranks,
sy,
parts,
partsCoreHalo,
hindex,
graphDH,
mu,
shadow_md=shadow_md,
)
)
q = torch.tensor(sy.charges)
# Constant shift in charges to maintain exact charge neutrality
#q = q - (torch.sum(q)/len(q))
# Convert the energy and forces to tensors
EPOT = torch.tensor(EPOT)
FTOT = torch.tensor(FTOT)

# dR2(1)/dt2: V(1/2)->V(1)
V = V + 0.5 * dt * F2V * FTOT / Mnuc.unsqueeze(1)

MD_xyz.close()
Energy_dat.close()

if rank == 0:
MD_xyz.close()
Energy_dat.close()


if __name__ == "__main__":
Expand All @@ -272,7 +341,7 @@ def main(args):
default="input.in",
)
parser.add_argument(
"--md_iter", type=int, default=100000, help="Number of timesteps"
"--md_iter", type=int, default=10000, help="Number of timesteps"
)
parser.add_argument("--dt", type=float, default=0.5, help="Timestep size (fs)")
parser.add_argument(
Expand All @@ -287,6 +356,12 @@ def main(args):
default=1,
help="Set to 1/0 to enable/disable shadow MD",
)
parser.add_argument(
"--use_kernel",
type=int,
default=1,
help="Set to 1/0 to enable/disable kernel calculation",
)
args = parser.parse_args()
if args.use_torch:
args.device = available_device()
Expand Down
Loading
Loading