Skip to content
52 changes: 27 additions & 25 deletions preprocessing/nextclade/src/loculus_preprocessing/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
import jwt
import pytz
import requests
from pydantic import ValidationError

from .config import Config
from .datatypes import (
BackendEntry,
FileUploadInfo,
InternalMetadata,
ProcessedEntry,
UnprocessedData,
UnprocessedEntry,
)
from .processing_functions import trim_ns

Expand Down Expand Up @@ -74,8 +76,26 @@ def get_jwt(config: Config) -> str:
raise Exception(error_msg)


def parse_ndjson(ndjson_data: str) -> Sequence[UnprocessedEntry]:
entries: list[UnprocessedEntry] = []
def _backend_entry_to_unprocessed(entry: BackendEntry) -> UnprocessedData:
accession_version = f"{entry.accession}.{entry.version}"
return UnprocessedData(
internal_metadata=InternalMetadata(
accession_version=accession_version,
submitter=entry.submitter,
group_id=entry.groupId,
submitted_at=entry.submittedAt,
submission_id=entry.submissionId,
),
metadata=entry.data.metadata,
unalignedNucleotideSequences={
key: trim_ns(value) if value else None
for key, value in entry.data.unalignedNucleotideSequences.items()
},
)


def parse_ndjson(ndjson_data: str) -> Sequence[UnprocessedData]:
entries: list[UnprocessedData] = []
if len(ndjson_data) == 0:
return entries
for json_str in ndjson_data.split("\n"):
Expand All @@ -84,35 +104,17 @@ def parse_ndjson(ndjson_data: str) -> Sequence[UnprocessedEntry]:
# Loculus currently cannot handle non-breaking spaces.
json_str_processed = json_str.replace("\N{NO-BREAK SPACE}", " ")
try:
json_object = json.loads(json_str_processed)
except json.JSONDecodeError as e:
backend_entry = BackendEntry.model_validate_json(json_str_processed)
except (json.JSONDecodeError, ValidationError) as e:
error_msg = f"Failed to parse JSON: {json_str_processed}"
raise ValueError(error_msg) from e
unaligned_nucleotide_sequences = json_object["data"]["unalignedNucleotideSequences"]
trimmed_unaligned_nucleotide_sequences = {
key: trim_ns(value) if value else None
for key, value in unaligned_nucleotide_sequences.items()
}
unprocessed_data = UnprocessedData(
submitter=json_object["submitter"],
group_id=json_object["groupId"],
submittedAt=json_object["submittedAt"],
metadata=json_object["data"]["metadata"],
unalignedNucleotideSequences=trimmed_unaligned_nucleotide_sequences
if unaligned_nucleotide_sequences
else {},
)
entry = UnprocessedEntry(
accessionVersion=f"{json_object['accession']}.{json_object['version']}",
data=unprocessed_data,
)
entries.append(entry)
entries.append(_backend_entry_to_unprocessed(backend_entry))
return entries


def fetch_unprocessed_sequences(
etag: str | None, config: Config
) -> tuple[str | None, Sequence[UnprocessedEntry] | None]:
) -> tuple[str | None, Sequence[UnprocessedData] | None]:
request_id = str(uuid.uuid4())
n = config.batch_size
url = config.backend_host.rstrip("/") + "/extract-unprocessed-data"
Expand Down
46 changes: 37 additions & 9 deletions preprocessing/nextclade/src/loculus_preprocessing/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from enum import StrEnum, unique
from typing import Any, Final

from pydantic import BaseModel

AccessionVersion = str
GeneName = str
SegmentName = str
Expand Down Expand Up @@ -74,28 +76,55 @@ def from_single(cls, name: str, type, message: str):
return cls.from_fields([name], [name], type, message)


class BackendEntryData(BaseModel):
metadata: InputMetadata
unalignedNucleotideSequences: dict[str, str | None] # noqa: N815
files: dict[str, list[dict[str, str]]] | None = None # filename to list of {fileId, name}


class BackendEntry(BaseModel):
accession: str
version: int
submitter: str
groupId: int # noqa: N815
submittedAt: int # noqa: N815 # Unix timestamp
submissionId: str # noqa: N815
data: BackendEntryData


@dataclass
class UnprocessedData:
class InternalMetadata:
accession_version: AccessionVersion # {accession}.{version}
submitter: str
group_id: int
submittedAt: str # timestamp # noqa: N815
metadata: InputMetadata
unalignedNucleotideSequences: dict[SequenceName, NucleotideSequence | None] # noqa: N815
submitted_at: int # timestamp
submission_id: str


