@@ -435,11 +435,43 @@ def _parse_xyz_coordinates(xyz_file: Union[str, Path]) -> np.ndarray:
435435 return np .asarray (coords , dtype = float )
436436
437437
438+ def _parse_restart_template_names_by_iface_count (lines : list [str ]) -> Dict [int , str ]:
439+ """Infer molecule names from the MolTemplates section keyed by unique interface counts."""
440+ molecule_section_idx = None
441+ for idx , line in enumerate (lines ):
442+ stripped = line .strip ()
443+ if stripped == "#All Molecules and coordinates" :
444+ molecule_section_idx = idx
445+ break
446+
447+ if molecule_section_idx is None :
448+ return {}
449+
450+ names_by_iface_count : Dict [int , set [str ]] = defaultdict (set )
451+ for line in lines [:molecule_section_idx ]:
452+ parts = line .split ()
453+ if (
454+ len (parts ) != 3
455+ or not parts [0 ].isdigit ()
456+ or not parts [2 ].isdigit ()
457+ or not parts [1 ].isalpha ()
458+ ):
459+ continue
460+ names_by_iface_count [int (parts [2 ])].add (parts [1 ])
461+
462+ return {
463+ iface_count : next (iter (names ))
464+ for iface_count , names in names_by_iface_count .items ()
465+ if len (names ) == 1
466+ }
467+
468+
438469def _parse_restart_snapshot (
439470 restart_file : Union [str , Path ],
440- ) -> Tuple [Dict [int , set [int ]], Dict [int , Tuple [float , float , float ]]]:
441- """Parse molecule connectivity and COM coordinates from a NERDSS restart file."""
471+ ) -> Tuple [Dict [int , set [int ]], Dict [int , Tuple [float , float , float ]], Dict [ int , str ] ]:
472+ """Parse molecule connectivity, COM coordinates, and inferred molecule names from a NERDSS restart file."""
442473 lines = Path (restart_file ).read_text (encoding = "utf-8" , errors = "replace" ).splitlines ()
474+ template_names_by_iface_count = _parse_restart_template_names_by_iface_count (lines )
443475
444476 start_idx = None
445477 for idx , line in enumerate (lines ):
@@ -456,6 +488,7 @@ def _parse_restart_snapshot(
456488
457489 adjacency : Dict [int , set [int ]] = defaultdict (set )
458490 restart_coords : Dict [int , Tuple [float , float , float ]] = {}
491+ restart_mol_names : Dict [int , str ] = {}
459492 idx = start_idx + 2
460493 for _ in range (molecule_count ):
461494 header = lines [idx ].split ()
@@ -502,6 +535,7 @@ def _parse_restart_snapshot(
502535 if not iface_count_line :
503536 raise ValueError (f"Malformed interface count in restart file near line { idx + 1 } " )
504537 iface_count = int (iface_count_line [0 ])
538+ restart_mol_names [mol_id ] = template_names_by_iface_count .get (iface_count , str (iface_count ))
505539 idx += 1
506540
507541 if bound_list_size != partner_count :
@@ -542,15 +576,45 @@ def _parse_restart_snapshot(
542576 raise ValueError (f"Truncated { list_name } list in restart file near line { idx + 1 } " )
543577 idx += 1
544578
545- return adjacency , restart_coords
579+ return adjacency , restart_coords , restart_mol_names
546580
547581
548582def _parse_restart_molecule_partners (restart_file : Union [str , Path ]) -> Dict [int , set [int ]]:
549583 """Parse final-frame molecule connectivity from a NERDSS restart file."""
550- adjacency , _ = _parse_restart_snapshot (restart_file )
584+ adjacency , _ , _ = _parse_restart_snapshot (restart_file )
551585 return adjacency
552586
553587
588+ def _find_matching_components (
589+ adjacency : Mapping [int , set [int ]],
590+ mol_id_to_name : Mapping [int , str ],
591+ target_counts : Mapping [str , int ],
592+ ) -> list [list [int ]]:
593+ """Return connected components whose composition matches the target counts."""
594+ matching_components : list [list [int ]] = []
595+ visited : set [int ] = set ()
596+ for mol_id in sorted (mol_id_to_name ):
597+ if mol_id in visited :
598+ continue
599+
600+ component : list [int ] = []
601+ stack = [mol_id ]
602+ visited .add (mol_id )
603+ while stack :
604+ current = stack .pop ()
605+ component .append (current )
606+ for neighbor in sorted (adjacency .get (current , ())):
607+ if neighbor in mol_id_to_name and neighbor not in visited :
608+ visited .add (neighbor )
609+ stack .append (neighbor )
610+
611+ component_counts = Counter (mol_id_to_name [node_id ] for node_id in component )
612+ if dict (component_counts ) == dict (target_counts ):
613+ matching_components .append (sorted (component ))
614+
615+ return matching_components
616+
617+
554618def _restart_snapshot_sort_key (restart_path : Path ) -> Tuple [Tuple [int , ...], str ]:
555619 """Sort restart snapshot files from earliest to latest by embedded numeric suffixes."""
556620 numeric_parts = tuple (int (part ) for part in re .findall (r"\d+" , restart_path .stem ))
@@ -589,12 +653,12 @@ def _extract_observed_com_coordinates(
589653 xyz_coords = None
590654 if final_coords_file is not None and Path (final_coords_file ).exists ():
591655 xyz_coords = _parse_xyz_coordinates (final_coords_file )
592- adjacency , restart_coords = _parse_restart_snapshot (restart_file )
656+ adjacency , restart_coords , restart_mol_names = _parse_restart_snapshot (restart_file )
593657
594- mol_id_to_name : Dict [int , str ] = {}
595658 mol_id_to_coord : Dict [int , Tuple [float , float , float ]] = {}
659+ psf_mol_id_to_name : Dict [int , str ] = {}
596660 for atom_index , mol_id , mol_name in com_records :
597- mol_id_to_name [mol_id ] = mol_name
661+ psf_mol_id_to_name [mol_id ] = mol_name
598662 if mol_id in restart_coords :
599663 mol_id_to_coord [mol_id ] = restart_coords [mol_id ]
600664 elif xyz_coords is not None :
@@ -603,26 +667,17 @@ def _extract_observed_com_coordinates(
603667 raise ValueError (f"Missing coordinates for molecule id { mol_id } in restart snapshot { restart_file } " )
604668 adjacency .setdefault (mol_id , set ())
605669
606- matching_components : list [list [int ]] = []
607- visited : set [int ] = set ()
608- for mol_id in sorted (mol_id_to_name ):
609- if mol_id in visited :
610- continue
611-
612- component : list [int ] = []
613- stack = [mol_id ]
614- visited .add (mol_id )
615- while stack :
616- current = stack .pop ()
617- component .append (current )
618- for neighbor in sorted (adjacency .get (current , ())):
619- if neighbor in mol_id_to_name and neighbor not in visited :
620- visited .add (neighbor )
621- stack .append (neighbor )
670+ name_maps_to_try : list [Dict [int , str ]] = [psf_mol_id_to_name ]
671+ if restart_mol_names and restart_mol_names != psf_mol_id_to_name :
672+ name_maps_to_try .append (restart_mol_names )
622673
623- component_counts = Counter (mol_id_to_name [node_id ] for node_id in component )
624- if dict (component_counts ) == dict (target_counts ):
625- matching_components .append (sorted (component ))
674+ matching_components : list [list [int ]] = []
675+ selected_mol_id_to_name = psf_mol_id_to_name
676+ for candidate_name_map in name_maps_to_try :
677+ matching_components = _find_matching_components (adjacency , candidate_name_map , target_counts )
678+ if matching_components :
679+ selected_mol_id_to_name = candidate_name_map
680+ break
626681
627682 if not matching_components :
628683 raise ValueError (
@@ -641,7 +696,7 @@ def _extract_observed_com_coordinates(
641696 observed = {}
642697 type_counts : Dict [str , int ] = {}
643698 for mol_id in selected_component :
644- mol_name = mol_id_to_name [mol_id ]
699+ mol_name = selected_mol_id_to_name [mol_id ]
645700 copy_idx = type_counts .get (mol_name , 0 )
646701 type_counts [mol_name ] = copy_idx + 1
647702
0 commit comments