Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/examples/load_and_visualize_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@
print("Available annotations:")
annotations = atom_array.get_annotation_categories()
for i, annotation in enumerate(annotations):
print(f" {i+1:2d}. {annotation}")
print(f" {i + 1:2d}. {annotation}")


# %%
Expand Down
2 changes: 1 addition & 1 deletion src/atomworks/biotite_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def array(atoms: list[Atom]) -> AtomArray:
for i, atom in enumerate(atoms):
if sorted(atom._annot.keys()) != names:
raise ValueError(
f"The atom at index {i} does not share the same " f"annotation categories as the atom at index 0"
f"The atom at index {i} does not share the same annotation categories as the atom at index 0"
)
array = AtomArray(len(atoms))

Expand Down
13 changes: 8 additions & 5 deletions src/atomworks/io/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def _build_cache_file_path(
def parse(
filename: os.PathLike | io.StringIO | io.BytesIO,
*,
file_type: Literal["cif", "pdb"] | None = None,
file_type: Literal["cif", "pdb", "mmjson"] | None = None,
ccd_mirror_path: os.PathLike | None = CCD_MIRROR_PATH,
cache_dir: os.PathLike | None = None,
save_to_cache: bool = False,
Expand Down Expand Up @@ -163,7 +163,7 @@ def parse(
atomic-level structure (e.g. .cif, .bcif, .cif.gz, .pdb), although .cif files are strongly recommended.

**Wrapper arguments:**
file_type (Literal["cif", "pdb"] | None, optional): The file type of the structure file.
file_type (Literal["cif", "pdb", "mmjson"] | None, optional): The file type of the structure file.
If not provided, the file type will be inferred automatically.
load_from_cache (bool, optional): Whether to load pre-compiled results from cache. Defaults to False.
cache_dir (PathLike, optional): Directory path to save pre-compiled results. Defaults to None.
Expand Down Expand Up @@ -329,9 +329,10 @@ def parse(
build_assembly=build_assembly,
extra_fields=extra_fields,
)
elif file_type in ("cif", "bcif"):
elif file_type in ("cif", "bcif", "mmjson"):
result = _parse_from_cif(
filename=filename,
file_type=file_type,
ccd_mirror_path=ccd_mirror_path,
add_missing_atoms=add_missing_atoms,
add_id_and_entity_annotations=add_id_and_entity_annotations,
Expand Down Expand Up @@ -684,7 +685,9 @@ def parse_atom_array(
return data_dict


def _parse_from_cif(filename: os.PathLike | io.StringIO | io.BytesIO, **kwargs) -> dict[str, Any]:
def _parse_from_cif(
filename: os.PathLike | io.StringIO | io.BytesIO, file_type: str | None = None, **kwargs
) -> dict[str, Any]:
"""Parse the CIF file.

Return chain information, residue information, atom array, and metadata.
Expand All @@ -696,7 +699,7 @@ def _parse_from_cif(filename: os.PathLike | io.StringIO | io.BytesIO, **kwargs)
data_dict = {"extra_info": {}}

# ... read the CIF file into the dictionary (we will clean up the dictionary before returning)
cif_file = read_any(filename)
cif_file = read_any(filename, file_type=file_type)
data_dict["cif_block"] = cif_file.block

# ... load metadata into "metadata" key (either from RCSB standard fields, or from the custom `extra_metadata` field)
Expand Down
83 changes: 68 additions & 15 deletions src/atomworks/io/utils/io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import gzip
import io
import json
import logging
import os
import re
Expand Down Expand Up @@ -44,7 +45,18 @@

logger = logging.getLogger("atomworks.io")

CIF_LIKE_EXTENSIONS = {".cif", ".pdb", ".bcif", ".cif.gz", ".pdb.gz", ".bcif.gz"}
CIF_LIKE_EXTENSIONS = {
".cif",
".pdb",
".bcif",
".cif.gz",
".pdb.gz",
".bcif.gz",
".json",
".json.gz",
".mmjson",
".mmjson.gz",
}


@contextmanager
Expand Down Expand Up @@ -86,7 +98,7 @@ def _get_logged_in_user() -> str:

def load_any(
file_or_buffer: os.PathLike | io.StringIO | io.BytesIO,
file_type: Literal["cif", "mmcif", "pdbx", "pdb", "pdb1", "bcif"] | None = None,
file_type: Literal["cif", "mmcif", "pdbx", "pdb", "pdb1", "bcif", "mmjson"] | None = None,
*,
extra_fields: list[str] | Literal["all"] = [],
include_bonds: bool = True,
Expand Down Expand Up @@ -320,7 +332,28 @@ def get_structure(
return atom_array_stack


def infer_pdb_file_type(path_or_buffer: os.PathLike | io.StringIO | io.BytesIO) -> Literal["cif", "pdb", "bcif", "sdf"]:
def _infer_file_type_from_buffer(buffer: io.BytesIO | io.StringIO) -> str:
"""Infer file type from buffer contents."""
if isinstance(buffer, io.BytesIO):
return "bcif"

# StringIO - peek at contents to determine format
buffer.seek(0)
first_char = buffer.read(1)
buffer.readline() # finish first line
second_line = buffer.readline()
buffer.seek(0)

if first_char == "{":
return "mmjson"
if second_line.startswith("#"):
return "cif"
return "pdb"


def infer_pdb_file_type(
path_or_buffer: os.PathLike | io.StringIO | io.BytesIO,
) -> Literal["cif", "pdb", "bcif", "sdf", "mmjson"]:
"""
Infer the file type of a PDB file or buffer.
"""
Expand All @@ -329,15 +362,8 @@ def infer_pdb_file_type(path_or_buffer: os.PathLike | io.StringIO | io.BytesIO)
path_or_buffer = Path(path_or_buffer)

# Determine file type and open context
if isinstance(path_or_buffer, io.BytesIO):
return "bcif"
elif isinstance(path_or_buffer, io.StringIO):
# ... if second line starts with '#', it is very likely a cif file
path_or_buffer.seek(0)
path_or_buffer.readline() # Skip the first line
second_line = path_or_buffer.readline().strip()
path_or_buffer.seek(0)
return "cif" if second_line.startswith("#") else "pdb"
if isinstance(path_or_buffer, io.StringIO | io.BytesIO):
return _infer_file_type_from_buffer(path_or_buffer)
elif isinstance(path_or_buffer, Path):
if path_or_buffer.suffix in (".gz", ".gzip"):
inferred_file_type = Path(path_or_buffer.stem).suffix.lstrip(".")
Expand All @@ -353,21 +379,40 @@ def infer_pdb_file_type(path_or_buffer: os.PathLike | io.StringIO | io.BytesIO)
return "bcif"
elif inferred_file_type == "sdf":
return "sdf"
elif inferred_file_type in ("json", "mmjson"):
return "mmjson"
else:
raise ValueError(f"Unsupported file type: {inferred_file_type}")


def _read_mmjson(file_obj: io.StringIO | io.BytesIO | io.TextIOWrapper) -> pdbx.CIFFile:
"""Read an mmjson file into a CIFFile object."""
data = json.load(file_obj)
cif_file = pdbx.CIFFile()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice job!

for block_name, block_data in data.items():
cif_block = pdbx.CIFBlock()
for cat_name, cat_data in block_data.items():
cif_category = pdbx.CIFCategory()
for col_name, col_data in cat_data.items():
# Convert None to "?" and ensure all elements are strings
processed_data = [str(x) if x is not None else "?" for x in col_data]
cif_category[col_name] = pdbx.CIFColumn(processed_data)
cif_block[cat_name] = cif_category
cif_file[block_name] = cif_block
return cif_file


def read_any(
path_or_buffer: os.PathLike | io.StringIO | io.BytesIO,
file_type: Literal["cif", "pdb", "bcif", "sdf"] | None = None,
file_type: Literal["cif", "pdb", "bcif", "sdf", "mmjson"] | None = None,
) -> pdbx.CIFFile | biotite_pdb.PDBFile | pdbx.BinaryCIFFile:
"""
Reads any of the allowed file types into the appropriate Biotite file object.

Args:
path_or_buffer (PathLike | io.StringIO | io.BytesIO): The path to the file or a buffer to read from.
If a buffer, it's highly recommended to specify the file_type.
file_type (Literal["cif", "pdb", "bcif"], optional): Type of the file.
file_type (Literal["cif", "pdb", "bcif", "mmjson"], optional): Type of the file.
If None, it will be inferred from the file extension. When using a buffer, the file type must be specified.

Returns:
Expand All @@ -379,7 +424,8 @@ def read_any(
# Determine file type
if file_type is None:
file_type = infer_pdb_file_type(path_or_buffer)
open_mode = "rb" if file_type == "bcif" else "rt"

open_mode = "rb" if file_type == "bcif" else "rt"

# Convert string paths to Path objects and decompress if necessary
if isinstance(path_or_buffer, str | Path):
Expand All @@ -398,6 +444,13 @@ def read_any(
file_cls = pdbx.BinaryCIFFile
elif file_type == "sdf":
file_cls = mol.SDFile
elif file_type == "mmjson":
# Special handling for mmjson
if isinstance(path_or_buffer, io.StringIO | io.BytesIO):
return _read_mmjson(path_or_buffer)
else:
with open(path_or_buffer) as f:
return _read_mmjson(f)
else:
raise ValueError(f"Unsupported file type: {file_type}")

Expand Down
2 changes: 1 addition & 1 deletion src/atomworks/io/utils/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def _ensure_bool_array(mask: Any, expected_length: int) -> np.ndarray:
# Check length
if len(mask) != expected_length:
raise ValueError(
f"Query resulted in mask of length {len(mask)}, " f"but AtomArray has length {expected_length}"
f"Query resulted in mask of length {len(mask)}, but AtomArray has length {expected_length}"
)

return mask
Expand Down
6 changes: 3 additions & 3 deletions src/atomworks/ml/datasets/concat_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def __getitem__(self, idxs: tuple[int, ...]) -> Any:

for i, idx in enumerate(idxs):
dataset = self.dataset if i == 0 else self.fallback_dataset
dataset_name = "Primary dataset" if i == 0 else f"Fallback {i}/{len(idxs)-1}"
dataset_name = "Primary dataset" if i == 0 else f"Fallback {i}/{len(idxs) - 1}"

try:
return dataset[idx]
Expand All @@ -238,11 +238,11 @@ def __getitem__(self, idxs: tuple[int, ...]) -> Any:

# Log fallback attempt if not the last one
if i < len(idxs) - 1:
logger.warning(f"({dataset_name}): Trying fallback index {idxs[i+1]}.{example_id}")
logger.warning(f"({dataset_name}): Trying fallback index {idxs[i + 1]}.{example_id}")

# All attempts failed
logger.error(
f"(Exceeded all {len(idxs)-1} fallbacks. Training will crash now. Errors: {error_list} for examples: {example_id_list})"
f"(Exceeded all {len(idxs) - 1} fallbacks. Training will crash now. Errors: {error_list} for examples: {example_id_list})"
)
raise RuntimeError(f"All attempts failed for indices {idxs}. See error_list for details.") from ExceptionGroup(
"All fallback attempts failed", error_list
Expand Down
2 changes: 1 addition & 1 deletion src/atomworks/ml/preprocessing/msa/finding.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def _build_msa_file_paths(

for shard_depth in shard_depths:
# Build shard path like "ab/cd/" for depth 2 with hash "abcd123..."
shard_path = "".join([f"{sequence_hash[(i*2):(i+1)*2]}/" for i in range(shard_depth)])
shard_path = "".join([f"{sequence_hash[(i * 2) : (i + 1) * 2]}/" for i in range(shard_depth)])

for extension in extensions:
file_path = msa_dir / shard_path / f"{sequence_hash}{extension}"
Expand Down
2 changes: 1 addition & 1 deletion src/atomworks/ml/transforms/design_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(

if not design_tasks_to_use:
logger.warning(
"No design tasks with non-zero frequency found. " "SampleDesignTask will act as an identity transform."
"No design tasks with non-zero frequency found. SampleDesignTask will act as an identity transform."
)

self.design_tasks = design_tasks_to_use
Expand Down
2 changes: 1 addition & 1 deletion src/atomworks_cli/find.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def find(
)
typer.secho(f" Found MSAs: {found_count:,} ({coverage_percent:.1f}%)", fg=found_color)
typer.secho(
f" Missing MSAs: {missing_count:,} ({100-coverage_percent:.1f}%)",
f" Missing MSAs: {missing_count:,} ({100 - coverage_percent:.1f}%)",
fg=typer.colors.RED if missing_count > 0 else typer.colors.GREEN,
)

Expand Down
Binary file added tests/data/io/2hhb.cif.gz
Binary file not shown.
Binary file added tests/data/io/2hhb.json.gz
Binary file not shown.
28 changes: 28 additions & 0 deletions tests/io/tools/test_mmjson.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from atomworks.io.parser import parse
from atomworks.io.utils.io_utils import infer_pdb_file_type
from atomworks.io.utils.testing import assert_same_atom_array
from tests.conftest import TEST_DATA_DIR


def test_mmjson_inference_and_parsing():
json_path = TEST_DATA_DIR / "io" / "2hhb.json.gz"
cif_path = TEST_DATA_DIR / "io" / "2hhb.cif.gz"

assert json_path.exists(), f"mmJSON file not found at {json_path}"
assert cif_path.exists(), f"CIF file not found at {cif_path}"

# 1. Test File Type Inference
inferred_type = infer_pdb_file_type(json_path)
assert inferred_type == "mmjson", f"Failed to infer 'mmjson'. Got: {inferred_type}"

# 2. Parse mmJSON
result_json = parse(json_path, file_type="mmjson")
atoms_json = result_json["asym_unit"]

# 3. Parse CIF for Comparison
result_cif = parse(cif_path, file_type="cif")
atoms_cif = result_cif["asym_unit"]

# 4. Compare Results
# This utility checks atom count, coordinates, and other annotations
assert_same_atom_array(atoms_json, atoms_cif)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great!

Loading