@dataclass
class UnprocessedEntry:
accessionVersion: AccessionVersion # {accession}.{version} # noqa: N815
data: UnprocessedData
class UnprocessedData:
metadata: InputMetadata
internal_metadata: InternalMetadata
unalignedNucleotideSequences: dict[SequenceName, NucleotideSequence | None] # noqa: N815


FunctionInputs = dict[ArgName, InputField]
FunctionArgs = dict[ArgName, ArgValue]


@dataclass
class ProcessingFunctionCallArgs:
args: FunctionArgs
output_field: str
input_fields: list[str]
input_data: InputMetadata
internal_metadata: InternalMetadata


@dataclass
class UnprocessedAfterNextclade:
inputMetadata: InputMetadata # noqa: N815
internal_metadata: InternalMetadata
# Derived metadata produced by Nextclade
nextcladeMetadata: dict[SequenceName, Any] | None # noqa: N815
unalignedNucleotideSequences: dict[SequenceName, NucleotideSequence | None] # noqa: N815
Expand Down Expand Up @@ -156,8 +185,7 @@ class SubmissionData:
but the annotations need to be uploaded separately."""

processed_entry: ProcessedEntry
submitter: str | None
group_id: int | None = None
internal_metadata: InternalMetadata
annotations: dict[str, Any] | None = None


Expand Down
50 changes: 25 additions & 25 deletions preprocessing/nextclade/src/loculus_preprocessing/nextclade.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
FastaId,
GeneName,
GenericSequence,
InternalMetadata,
NucleotideInsertion,
NucleotideSequence,
ProcessingAnnotation,
Expand All @@ -35,7 +36,7 @@
SequenceAssignment,
SequenceAssignmentBatch,
UnprocessedAfterNextclade,
UnprocessedEntry,
UnprocessedData,
)

# https://stackoverflow.com/questions/15063936
Expand Down Expand Up @@ -346,7 +347,7 @@ def check_nextclade_sort_matches( # noqa: PLR0913, PLR0917


def write_nextclade_input_fasta(
unprocessed: Sequence[UnprocessedEntry], input_file: str
unprocessed: Sequence[UnprocessedData], input_file: str
) -> defaultdict[tuple[AccessionVersion, FastaId], str]:
"""
Write unprocessed sequences to a fasta file for nextclade input
Expand All @@ -355,8 +356,8 @@ def write_nextclade_input_fasta(
os.makedirs(os.path.dirname(input_file), exist_ok=True)
with open(input_file, "w", encoding="utf-8") as f:
for entry in unprocessed:
accession_version = entry.accessionVersion
for fasta_id, seq in entry.data.unalignedNucleotideSequences.items():
accession_version = entry.internal_metadata.accession_version
for fasta_id, seq in entry.unalignedNucleotideSequences.items():
id = f"{accession_version}__{fasta_id}"
id_map[accession_version, fasta_id] = id
f.write(f">{id}\n")
Expand All @@ -377,7 +378,7 @@ def is_valid_dataset_match(method, best_dataset_id, dataset):


def assign_segment( # noqa: C901
entry: UnprocessedEntry,
entry: UnprocessedData,
id_map: dict[tuple[AccessionVersion, FastaId], str],
best_hits: pd.DataFrame,
config: Config,
Expand All @@ -398,8 +399,8 @@ def assign_segment( # noqa: C901
has_unaligned_sequence = False
has_duplicate_segments = False

for fasta_id in entry.data.unalignedNucleotideSequences:
seq_id = id_map[entry.accessionVersion, fasta_id]
for fasta_id in entry.unalignedNucleotideSequences:
seq_id = id_map[entry.internal_metadata.accession_version, fasta_id]
if seq_id not in best_hits[SequenceIdentifier].unique():
has_unaligned_sequence = True
method = config.segment_classification_method.display_name
Expand Down Expand Up @@ -454,7 +455,7 @@ def assign_segment( # noqa: C901

sequence_assignment.sequenceNameToFastaId[ids[0].name] = ids[0].fasta_id
sequence_assignment.unalignedNucleotideSequences[ids[0].name] = (
entry.data.unalignedNucleotideSequences[ids[0].fasta_id]
entry.unalignedNucleotideSequences[ids[0].fasta_id]
)

if (
Expand All @@ -473,7 +474,7 @@ def assign_segment( # noqa: C901


def assign_segment_with_nextclade_align(
unprocessed: Sequence[UnprocessedEntry], config: Config, dataset_dir: str
unprocessed: Sequence[UnprocessedData], config: Config, dataset_dir: str
) -> SequenceAssignmentBatch:
"""
Run nextclade align
Expand Down Expand Up @@ -523,7 +524,7 @@ def assign_segment_with_nextclade_align(
best_hits,
config,
)
accession_version = entry.accessionVersion
accession_version = entry.internal_metadata.accession_version
batch.sequenceNameToFastaId[accession_version] = sequence_assignment.sequenceNameToFastaId
batch.unalignedNucleotideSequences[accession_version] = (
sequence_assignment.unalignedNucleotideSequences
Expand All @@ -534,7 +535,7 @@ def assign_segment_with_nextclade_align(


def assign_segment_with_nextclade_sort(
unprocessed: Sequence[UnprocessedEntry], config: Config, dataset_dir: str
unprocessed: Sequence[UnprocessedData], config: Config, dataset_dir: str
) -> SequenceAssignmentBatch:
"""
Run nextclade sort
Expand Down Expand Up @@ -566,7 +567,7 @@ def assign_segment_with_nextclade_sort(
best_hits,
config,
)
accession_version = entry.accessionVersion
accession_version = entry.internal_metadata.accession_version
batch.sequenceNameToFastaId[accession_version] = sequence_assignment.sequenceNameToFastaId
batch.unalignedNucleotideSequences[accession_version] = (
sequence_assignment.unalignedNucleotideSequences
Expand All @@ -576,7 +577,7 @@ def assign_segment_with_nextclade_sort(


def assign_segment_with_diamond(
unprocessed: Sequence[UnprocessedEntry], config: Config, dataset_dir: str
unprocessed: Sequence[UnprocessedData], config: Config, dataset_dir: str
) -> SequenceAssignmentBatch:
"""
Run diamond
Expand Down Expand Up @@ -608,7 +609,7 @@ def assign_segment_with_diamond(
best_hits,
config,
)
accession_version = entry.accessionVersion
accession_version = entry.internal_metadata.accession_version
batch.sequenceNameToFastaId[accession_version] = sequence_assignment.sequenceNameToFastaId
batch.unalignedNucleotideSequences[accession_version] = (
sequence_assignment.unalignedNucleotideSequences
Expand Down Expand Up @@ -642,13 +643,13 @@ def assign_single_segment(


def assign_all_single_segments(
unprocessed: Sequence[UnprocessedEntry], config: Config
unprocessed: Sequence[UnprocessedData], config: Config
) -> SequenceAssignmentBatch:
batch = SequenceAssignmentBatch()
for entry in unprocessed:
accession_version = entry.accessionVersion
accession_version = entry.internal_metadata.accession_version
sequence_assignment = assign_single_segment(
entry.data.unalignedNucleotideSequences,
entry.unalignedNucleotideSequences,
config=config,
)
batch.sequenceNameToFastaId[accession_version] = sequence_assignment.sequenceNameToFastaId
Expand Down Expand Up @@ -766,13 +767,14 @@ def load_aligned_aa_sequences(


def enrich_with_nextclade( # noqa: C901, PLR0914
unprocessed: Sequence[UnprocessedEntry], dataset_dir: str, config: Config
unprocessed: Sequence[UnprocessedData], dataset_dir: str, config: Config
) -> dict[AccessionVersion, UnprocessedAfterNextclade]:
"""
For each unprocessed segment of each unprocessed sequence use nextclade run to perform alignment
and QC. The result is a mapping from each AccessionVersion to an
`UnprocessedAfterNextclade(
inputMetadata: InputMetadata
internal_metadata: InternalMetadata
nextcladeMetadata: dict[SegmentName, Any] | None
unalignedNucleotideSequences: dict[SegmentName, NucleotideSequence | None]
alignedNucleotideSequences: dict[SegmentName, NucleotideSequence | None]
Expand All @@ -783,13 +785,10 @@ def enrich_with_nextclade( # noqa: C901, PLR0914
)` object.
"""
input_metadata: dict[AccessionVersion, dict[str, Any]] = {
entry.accessionVersion: {
**entry.data.metadata,
"submitter": entry.data.submitter,
"submittedAt": entry.data.submittedAt,
"group_id": entry.data.group_id,
}
for entry in unprocessed
entry.internal_metadata.accession_version: entry.metadata for entry in unprocessed
}
internal_metadata: dict[AccessionVersion, InternalMetadata] = {
entry.internal_metadata.accession_version: entry.internal_metadata for entry in unprocessed
}

if not config.multi_datasets:
Expand Down Expand Up @@ -897,6 +896,7 @@ def enrich_with_nextclade( # noqa: C901, PLR0914
return {
id: UnprocessedAfterNextclade(
inputMetadata=input_metadata[id],
internal_metadata=internal_metadata[id],
nextcladeMetadata=nextclade_metadata[id],
unalignedNucleotideSequences=unaligned_nucleotide_sequences[id],
alignedNucleotideSequences=aligned_nucleotide_sequences[id],
Expand Down
Loading
Loading