Skip to content

Commit a027be6

Browse files
Update structure validation to apply to
1 parent c02f79d commit a027be6

2 files changed

Lines changed: 213 additions & 32 deletions

File tree

ionerdss/model/pdb/structure_validation.py

Lines changed: 165 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
Utilities for "lego assembly" structure validation.
2+
Utilities for structure validation.
33
44
This validation mode reduces the designed assembly to one representative copy
55
per exported molecule type, turns binding effectively irreversible by forcing
@@ -18,6 +18,7 @@
1818
import shutil
1919
import warnings
2020
from collections import Counter, defaultdict
21+
from itertools import permutations, product
2122

2223
import numpy as np
2324

@@ -102,6 +103,158 @@ def _as_xyz_array(coords: CoordinateInput, labels: Optional[Iterable[str]] = Non
102103
return ordered_labels, points
103104

104105

106+
def _strip_designed_label_to_type(label: str) -> str:
107+
"""Reduce ionerdss-style labels like `chain_type` to `type`."""
108+
if "_" not in label:
109+
return label
110+
return label.split("_", 1)[1]
111+
112+
113+
def _strip_observed_label_to_type(label: str) -> str:
114+
"""Reduce NERDSS-style labels like `type_0` to `type`."""
115+
if "_" not in label:
116+
return label
117+
118+
prefix, suffix = label.rsplit("_", 1)
119+
if suffix.isdigit():
120+
return prefix
121+
return label
122+
123+
124+
def _compute_alignment(
125+
designed_xyz: np.ndarray,
126+
observed_xyz: np.ndarray,
127+
*,
128+
backend: str,
129+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, float]:
130+
"""Return rigid transform, aligned coordinates, and RMSD."""
131+
if backend == "kabsch":
132+
rotation_matrix, translation_vector = rigid_transform_3d(observed_xyz, designed_xyz)
133+
elif backend == "biopython":
134+
try:
135+
from Bio.SVDSuperimposer import SVDSuperimposer
136+
except ImportError as exc:
137+
raise ImportError(
138+
"Biopython is required for backend='biopython'. Install biopython or use backend='kabsch'."
139+
) from exc
140+
141+
superimposer = SVDSuperimposer()
142+
superimposer.set(designed_xyz, observed_xyz)
143+
superimposer.run()
144+
rotation_matrix, translation_vector = superimposer.get_rotran()
145+
else:
146+
raise ValueError("backend must be 'kabsch' or 'biopython'.")
147+
148+
aligned_observed = apply_rigid_transform(rotation_matrix, translation_vector, observed_xyz)
149+
deltas = aligned_observed - designed_xyz
150+
rmsd = float(np.sqrt(np.mean(np.sum(deltas * deltas, axis=1))))
151+
return rotation_matrix, translation_vector, aligned_observed, rmsd
152+
153+
154+
def _match_coordinate_maps(
155+
designed_coordinates: CoordinateInput,
156+
observed_coordinates: CoordinateInput,
157+
*,
158+
labels: Optional[Iterable[str]] = None,
159+
backend: str,
160+
) -> Tuple[Tuple[str, ...], np.ndarray, np.ndarray]:
161+
"""Return labels and coordinate arrays with homomer-aware key matching for mappings."""
162+
designed_labels, designed_xyz = _as_xyz_array(designed_coordinates, labels=labels)
163+
observed_labels, observed_xyz = _as_xyz_array(observed_coordinates, labels=labels)
164+
165+
if designed_xyz.shape != observed_xyz.shape:
166+
raise ValueError("Designed and observed structures must have the same shape.")
167+
168+
if designed_labels == observed_labels:
169+
return designed_labels, designed_xyz, observed_xyz
170+
171+
if not isinstance(designed_coordinates, Mapping) or not isinstance(observed_coordinates, Mapping):
172+
raise ValueError(
173+
"Designed and observed structures must have the same ordered labels. "
174+
"Pass dictionaries keyed by molecule labels to enable automatic matching."
175+
)
176+
177+
designed_type_by_label = {
178+
label: _strip_designed_label_to_type(label)
179+
for label in designed_labels
180+
}
181+
observed_type_by_label = {
182+
label: _strip_observed_label_to_type(label)
183+
for label in observed_labels
184+
}
185+
186+
designed_type_counts = Counter(designed_type_by_label.values())
187+
observed_type_counts = Counter(observed_type_by_label.values())
188+
if designed_type_counts != observed_type_counts:
189+
raise ValueError(
190+
"Designed and observed structures do not describe the same molecule-type composition after "
191+
"normalizing homomer labels."
192+
)
193+
194+
designed_labels_by_type: Dict[str, list[str]] = defaultdict(list)
195+
observed_labels_by_type: Dict[str, list[str]] = defaultdict(list)
196+
for label in designed_labels:
197+
designed_labels_by_type[designed_type_by_label[label]].append(label)
198+
for label in observed_labels:
199+
observed_labels_by_type[observed_type_by_label[label]].append(label)
200+
201+
for type_name in designed_type_counts:
202+
designed_labels_by_type[type_name].sort()
203+
observed_labels_by_type[type_name].sort()
204+
205+
permutation_sets = [
206+
list(permutations(observed_labels_by_type[type_name]))
207+
for type_name, count in sorted(designed_type_counts.items())
208+
if count > 1
209+
]
210+
repeated_types = [
211+
type_name
212+
for type_name, count in sorted(designed_type_counts.items())
213+
if count > 1
214+
]
215+
216+
best_labels: Optional[Tuple[str, ...]] = None
217+
best_observed_xyz: Optional[np.ndarray] = None
218+
best_rmsd: Optional[float] = None
219+
220+
permutation_products = product(*permutation_sets) if permutation_sets else [()]
221+
for perm_choice in permutation_products:
222+
observed_order_by_type = {
223+
type_name: list(observed_labels_by_type[type_name])
224+
for type_name in designed_type_counts
225+
}
226+
for type_name, permuted_labels in zip(repeated_types, perm_choice):
227+
observed_order_by_type[type_name] = list(permuted_labels)
228+
229+
matched_labels = tuple(designed_type_by_label[label] for label in designed_labels)
230+
matched_observed_labels = []
231+
type_offsets: Dict[str, int] = defaultdict(int)
232+
for designed_label in designed_labels:
233+
type_name = designed_type_by_label[designed_label]
234+
idx = type_offsets[type_name]
235+
matched_observed_labels.append(observed_order_by_type[type_name][idx])
236+
type_offsets[type_name] += 1
237+
238+
candidate_observed_xyz = np.asarray(
239+
[observed_coordinates[label] for label in matched_observed_labels],
240+
dtype=float,
241+
)
242+
_, _, _, candidate_rmsd = _compute_alignment(
243+
designed_xyz,
244+
candidate_observed_xyz,
245+
backend=backend,
246+
)
247+
248+
if best_rmsd is None or candidate_rmsd < best_rmsd:
249+
best_labels = matched_labels
250+
best_observed_xyz = candidate_observed_xyz
251+
best_rmsd = candidate_rmsd
252+
253+
assert best_labels is not None
254+
assert best_observed_xyz is not None
255+
return best_labels, designed_xyz, best_observed_xyz
256+
257+
105258
def get_representative_instances(system: System) -> Dict[str, MoleculeInstance]:
106259
"""Choose one representative instance per molecule type.
107260
@@ -307,37 +460,17 @@ def align_structure_to_design(
307460
plot: bool = False,
308461
) -> StructureAlignmentResult:
309462
"""Rigidly align an observed structure onto the designed target and compute RMSD."""
310-
designed_labels, designed_xyz = _as_xyz_array(designed_coordinates, labels=labels)
311-
observed_labels, observed_xyz = _as_xyz_array(observed_coordinates, labels=labels)
312-
313-
if designed_labels != observed_labels:
314-
raise ValueError(
315-
"Designed and observed structures must have the same ordered labels. "
316-
"Pass dictionaries keyed by molecule type to match automatically."
317-
)
318-
319-
if designed_xyz.shape != observed_xyz.shape:
320-
raise ValueError("Designed and observed structures must have the same shape.")
321-
322-
if backend == "kabsch":
323-
rotation_matrix, translation_vector = rigid_transform_3d(observed_xyz, designed_xyz)
324-
elif backend == "biopython":
325-
try:
326-
from Bio.SVDSuperimposer import SVDSuperimposer
327-
except ImportError as exc:
328-
raise ImportError(
329-
"Biopython is required for backend='biopython'. Install biopython or use backend='kabsch'."
330-
) from exc
331-
332-
superimposer = SVDSuperimposer()
333-
superimposer.set(designed_xyz, observed_xyz)
334-
superimposer.run()
335-
rotation_matrix, translation_vector = superimposer.get_rotran()
336-
else:
337-
raise ValueError("backend must be 'kabsch' or 'biopython'.")
338-
aligned_observed = apply_rigid_transform(rotation_matrix, translation_vector, observed_xyz)
339-
deltas = aligned_observed - designed_xyz
340-
rmsd = float(np.sqrt(np.mean(np.sum(deltas * deltas, axis=1))))
463+
designed_labels, designed_xyz, observed_xyz = _match_coordinate_maps(
464+
designed_coordinates,
465+
observed_coordinates,
466+
labels=labels,
467+
backend=backend,
468+
)
469+
rotation_matrix, translation_vector, aligned_observed, rmsd = _compute_alignment(
470+
designed_xyz,
471+
observed_xyz,
472+
backend=backend,
473+
)
341474

342475
if plot:
343476
try:

ionerdss/tests/test_structure_validation.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,54 @@ def test_align_structure_to_design_recovers_rigid_transform():
163163
assert np.allclose(result.aligned_observed_coordinates, result.designed_coordinates)
164164

165165

166+
def test_align_structure_to_design_keeps_exact_key_matching():
167+
designed = {
168+
"A_0": [0.0, 0.0, 0.0],
169+
"A_1": [1.0, 0.0, 0.0],
170+
"B_0": [0.0, 1.0, 0.0],
171+
}
172+
173+
rotation = np.array(
174+
[
175+
[0.0, -1.0, 0.0],
176+
[1.0, 0.0, 0.0],
177+
[0.0, 0.0, 1.0],
178+
]
179+
)
180+
translation = np.array([2.0, 3.0, -1.0])
181+
182+
designed_xyz = np.asarray([designed[key] for key in sorted(designed)], dtype=float)
183+
observed_xyz = (rotation @ designed_xyz.T).T + translation
184+
observed = {
185+
key: observed_xyz[idx].tolist()
186+
for idx, key in enumerate(sorted(designed))
187+
}
188+
189+
result = align_structure_to_design(designed, observed)
190+
191+
assert result.labels == ("A_0", "A_1", "B_0")
192+
assert result.rmsd < 1e-10
193+
194+
195+
def test_align_structure_to_design_matches_homomer_labels_by_type():
196+
designed = {
197+
"chainA_A": [0.0, 0.0, 0.0],
198+
"chainB_A": [2.0, 0.0, 0.0],
199+
"chainC_B": [0.0, 3.0, 0.0],
200+
}
201+
observed = {
202+
"A_0": [2.0, 0.0, 0.0],
203+
"A_1": [0.0, 0.0, 0.0],
204+
"B_0": [0.0, 3.0, 0.0],
205+
}
206+
207+
result = align_structure_to_design(designed, observed)
208+
209+
assert result.labels == ("A", "A", "B")
210+
assert result.rmsd < 1e-10
211+
assert np.allclose(result.aligned_observed_coordinates, result.designed_coordinates)
212+
213+
166214
def test_validation_module_exposes_prepare_and_compare():
167215
assert hasattr(pdb, "validation")
168216
assert callable(pdb.validation.prepare)

0 commit comments

Comments
 (0)