Skip to content

Commit f845f71

Browse files
When psf based method does not work fall back to restart file only parsing
1 parent a56a6f0 commit f845f71

File tree

2 files changed

+151
-27
lines changed

2 files changed

+151
-27
lines changed

ionerdss/model/pdb/structure_validation.py

Lines changed: 82 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -435,11 +435,43 @@ def _parse_xyz_coordinates(xyz_file: Union[str, Path]) -> np.ndarray:
435435
return np.asarray(coords, dtype=float)
436436

437437

438+
def _parse_restart_template_names_by_iface_count(lines: list[str]) -> Dict[int, str]:
439+
"""Infer molecule names from the MolTemplates section keyed by unique interface counts."""
440+
molecule_section_idx = None
441+
for idx, line in enumerate(lines):
442+
stripped = line.strip()
443+
if stripped == "#All Molecules and coordinates":
444+
molecule_section_idx = idx
445+
break
446+
447+
if molecule_section_idx is None:
448+
return {}
449+
450+
names_by_iface_count: Dict[int, set[str]] = defaultdict(set)
451+
for line in lines[:molecule_section_idx]:
452+
parts = line.split()
453+
if (
454+
len(parts) != 3
455+
or not parts[0].isdigit()
456+
or not parts[2].isdigit()
457+
or not parts[1].isalpha()
458+
):
459+
continue
460+
names_by_iface_count[int(parts[2])].add(parts[1])
461+
462+
return {
463+
iface_count: next(iter(names))
464+
for iface_count, names in names_by_iface_count.items()
465+
if len(names) == 1
466+
}
467+
468+
438469
def _parse_restart_snapshot(
439470
restart_file: Union[str, Path],
440-
) -> Tuple[Dict[int, set[int]], Dict[int, Tuple[float, float, float]]]:
441-
"""Parse molecule connectivity and COM coordinates from a NERDSS restart file."""
471+
) -> Tuple[Dict[int, set[int]], Dict[int, Tuple[float, float, float]], Dict[int, str]]:
472+
"""Parse molecule connectivity, COM coordinates, and inferred molecule names from a NERDSS restart file."""
442473
lines = Path(restart_file).read_text(encoding="utf-8", errors="replace").splitlines()
474+
template_names_by_iface_count = _parse_restart_template_names_by_iface_count(lines)
443475

444476
start_idx = None
445477
for idx, line in enumerate(lines):
@@ -456,6 +488,7 @@ def _parse_restart_snapshot(
456488

457489
adjacency: Dict[int, set[int]] = defaultdict(set)
458490
restart_coords: Dict[int, Tuple[float, float, float]] = {}
491+
restart_mol_names: Dict[int, str] = {}
459492
idx = start_idx + 2
460493
for _ in range(molecule_count):
461494
header = lines[idx].split()
@@ -502,6 +535,7 @@ def _parse_restart_snapshot(
502535
if not iface_count_line:
503536
raise ValueError(f"Malformed interface count in restart file near line {idx + 1}")
504537
iface_count = int(iface_count_line[0])
538+
restart_mol_names[mol_id] = template_names_by_iface_count.get(iface_count, str(iface_count))
505539
idx += 1
506540

507541
if bound_list_size != partner_count:
@@ -542,15 +576,45 @@ def _parse_restart_snapshot(
542576
raise ValueError(f"Truncated {list_name} list in restart file near line {idx + 1}")
543577
idx += 1
544578

545-
return adjacency, restart_coords
579+
return adjacency, restart_coords, restart_mol_names
546580

547581

548582
def _parse_restart_molecule_partners(restart_file: Union[str, Path]) -> Dict[int, set[int]]:
549583
"""Parse final-frame molecule connectivity from a NERDSS restart file."""
550-
adjacency, _ = _parse_restart_snapshot(restart_file)
584+
adjacency, _, _ = _parse_restart_snapshot(restart_file)
551585
return adjacency
552586

553587

588+
def _find_matching_components(
589+
adjacency: Mapping[int, set[int]],
590+
mol_id_to_name: Mapping[int, str],
591+
target_counts: Mapping[str, int],
592+
) -> list[list[int]]:
593+
"""Return connected components whose composition matches the target counts."""
594+
matching_components: list[list[int]] = []
595+
visited: set[int] = set()
596+
for mol_id in sorted(mol_id_to_name):
597+
if mol_id in visited:
598+
continue
599+
600+
component: list[int] = []
601+
stack = [mol_id]
602+
visited.add(mol_id)
603+
while stack:
604+
current = stack.pop()
605+
component.append(current)
606+
for neighbor in sorted(adjacency.get(current, ())):
607+
if neighbor in mol_id_to_name and neighbor not in visited:
608+
visited.add(neighbor)
609+
stack.append(neighbor)
610+
611+
component_counts = Counter(mol_id_to_name[node_id] for node_id in component)
612+
if dict(component_counts) == dict(target_counts):
613+
matching_components.append(sorted(component))
614+
615+
return matching_components
616+
617+
554618
def _restart_snapshot_sort_key(restart_path: Path) -> Tuple[Tuple[int, ...], str]:
555619
"""Sort restart snapshot files from earliest to latest by embedded numeric suffixes."""
556620
numeric_parts = tuple(int(part) for part in re.findall(r"\d+", restart_path.stem))
@@ -589,12 +653,12 @@ def _extract_observed_com_coordinates(
589653
xyz_coords = None
590654
if final_coords_file is not None and Path(final_coords_file).exists():
591655
xyz_coords = _parse_xyz_coordinates(final_coords_file)
592-
adjacency, restart_coords = _parse_restart_snapshot(restart_file)
656+
adjacency, restart_coords, restart_mol_names = _parse_restart_snapshot(restart_file)
593657

594-
mol_id_to_name: Dict[int, str] = {}
595658
mol_id_to_coord: Dict[int, Tuple[float, float, float]] = {}
659+
psf_mol_id_to_name: Dict[int, str] = {}
596660
for atom_index, mol_id, mol_name in com_records:
597-
mol_id_to_name[mol_id] = mol_name
661+
psf_mol_id_to_name[mol_id] = mol_name
598662
if mol_id in restart_coords:
599663
mol_id_to_coord[mol_id] = restart_coords[mol_id]
600664
elif xyz_coords is not None:
@@ -603,26 +667,17 @@ def _extract_observed_com_coordinates(
603667
raise ValueError(f"Missing coordinates for molecule id {mol_id} in restart snapshot {restart_file}")
604668
adjacency.setdefault(mol_id, set())
605669

606-
matching_components: list[list[int]] = []
607-
visited: set[int] = set()
608-
for mol_id in sorted(mol_id_to_name):
609-
if mol_id in visited:
610-
continue
611-
612-
component: list[int] = []
613-
stack = [mol_id]
614-
visited.add(mol_id)
615-
while stack:
616-
current = stack.pop()
617-
component.append(current)
618-
for neighbor in sorted(adjacency.get(current, ())):
619-
if neighbor in mol_id_to_name and neighbor not in visited:
620-
visited.add(neighbor)
621-
stack.append(neighbor)
670+
name_maps_to_try: list[Dict[int, str]] = [psf_mol_id_to_name]
671+
if restart_mol_names and restart_mol_names != psf_mol_id_to_name:
672+
name_maps_to_try.append(restart_mol_names)
622673

623-
component_counts = Counter(mol_id_to_name[node_id] for node_id in component)
624-
if dict(component_counts) == dict(target_counts):
625-
matching_components.append(sorted(component))
674+
matching_components: list[list[int]] = []
675+
selected_mol_id_to_name = psf_mol_id_to_name
676+
for candidate_name_map in name_maps_to_try:
677+
matching_components = _find_matching_components(adjacency, candidate_name_map, target_counts)
678+
if matching_components:
679+
selected_mol_id_to_name = candidate_name_map
680+
break
626681

627682
if not matching_components:
628683
raise ValueError(
@@ -641,7 +696,7 @@ def _extract_observed_com_coordinates(
641696
observed = {}
642697
type_counts: Dict[str, int] = {}
643698
for mol_id in selected_component:
644-
mol_name = mol_id_to_name[mol_id]
699+
mol_name = selected_mol_id_to_name[mol_id]
645700
copy_idx = type_counts.get(mol_name, 0)
646701
type_counts[mol_name] = copy_idx + 1
647702

ionerdss/tests/test_structure_validation.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,3 +442,72 @@ def _restart_block(mol_id: int, partners: list[int], coord) -> list[str]:
442442
"H": (1.0, 0.0, 0.0),
443443
"L": (0.0, 1.0, 0.0),
444444
}
445+
446+
447+
def test_observed_structure_extraction_falls_back_to_restart_native_molecule_types():
448+
with tempfile.TemporaryDirectory() as tmpdir:
449+
tmp_path = Path(tmpdir)
450+
psf_path = tmp_path / "system.psf"
451+
restart_path = tmp_path / "restart.dat"
452+
453+
# Deliberately mismap the COM entries so PSF-based composition matching fails.
454+
psf_path.write_text(
455+
"\n".join(
456+
[
457+
"PSF",
458+
"",
459+
" 3 !NATOM",
460+
" 1 A 0 COM O 0 0 1.0 0",
461+
" 2 A 1 COM O 0 0 1.0 0",
462+
" 3 A 2 COM O 0 0 1.0 0",
463+
]
464+
),
465+
encoding="utf-8",
466+
)
467+
468+
def _restart_block(mol_id: int, coord, iface_count: int, partners: list[int]) -> list[str]:
469+
lines = [
470+
f"{mol_id} 0 0 0 0",
471+
"1.0 0 0 0 0 0",
472+
f"{coord[0]} {coord[1]} {coord[2]}",
473+
" ".join([str(iface_count)] + [str(i) for i in range(iface_count)]),
474+
" ".join([str(len(partners))] + [str(i) for i in range(len(partners))]),
475+
" ".join([str(len(partners))] + [str(pid) for pid in partners]),
476+
str(iface_count),
477+
]
478+
479+
bound_iface_indexes = set(range(max(iface_count - len(partners), 0), iface_count))
480+
for iface_idx in range(iface_count):
481+
is_bound = 1 if iface_idx in bound_iface_indexes else 0
482+
lines.append(f"{iface_idx} {iface_idx} 0 0 0 {is_bound}")
483+
lines.append("0.0 0.0 0.0")
484+
if is_bound:
485+
lines.append("0 0 0")
486+
487+
lines.extend(["0", "0", "0", "0", "0", "0"])
488+
return lines
489+
490+
restart_lines = [
491+
"#MolTemplates",
492+
"0 A 4",
493+
"1 B 2",
494+
"#All Molecules and coordinates",
495+
"3 3",
496+
]
497+
restart_lines.extend(_restart_block(0, (0.0, 0.0, 0.0), 4, [2]))
498+
restart_lines.extend(_restart_block(1, (1.0, 0.0, 0.0), 4, [2]))
499+
restart_lines.extend(_restart_block(2, (0.0, 1.0, 0.0), 2, [0, 1]))
500+
restart_path.write_text("\n".join(restart_lines), encoding="utf-8")
501+
502+
observed = _extract_observed_com_coordinates(
503+
system_psf_file=psf_path,
504+
final_coords_file=None,
505+
restart_file=restart_path,
506+
target_counts={"A": 2, "B": 1},
507+
)
508+
509+
assert observed == {
510+
"A_0": (0.0, 0.0, 0.0),
511+
"A_1": (1.0, 0.0, 0.0),
512+
"B": (0.0, 1.0, 0.0),
513+
}

0 commit comments

Comments
 (0)