diff --git a/preprocessing/nextclade/src/loculus_preprocessing/backend.py b/preprocessing/nextclade/src/loculus_preprocessing/backend.py index cdc6fe3095..7d6c43ed4d 100644 --- a/preprocessing/nextclade/src/loculus_preprocessing/backend.py +++ b/preprocessing/nextclade/src/loculus_preprocessing/backend.py @@ -14,13 +14,15 @@ import jwt import pytz import requests +from pydantic import ValidationError from .config import Config from .datatypes import ( + BackendEntry, FileUploadInfo, + InternalMetadata, ProcessedEntry, UnprocessedData, - UnprocessedEntry, ) from .processing_functions import trim_ns @@ -74,8 +76,26 @@ def get_jwt(config: Config) -> str: raise Exception(error_msg) -def parse_ndjson(ndjson_data: str) -> Sequence[UnprocessedEntry]: - entries: list[UnprocessedEntry] = [] +def _backend_entry_to_unprocessed(entry: BackendEntry) -> UnprocessedData: + accession_version = f"{entry.accession}.{entry.version}" + return UnprocessedData( + internal_metadata=InternalMetadata( + accession_version=accession_version, + submitter=entry.submitter, + group_id=entry.groupId, + submitted_at=entry.submittedAt, + submission_id=entry.submissionId, + ), + metadata=entry.data.metadata, + unalignedNucleotideSequences={ + key: trim_ns(value) if value else None + for key, value in entry.data.unalignedNucleotideSequences.items() + }, + ) + + +def parse_ndjson(ndjson_data: str) -> Sequence[UnprocessedData]: + entries: list[UnprocessedData] = [] if len(ndjson_data) == 0: return entries for json_str in ndjson_data.split("\n"): @@ -84,35 +104,17 @@ def parse_ndjson(ndjson_data: str) -> Sequence[UnprocessedEntry]: # Loculus currently cannot handle non-breaking spaces. json_str_processed = json_str.replace("\N{NO-BREAK SPACE}", " ") try: - json_object = json.loads(json_str_processed) - except json.JSONDecodeError as e: + backend_entry = BackendEntry.model_validate_json(json_str_processed) + except (json.JSONDecodeError, ValidationError) as e: error_msg = f"Failed to parse JSON: {json_str_processed}" raise ValueError(error_msg) from e - unaligned_nucleotide_sequences = json_object["data"]["unalignedNucleotideSequences"] - trimmed_unaligned_nucleotide_sequences = { - key: trim_ns(value) if value else None - for key, value in unaligned_nucleotide_sequences.items() - } - unprocessed_data = UnprocessedData( - submitter=json_object["submitter"], - group_id=json_object["groupId"], - submittedAt=json_object["submittedAt"], - metadata=json_object["data"]["metadata"], - unalignedNucleotideSequences=trimmed_unaligned_nucleotide_sequences - if unaligned_nucleotide_sequences - else {}, - ) - entry = UnprocessedEntry( - accessionVersion=f"{json_object['accession']}.{json_object['version']}", - data=unprocessed_data, - ) - entries.append(entry) + entries.append(_backend_entry_to_unprocessed(backend_entry)) return entries def fetch_unprocessed_sequences( etag: str | None, config: Config -) -> tuple[str | None, Sequence[UnprocessedEntry] | None]: +) -> tuple[str | None, Sequence[UnprocessedData] | None]: request_id = str(uuid.uuid4()) n = config.batch_size url = config.backend_host.rstrip("/") + "/extract-unprocessed-data" diff --git a/preprocessing/nextclade/src/loculus_preprocessing/datatypes.py b/preprocessing/nextclade/src/loculus_preprocessing/datatypes.py index 9fc88bcb8d..a95a1fe344 100644 --- a/preprocessing/nextclade/src/loculus_preprocessing/datatypes.py +++ b/preprocessing/nextclade/src/loculus_preprocessing/datatypes.py @@ -3,6 +3,8 @@ from enum import StrEnum, unique from typing import Any, Final +from pydantic import BaseModel + AccessionVersion = str GeneName = str SegmentName = str @@ -74,28 +76,55 @@ def from_single(cls, name: str, type, message: str): return cls.from_fields([name], [name], type, message) +class BackendEntryData(BaseModel): + metadata: InputMetadata + unalignedNucleotideSequences: dict[str, str | None] # noqa: N815 + files: dict[str, list[dict[str, str]]] | None = None # filename to list of {fileId, name} + + +class BackendEntry(BaseModel): + accession: str + version: int + submitter: str + groupId: int # noqa: N815 + submittedAt: int # noqa: N815 # Unix timestamp + submissionId: str # noqa: N815 + data: BackendEntryData + + @dataclass -class UnprocessedData: +class InternalMetadata: + accession_version: AccessionVersion # {accession}.{version} submitter: str group_id: int - submittedAt: str # timestamp # noqa: N815 - metadata: InputMetadata - unalignedNucleotideSequences: dict[SequenceName, NucleotideSequence | None] # noqa: N815 + submitted_at: int # timestamp + submission_id: str @dataclass -class UnprocessedEntry: - accessionVersion: AccessionVersion # {accession}.{version} # noqa: N815 - data: UnprocessedData +class UnprocessedData: + metadata: InputMetadata + internal_metadata: InternalMetadata + unalignedNucleotideSequences: dict[SequenceName, NucleotideSequence | None] # noqa: N815 FunctionInputs = dict[ArgName, InputField] FunctionArgs = dict[ArgName, ArgValue] +@dataclass +class ProcessingFunctionCallArgs: + args: FunctionArgs + output_field: str + input_fields: list[str] + input_data: InputMetadata + internal_metadata: InternalMetadata + + @dataclass class UnprocessedAfterNextclade: inputMetadata: InputMetadata # noqa: N815 + internal_metadata: InternalMetadata # Derived metadata produced by Nextclade nextcladeMetadata: dict[SequenceName, Any] | None # noqa: N815 unalignedNucleotideSequences: dict[SequenceName, NucleotideSequence | None] # noqa: N815 @@ -156,8 +185,7 @@ class SubmissionData: but the annotations need to be uploaded separately.""" processed_entry: ProcessedEntry - submitter: str | None - group_id: int | None = None + internal_metadata: InternalMetadata annotations: dict[str, Any] | None = None diff --git a/preprocessing/nextclade/src/loculus_preprocessing/nextclade.py b/preprocessing/nextclade/src/loculus_preprocessing/nextclade.py index 2ddc171c61..f4657aeabd 100644 --- a/preprocessing/nextclade/src/loculus_preprocessing/nextclade.py +++ b/preprocessing/nextclade/src/loculus_preprocessing/nextclade.py @@ -26,6 +26,7 @@ FastaId, GeneName, GenericSequence, + InternalMetadata, NucleotideInsertion, NucleotideSequence, ProcessingAnnotation, @@ -35,7 +36,7 @@ SequenceAssignment, SequenceAssignmentBatch, UnprocessedAfterNextclade, - UnprocessedEntry, + UnprocessedData, ) # https://stackoverflow.com/questions/15063936 @@ -346,7 +347,7 @@ def check_nextclade_sort_matches( # noqa: PLR0913, PLR0917 def write_nextclade_input_fasta( - unprocessed: Sequence[UnprocessedEntry], input_file: str + unprocessed: Sequence[UnprocessedData], input_file: str ) -> defaultdict[tuple[AccessionVersion, FastaId], str]: """ Write unprocessed sequences to a fasta file for nextclade input @@ -355,8 +356,8 @@ def write_nextclade_input_fasta( os.makedirs(os.path.dirname(input_file), exist_ok=True) with open(input_file, "w", encoding="utf-8") as f: for entry in unprocessed: - accession_version = entry.accessionVersion - for fasta_id, seq in entry.data.unalignedNucleotideSequences.items(): + accession_version = entry.internal_metadata.accession_version + for fasta_id, seq in entry.unalignedNucleotideSequences.items(): id = f"{accession_version}__{fasta_id}" id_map[accession_version, fasta_id] = id f.write(f">{id}\n") @@ -377,7 +378,7 @@ def is_valid_dataset_match(method, best_dataset_id, dataset): def assign_segment( # noqa: C901 - entry: UnprocessedEntry, + entry: UnprocessedData, id_map: dict[tuple[AccessionVersion, FastaId], str], best_hits: pd.DataFrame, config: Config, @@ -398,8 +399,8 @@ def assign_segment( # noqa: C901 has_unaligned_sequence = False has_duplicate_segments = False - for fasta_id in entry.data.unalignedNucleotideSequences: - seq_id = id_map[entry.accessionVersion, fasta_id] + for fasta_id in entry.unalignedNucleotideSequences: + seq_id = id_map[entry.internal_metadata.accession_version, fasta_id] if seq_id not in best_hits[SequenceIdentifier].unique(): has_unaligned_sequence = True method = config.segment_classification_method.display_name @@ -454,7 +455,7 @@ def assign_segment( # noqa: C901 sequence_assignment.sequenceNameToFastaId[ids[0].name] = ids[0].fasta_id sequence_assignment.unalignedNucleotideSequences[ids[0].name] = ( - entry.data.unalignedNucleotideSequences[ids[0].fasta_id] + entry.unalignedNucleotideSequences[ids[0].fasta_id] ) if ( @@ -473,7 +474,7 @@ def assign_segment( # noqa: C901 def assign_segment_with_nextclade_align( - unprocessed: Sequence[UnprocessedEntry], config: Config, dataset_dir: str + unprocessed: Sequence[UnprocessedData], config: Config, dataset_dir: str ) -> SequenceAssignmentBatch: """ Run nextclade align @@ -523,7 +524,7 @@ def assign_segment_with_nextclade_align( best_hits, config, ) - accession_version = entry.accessionVersion + accession_version = entry.internal_metadata.accession_version batch.sequenceNameToFastaId[accession_version] = sequence_assignment.sequenceNameToFastaId batch.unalignedNucleotideSequences[accession_version] = ( sequence_assignment.unalignedNucleotideSequences @@ -534,7 +535,7 @@ def assign_segment_with_nextclade_align( def assign_segment_with_nextclade_sort( - unprocessed: Sequence[UnprocessedEntry], config: Config, dataset_dir: str + unprocessed: Sequence[UnprocessedData], config: Config, dataset_dir: str ) -> SequenceAssignmentBatch: """ Run nextclade sort @@ -566,7 +567,7 @@ def assign_segment_with_nextclade_sort( best_hits, config, ) - accession_version = entry.accessionVersion + accession_version = entry.internal_metadata.accession_version batch.sequenceNameToFastaId[accession_version] = sequence_assignment.sequenceNameToFastaId batch.unalignedNucleotideSequences[accession_version] = ( sequence_assignment.unalignedNucleotideSequences @@ -576,7 +577,7 @@ def assign_segment_with_nextclade_sort( def assign_segment_with_diamond( - unprocessed: Sequence[UnprocessedEntry], config: Config, dataset_dir: str + unprocessed: Sequence[UnprocessedData], config: Config, dataset_dir: str ) -> SequenceAssignmentBatch: """ Run diamond @@ -608,7 +609,7 @@ def assign_segment_with_diamond( best_hits, config, ) - accession_version = entry.accessionVersion + accession_version = entry.internal_metadata.accession_version batch.sequenceNameToFastaId[accession_version] = sequence_assignment.sequenceNameToFastaId batch.unalignedNucleotideSequences[accession_version] = ( sequence_assignment.unalignedNucleotideSequences @@ -642,13 +643,13 @@ def assign_single_segment( def assign_all_single_segments( - unprocessed: Sequence[UnprocessedEntry], config: Config + unprocessed: Sequence[UnprocessedData], config: Config ) -> SequenceAssignmentBatch: batch = SequenceAssignmentBatch() for entry in unprocessed: - accession_version = entry.accessionVersion + accession_version = entry.internal_metadata.accession_version sequence_assignment = assign_single_segment( - entry.data.unalignedNucleotideSequences, + entry.unalignedNucleotideSequences, config=config, ) batch.sequenceNameToFastaId[accession_version] = sequence_assignment.sequenceNameToFastaId @@ -766,13 +767,14 @@ def load_aligned_aa_sequences( def enrich_with_nextclade( # noqa: C901, PLR0914 - unprocessed: Sequence[UnprocessedEntry], dataset_dir: str, config: Config + unprocessed: Sequence[UnprocessedData], dataset_dir: str, config: Config ) -> dict[AccessionVersion, UnprocessedAfterNextclade]: """ For each unprocessed segment of each unprocessed sequence use nextclade run to perform alignment and QC. The result is a mapping from each AccessionVersion to an `UnprocessedAfterNextclade( inputMetadata: InputMetadata + internal_metadata: InternalMetadata nextcladeMetadata: dict[SegmentName, Any] | None unalignedNucleotideSequences: dict[SegmentName, NucleotideSequence | None] alignedNucleotideSequences: dict[SegmentName, NucleotideSequence | None] @@ -783,13 +785,10 @@ def enrich_with_nextclade( # noqa: C901, PLR0914 )` object. """ input_metadata: dict[AccessionVersion, dict[str, Any]] = { - entry.accessionVersion: { - **entry.data.metadata, - "submitter": entry.data.submitter, - "submittedAt": entry.data.submittedAt, - "group_id": entry.data.group_id, - } - for entry in unprocessed + entry.internal_metadata.accession_version: entry.metadata for entry in unprocessed + } + internal_metadata: dict[AccessionVersion, InternalMetadata] = { + entry.internal_metadata.accession_version: entry.internal_metadata for entry in unprocessed } if not config.multi_datasets: @@ -897,6 +896,7 @@ def enrich_with_nextclade( # noqa: C901, PLR0914 return { id: UnprocessedAfterNextclade( inputMetadata=input_metadata[id], + internal_metadata=internal_metadata[id], nextcladeMetadata=nextclade_metadata[id], unalignedNucleotideSequences=unaligned_nucleotide_sequences[id], alignedNucleotideSequences=aligned_nucleotide_sequences[id], diff --git a/preprocessing/nextclade/src/loculus_preprocessing/prepro.py b/preprocessing/nextclade/src/loculus_preprocessing/prepro.py index dd6d4c36e0..4c17871d88 100644 --- a/preprocessing/nextclade/src/loculus_preprocessing/prepro.py +++ b/preprocessing/nextclade/src/loculus_preprocessing/prepro.py @@ -26,6 +26,7 @@ GeneName, InputData, InputMetadata, + InternalMetadata, NucleotideInsertion, NucleotideSequence, ProcessedData, @@ -34,13 +35,13 @@ ProcessedMetadataValue, ProcessingAnnotation, ProcessingAnnotationAlignment, + ProcessingFunctionCallArgs, ProcessingResult, SegmentClassificationMethod, SegmentName, SubmissionData, UnprocessedAfterNextclade, UnprocessedData, - UnprocessedEntry, ) from .embl import create_flatfile from .nextclade import ( @@ -128,7 +129,7 @@ def truncate_after_wildcard(path: str, separator: str = ".") -> str: return path -def add_nextclade_metadata( +def add_nextclade_metadata( # noqa: PLR0911 spec: ProcessingSpec, unprocessed: UnprocessedAfterNextclade, nextclade_path: str, @@ -230,27 +231,26 @@ def add_input_metadata( def _call_processing_function( # noqa: PLR0913, PLR0917 - accession_version: AccessionVersion, spec: ProcessingSpec, output_field: str, - group_id: int | None, - submitted_at: str | None, input_data: InputMetadata, + internal_metadata: InternalMetadata, input_fields: list[str], config: Config, ) -> ProcessingResult: args = dict(spec.args) if spec.args else {} - args["is_insdc_ingest_group"] = config.insdc_ingest_group_id == group_id - args["submittedAt"] = submitted_at - args["ACCESSION_VERSION"] = accession_version + args["is_insdc_ingest_group"] = config.insdc_ingest_group_id == internal_metadata.group_id try: processing_result = ProcessingFunctions.call_function( - spec.function, - args, - input_data, - output_field, - input_fields, + function_name=spec.function, + call_args=ProcessingFunctionCallArgs( + args=args, + output_field=output_field, + input_fields=input_fields, + input_data=input_data, + internal_metadata=internal_metadata, + ), ) except Exception as e: msg = f"Processing for spec: {spec} with input data: {input_data} failed with {e}" @@ -259,8 +259,7 @@ def _call_processing_function( # noqa: PLR0913, PLR0917 return processing_result -def processed_entry_no_alignment( # noqa: PLR0913, PLR0917 - accession_version: AccessionVersion, +def processed_entry_no_alignment( unprocessed: UnprocessedData, output_metadata: ProcessedMetadata, errors: list[ProcessingAnnotation], @@ -268,7 +267,7 @@ def processed_entry_no_alignment( # noqa: PLR0913, PLR0917 sequenceNameToFastaId: dict[SequenceName, str], # noqa: N803 ) -> SubmissionData: """Process a single sequence without alignment""" - + accession_version = unprocessed.internal_metadata.accession_version aligned_nucleotide_sequences: dict[SequenceName, NucleotideSequence | None] = {} aligned_aminoacid_sequences: dict[GeneName, AminoAcidSequence | None] = {} nucleotide_insertions: dict[SequenceName, list[NucleotideInsertion]] = {} @@ -290,7 +289,7 @@ def processed_entry_no_alignment( # noqa: PLR0913, PLR0917 errors=errors, warnings=warnings, ), - submitter=unprocessed.submitter, + internal_metadata=unprocessed.internal_metadata, ) @@ -305,7 +304,6 @@ def get_sequence_length( def get_output_metadata( - accession_version: AccessionVersion, unprocessed: UnprocessedData | UnprocessedAfterNextclade, config: Config, ) -> tuple[ProcessedMetadata, list[ProcessingAnnotation], list[ProcessingAnnotation]]: @@ -354,27 +352,17 @@ def get_output_metadata( errors.extend(input_metadata.errors) warnings.extend(input_metadata.warnings) input_fields.append(input_path) - group_id = ( - int(unprocessed.inputMetadata["group_id"]) - if unprocessed.inputMetadata["group_id"] - else None - ) - submitted_at = unprocessed.inputMetadata["submittedAt"] else: input_data[arg_name] = unprocessed.metadata.get(input_path) input_fields.append(input_path) - group_id = unprocessed.group_id - submitted_at = unprocessed.submittedAt processing_result = _call_processing_function( - accession_version=accession_version, spec=spec, output_field=output_field, - group_id=group_id, - submitted_at=submitted_at, input_data=input_data, input_fields=input_fields, config=config, + internal_metadata=unprocessed.internal_metadata, ) output_metadata[output_field] = processing_result.datum @@ -383,7 +371,7 @@ def get_output_metadata( if ( null_per_backend(processing_result.datum) and spec.required - and group_id != config.insdc_ingest_group_id + and unprocessed.internal_metadata.group_id != config.insdc_ingest_group_id ): errors.append( ProcessingAnnotation.from_fields( @@ -393,7 +381,7 @@ def get_output_metadata( message=f"Metadata field {output_field} is required.", ) ) - logger.debug(f"Processed {accession_version}: {output_metadata}") + logger.debug(f"Processed {unprocessed.internal_metadata.accession_version}: {output_metadata}") return output_metadata, errors, warnings @@ -482,9 +470,7 @@ def process_single( config, ) - output_metadata, metadata_errors, metadata_warnings = get_output_metadata( - accession_version, unprocessed, config - ) + output_metadata, metadata_errors, metadata_warnings = get_output_metadata(unprocessed, config) processed_entry = ProcessedEntry( accession=accession_from_str(accession_version), @@ -505,13 +491,11 @@ def process_single( return SubmissionData( processed_entry=processed_entry, annotations=unpack_annotations(config, unprocessed.nextcladeMetadata), - group_id=int(str(unprocessed.inputMetadata["group_id"])), - submitter=str(unprocessed.inputMetadata["submitter"]), + internal_metadata=unprocessed.internal_metadata, ) def process_single_unaligned( - accession_version: AccessionVersion, unprocessed: UnprocessedData, config: Config, ) -> SubmissionData: @@ -523,12 +507,9 @@ def process_single_unaligned( unprocessed.unalignedNucleotideSequences = segment_assignment.unalignedNucleotideSequences iupac_errors = errors_if_non_iupac(unprocessed.unalignedNucleotideSequences) - output_metadata, metadata_errors, metadata_warnings = get_output_metadata( - accession_version, unprocessed, config - ) + output_metadata, metadata_errors, metadata_warnings = get_output_metadata(unprocessed, config) return processed_entry_no_alignment( - accession_version=accession_version, unprocessed=unprocessed, output_metadata=output_metadata, errors=list(set(iupac_errors + metadata_errors + segment_assignment.alert.errors)), @@ -537,7 +518,8 @@ def process_single_unaligned( ) -def processed_entry_with_errors(id) -> SubmissionData: +def processed_entry_with_errors(internal_metadata: InternalMetadata) -> SubmissionData: + id = internal_metadata.accession_version return SubmissionData( processed_entry=ProcessedEntry( accession=accession_from_str(id), @@ -563,12 +545,12 @@ def processed_entry_with_errors(id) -> SubmissionData: ], warnings=[], ), - submitter=None, + internal_metadata=internal_metadata, ) def process_all( - unprocessed: Sequence[UnprocessedEntry], dataset_dir: str, config: Config + unprocessed: Sequence[UnprocessedData], dataset_dir: str, config: Config ) -> Sequence[SubmissionData]: processed_results = [] logger.debug(f"Processing {len(unprocessed)} unprocessed sequences") @@ -579,17 +561,17 @@ def process_all( processed_single = process_single(id, result, config) except Exception as e: logger.error(f"Processing failed for {id} with error: {e}") - processed_single = processed_entry_with_errors(id) + processed_single = processed_entry_with_errors(result.internal_metadata) processed_results.append(processed_single) else: for entry in unprocessed: try: - processed_single = process_single_unaligned( - entry.accessionVersion, entry.data, config - ) + processed_single = process_single_unaligned(entry, config) except Exception as e: - logger.error(f"Processing failed for {entry.accessionVersion} with error: {e}") - processed_single = processed_entry_with_errors(entry.accessionVersion) + logger.error( + f"Processing failed for {entry.internal_metadata.accession_version} with error: {e}" + ) + processed_single = processed_entry_with_errors(entry.internal_metadata) processed_results.append(processed_single) return processed_results @@ -600,12 +582,9 @@ def upload_flatfiles(processed: Sequence[SubmissionData], config: Config) -> Non accession = submission_data.processed_entry.accession version = submission_data.processed_entry.version try: - if submission_data.group_id is None: - msg = "Group ID is required for EMBL file upload" - raise ValueError(msg) file_content = create_flatfile(config, submission_data) file_name = f"{accession}.{version}.embl" - upload_info = request_upload(submission_data.group_id, 1, config)[0] + upload_info = request_upload(submission_data.internal_metadata.group_id, 1, config)[0] file_id = upload_info.fileId url = upload_info.url upload_embl_file_to_presigned_url(file_content, url) diff --git a/preprocessing/nextclade/src/loculus_preprocessing/processing_functions.py b/preprocessing/nextclade/src/loculus_preprocessing/processing_functions.py index 85b4fc3449..55dc797d32 100644 --- a/preprocessing/nextclade/src/loculus_preprocessing/processing_functions.py +++ b/preprocessing/nextclade/src/loculus_preprocessing/processing_functions.py @@ -10,7 +10,7 @@ import math import re import unicodedata -from dataclasses import dataclass +from dataclasses import dataclass, replace from datetime import datetime from typing import Any @@ -25,6 +25,7 @@ InputMetadata, ProcessedMetadataValue, ProcessingAnnotation, + ProcessingFunctionCallArgs, ProcessingResult, ) @@ -185,10 +186,7 @@ class ProcessingFunctions: def call_function( cls, function_name: str, - args: FunctionArgs, - input_data: InputMetadata, - output_field: str, - input_fields: list[str], + call_args: ProcessingFunctionCallArgs, ) -> ProcessingResult: if not hasattr(cls, function_name): msg = ( @@ -198,11 +196,11 @@ def call_function( raise ValueError(msg) func = getattr(cls, function_name) try: - result = func(input_data, output_field, input_fields=input_fields, args=args) + result = func(call_args) except Exception as e: message = ( - f"Error calling function {function_name} for output field {output_field} " - f"with input {input_data} and args {args}: {e}" + f"Error calling function {function_name} for output field {call_args.output_field} " + f"with input {call_args.input_data} and args {call_args.args}: {e}" ) logger.exception(message) return ProcessingResult( @@ -211,15 +209,17 @@ def call_function( errors=[ ProcessingAnnotation( processedFields=[ - AnnotationSource(name=output_field, type=AnnotationSourceType.METADATA) + AnnotationSource( + name=call_args.output_field, type=AnnotationSourceType.METADATA + ) ], unprocessedFields=[ AnnotationSource(name=field, type=AnnotationSourceType.METADATA) - for field in input_fields + for field in call_args.input_fields ], message=( f"Internal Error: Function {function_name} did not return " - f"ProcessingResult with input {input_data} and args {args}, " + f"ProcessingResult with input {call_args.input_data} and args {call_args.args}, " "please contact the administrator." ), ) @@ -228,7 +228,7 @@ def call_function( if not isinstance(result, ProcessingResult): logger.error( f"ERROR: Function {function_name} did not return ProcessingResult " - f"given input {input_data} and args {args}. " + f"given input {call_args.input_data} and args {call_args.args}. " "This is likely a preprocessing bug." ) return ProcessingResult( @@ -236,12 +236,12 @@ def call_function( warnings=[], errors=[ ProcessingAnnotation.from_fields( - input_fields, - [output_field], + call_args.input_fields, + [call_args.output_field], AnnotationSourceType.METADATA, message=( f"Internal Error: Function {function_name} did not return " - f"ProcessingResult with input {input_data} and args {args}, " + f"ProcessingResult with input {call_args.input_data} and args {call_args.args}, " "please contact the administrator." ), ) @@ -251,17 +251,14 @@ def call_function( @staticmethod def check_date( - input_data: InputMetadata, - output_field: str, - input_fields: list[str], - args: FunctionArgs, # args is essential - even if Pylance says it's not used + call_args: ProcessingFunctionCallArgs, ) -> ProcessingResult: """Check that date is complete YYYY-MM-DD If not according to format return error If in future, return warning Expects input_data to be an ordered dictionary with a single key "date" """ - date = input_data["date"] + date = call_args.input_data["date"] if not date: return ProcessingResult( @@ -277,8 +274,8 @@ def check_date( if parsed_date > datetime.now(tz=pytz.utc): warnings.append( ProcessingAnnotation.from_fields( - input_fields, - [output_field], + call_args.input_fields, + [call_args.output_field], AnnotationSourceType.METADATA, message="Date is in the future.", ) @@ -293,8 +290,8 @@ def check_date( warnings=warnings, errors=[ ProcessingAnnotation.from_fields( - input_fields, - [output_field], + call_args.input_fields, + [call_args.output_field], AnnotationSourceType.METADATA, message=error_message, ) @@ -302,11 +299,8 @@ def check_date( ) @staticmethod - def parse_date_into_range( - input_data: InputMetadata, - output_field: str, - input_fields: list[str], - args: FunctionArgs, # args is essential - even if Pylance says it's not used + def parse_date_into_range( # noqa: C901, PLR0912, PLR0915 + call_args: ProcessingFunctionCallArgs, ) -> ProcessingResult: """Parse date string (`input.date`) formatted as one of YYYY | YYYY-MM | YYYY-MM-DD into a range using upper bound (`input.releaseDate`) @@ -314,33 +308,32 @@ def parse_date_into_range( fieldType: "dateRangeString" | "dateRangeLower" | "dateRangeUpper" Default fieldType is "dateRangeString" """ - if not args: - args = {"fieldType": "dateRangeString"} - - logger.debug(f"input_data: {input_data}") - - input_date_str = input_data["date"] + logger.debug(f"input_data: {call_args.input_data}") - release_date_str = input_data.get("releaseDate", "") or "" + input_date_str = call_args.input_data["date"] + release_date_str = call_args.input_data.get("releaseDate", "") or "" + args = call_args.args or {"fieldType": "dateRangeString"} try: release_date = dateutil.parse(release_date_str).replace(tzinfo=pytz.utc) except Exception: release_date = None try: - submitted_at = datetime.fromtimestamp(float(str(args["submittedAt"])), tz=pytz.utc) + submitted_at = datetime.fromtimestamp( + call_args.internal_metadata.submitted_at, tz=pytz.utc + ) except Exception: return ProcessingResult( datum=None, warnings=[], errors=[ ProcessingAnnotation.from_fields( - input_fields, - [output_field], + call_args.input_fields, + [call_args.output_field], AnnotationSourceType.METADATA, message=( f"Internal Error: Function parse_into_ranges did not receive valid " - f"submittedAt date, with input {input_data} and args {args}, " + f"submittedAt date, with input {call_args.input_data} and args {call_args.args}, " "please contact the administrator." ), ) @@ -415,10 +408,11 @@ class DateRange: if message: warnings.append( ProcessingAnnotation.from_fields( - input_fields, - [output_field], + call_args.input_fields, + [call_args.output_field], AnnotationSourceType.METADATA, - message=f"Metadata field {output_field}:'{input_date_str}' - " + message, + message=f"Metadata field {call_args.output_field}:'{input_date_str}' - " + + message, ) ) @@ -428,11 +422,11 @@ class DateRange: ) errors.append( ProcessingAnnotation.from_fields( - input_fields, - [output_field], + call_args.input_fields, + [call_args.output_field], AnnotationSourceType.METADATA, message=( - f"Metadata field {output_field}:'{input_date_str}' is in the future." + f"Metadata field {call_args.output_field}:'{input_date_str}' is in the future." ), ) ) @@ -441,11 +435,11 @@ class DateRange: logger.debug(f"Lower range of date: {parsed_date} > release_date: {release_date}") errors.append( ProcessingAnnotation.from_fields( - input_fields, - [output_field], + call_args.input_fields, + [call_args.output_field], AnnotationSourceType.METADATA, message=( - f"Metadata field {output_field}:'{input_date_str}'" + f"Metadata field {call_args.output_field}:'{input_date_str}'" "is after release date." ), ) @@ -475,10 +469,10 @@ class DateRange: warnings=[], errors=[ ProcessingAnnotation.from_fields( - input_fields, - [output_field], + call_args.input_fields, + [call_args.output_field], AnnotationSourceType.METADATA, - message=f"Metadata field {output_field}: " + message=f"Metadata field {call_args.output_field}: " f"Date {input_date_str} could not be parsed.", ) ], @@ -486,18 +480,15 @@ class DateRange: @staticmethod def parse_and_assert_past_date( # noqa: C901 - input_data: InputMetadata, - output_field, - input_fields: list[str], - args: FunctionArgs, # args is essential - even if Pylance says it's not used + call_args: ProcessingFunctionCallArgs, ) -> ProcessingResult: """Parse date string. If it's incomplete, add 01-01, if no year, return null and error input_data: date: str, date string to parse release_date: str, optional release date to compare against if None use today """ - logger.debug(f"input_data: {input_data}") - date_str = input_data["date"] + logger.debug(f"input_data: {call_args.input_data}") + date_str = call_args.input_data["date"] if not date_str: return ProcessingResult( @@ -505,7 +496,7 @@ def parse_and_assert_past_date( # noqa: C901 warnings=[], errors=[], ) - release_date_str = input_data.get("release_date", "") or "" + release_date_str = call_args.input_data.get("release_date", "") or "" try: release_date = dateutil.parse(release_date_str) except Exception: @@ -538,10 +529,11 @@ def parse_and_assert_past_date( # noqa: C901 if message: warnings.append( ProcessingAnnotation.from_fields( - input_fields, - [output_field], + call_args.input_fields, + [call_args.output_field], AnnotationSourceType.METADATA, - message=f"Metadata field {output_field}:'{date_str}' - " + message, + message=f"Metadata field {call_args.output_field}:'{date_str}' - " + + message, ) ) @@ -549,10 +541,10 @@ def parse_and_assert_past_date( # noqa: C901 logger.debug(f"parsed_date: {parsed_date} > {datetime.now(tz=pytz.utc)}") errors.append( ProcessingAnnotation.from_fields( - input_fields, - [output_field], + call_args.input_fields, + [call_args.output_field], AnnotationSourceType.METADATA, - message=f"Metadata field {output_field}:'{date_str}' is in the future.", + message=f"Metadata field {call_args.output_field}:'{date_str}' is in the future.", ) ) @@ -560,11 +552,11 @@ def parse_and_assert_past_date( # noqa: C901 logger.debug(f"parsed_date: {parsed_date} > release_date: {release_date}") errors.append( ProcessingAnnotation.from_fields( - input_fields, - [output_field], + call_args.input_fields, + [call_args.output_field], AnnotationSourceType.METADATA, message=( - f"Metadata field {output_field}:'{date_str}'is after release date." + f"Metadata field {call_args.output_field}:'{date_str}'is after release date." ), ) ) @@ -579,23 +571,20 @@ def parse_and_assert_past_date( # noqa: C901 warnings=[], errors=[ ProcessingAnnotation.from_fields( - input_fields, - [output_field], + call_args.input_fields, + [call_args.output_field], AnnotationSourceType.METADATA, - message=f"Metadata field {output_field}: Date format is not recognized.", + message=f"Metadata field {call_args.output_field}: Date format is not recognized.", ) ], ) @staticmethod def parse_timestamp( - input_data: InputMetadata, - output_field: str, - input_fields: list[str], - args: FunctionArgs, # args is essential - even if Pylance says it's not used + call_args: ProcessingFunctionCallArgs, ) -> ProcessingResult: """Parse a timestamp string, e.g. 2022-11-01T00:00:00Z and return a YYYY-MM-DD string""" - timestamp = input_data["timestamp"] + timestamp = call_args.input_data["timestamp"] if not timestamp: return ProcessingResult( @@ -622,8 +611,8 @@ def parse_timestamp( datum=None, errors=[ ProcessingAnnotation.from_fields( - input_fields, - [output_field], + call_args.input_fields, + [call_args.output_field], AnnotationSourceType.METADATA, message=error_message, ) @@ -632,11 +621,8 @@ def parse_timestamp( ) @staticmethod - def concatenate( - input_data: InputMetadata, - output_field: str, - input_fields: list[str], - args: FunctionArgs, + def concatenate( # noqa: C901, PLR0911 + call_args: ProcessingFunctionCallArgs, ) -> ProcessingResult: """Concatenates input fields using the "/" separator in the order specified by the order argument. Optionally, a 'fallback_value' argument can be provided. @@ -646,47 +632,50 @@ def concatenate( warnings: list[ProcessingAnnotation] = [] errors: list[ProcessingAnnotation] = [] - if not isinstance(args["ACCESSION_VERSION"], str): + if not isinstance(call_args.internal_metadata.accession_version, str): return ProcessingResult( datum=None, warnings=[], errors=[ ProcessingAnnotation.from_fields( - input_fields, - [output_field], + call_args.input_fields, + [call_args.output_field], AnnotationSourceType.METADATA, message=( "Internal Error: Function concatenate did not receive " - f"accession_version ProcessingResult with input {input_data} " - f"and args {args}, please contact the administrator." + f"accession_version ProcessingResult with input {call_args.input_data} " + f"and args {call_args.args}, please contact the administrator." ), ) ], ) - accession_version: str = args["ACCESSION_VERSION"] - order = args["order"] - field_types = args["type"] + order = call_args.args["order"] + field_types = call_args.args["type"] fallback_value = ( - str(args["fallback_value"]).strip() if args.get("fallback_value") is not None else "" + str(call_args.args["fallback_value"]).strip() + if call_args.args.get("fallback_value") is not None + else "" ) def add_errors(): errors.append( ProcessingAnnotation.from_fields( - input_fields, - [output_field], + call_args.input_fields, + [call_args.output_field], AnnotationSourceType.METADATA, message="Concatenation failed." "This may be a configuration error, please contact the administrator.", ) ) - if not isinstance(order, list): + def logger_error(message): logger.error( - f"Concatenate: Expected order field to be a list. " - f"This is probably a configuration error. (ACCESSION_VERSION: {accession_version})" + f"{message} (ACCESSION_VERSION: {call_args.internal_metadata.accession_version}) " ) + + if not isinstance(order, list): + logger_error("Concatenate: Expected order field to be a list. ") add_errors() return ProcessingResult( datum=None, @@ -694,14 +683,11 @@ def add_errors(): errors=errors, ) - n_inputs = len(input_data.keys()) - # exclude ACCESSION_VERSION as it's provided by _call_preprocessing_function() and should not be an input_metadata field + n_inputs = len(call_args.input_data.keys()) + # exclude ACCESSION_VERSION as it's provided via internal_metadata and is not in input_metadata n_expected = len([i for i in order if i != "ACCESSION_VERSION"]) if n_inputs != n_expected: - logger.error( - f"Concatenate: Expected {n_expected} fields, got {n_inputs}. " - f"This is probably a configuration error. (ACCESSION_VERSION: {accession_version})" - ) + logger_error(f"Concatenate: Expected {n_expected} fields, got {n_inputs}. ") add_errors() return ProcessingResult( datum=None, @@ -709,10 +695,7 @@ def add_errors(): errors=errors, ) if not isinstance(field_types, list): - logger.error( - f"Concatenate: Expected type field to be a list. " - f"This is probably a configuration error. (ACCESSION_VERSION: {accession_version})" - ) + logger_error("Concatenate: Expected type field to be a list. ") add_errors() return ProcessingResult( datum=None, @@ -724,36 +707,37 @@ def add_errors(): try: for i in range(len(order)): if field_types[i] == "date": - processed = ProcessingFunctions.parse_and_assert_past_date( - {"date": input_data[order[i]]}, output_field, input_fields, args + new_call_args = replace( + call_args, + input_data={"date": call_args.input_data[order[i]]}, ) + processed = ProcessingFunctions.parse_and_assert_past_date(new_call_args) formatted_input_data.append( fallback_value if null_per_backend(processed.datum) else str(processed.datum) ) elif field_types[i] == "timestamp": - processed = ProcessingFunctions.parse_timestamp( - {"timestamp": input_data[order[i]]}, output_field, input_fields, args + new_call_args = replace( + call_args, + input_data={"timestamp": call_args.input_data[order[i]]}, ) + processed = ProcessingFunctions.parse_timestamp(new_call_args) formatted_input_data.append( fallback_value if null_per_backend(processed.datum) else str(processed.datum) ) elif field_types[i] == "ACCESSION_VERSION": - formatted_input_data.append(accession_version) - elif order[i] in input_data: + formatted_input_data.append(call_args.internal_metadata.accession_version) + elif order[i] in call_args.input_data: formatted_input_data.append( fallback_value - if null_per_backend(input_data[order[i]]) - else str(input_data[order[i]]).strip() + if null_per_backend(call_args.input_data[order[i]]) + else str(call_args.input_data[order[i]]).strip() ) else: - logger.error( - f"Concatenate: cannot find field {order[i]} in input_data" - f"This is probably a configuration error. (ACCESSION_VERSION: {accession_version})" - ) + logger_error(f"Concatenate: cannot find field {order[i]} in input_data") add_errors() return ProcessingResult( datum=None, @@ -768,14 +752,14 @@ def add_errors(): return ProcessingResult(datum=result, warnings=warnings, errors=errors) except ValueError as e: - logger.error(f"Concatenate failed with {e} (ACCESSION_VERSION: {accession_version})") + logger_error(f"Concatenate failed with {e} ") errors.append( ProcessingAnnotation.from_fields( - input_fields, - [output_field], + call_args.input_fields, + [call_args.output_field], AnnotationSourceType.METADATA, message=( - f"Concatenation failed for {output_field}. This is a technical error, " + f"Concatenation failed for {call_args.output_field}. This is a technical error, " "please contact the administrator." ), ) @@ -788,12 +772,9 @@ def add_errors(): @staticmethod def check_authors( - input_data: InputMetadata, - output_field: str, - input_fields: list[str], - args: FunctionArgs, + call_args: ProcessingFunctionCallArgs, ) -> ProcessingResult: - authors = input_data["authors"] + authors = call_args.input_data["authors"] author_format_description = ( "Please ensure that " @@ -812,7 +793,9 @@ def check_authors( warnings=warnings, errors=errors, ) - errors, warnings = check_latin_characters(authors, input_fields, output_field) + errors, warnings = check_latin_characters( + authors, call_args.input_fields, call_args.output_field + ) if errors or warnings: return ProcessingResult( datum=None, @@ -829,8 +812,8 @@ def check_authors( ) warnings.append( ProcessingAnnotation.from_fields( - input_fields, - [output_field], + call_args.input_fields, + [call_args.output_field], AnnotationSourceType.METADATA, message=warning_message, ) @@ -843,8 +826,8 @@ def check_authors( ) warnings.append( ProcessingAnnotation.from_fields( - input_fields, - [output_field], + call_args.input_fields, + [call_args.output_field], AnnotationSourceType.METADATA, message=warning_message, ) @@ -868,8 +851,8 @@ def check_authors( datum=None, errors=[ ProcessingAnnotation.from_fields( - input_fields, - [output_field], + call_args.input_fields, + [call_args.output_field], AnnotationSourceType.METADATA, message=error_message, ) @@ -879,10 +862,7 @@ def check_authors( @staticmethod def extract_regex( - input_data: InputMetadata, - output_field: str, - input_fields: list[str], - args: FunctionArgs, + call_args: ProcessingFunctionCallArgs, ) -> ProcessingResult: """ Extracts a substring from the `regex_field` using the provided regex `pattern` @@ -890,34 +870,38 @@ def extract_regex( e.g. ^(?P[^-]+)-(?P[^-]+)$ where segment or subtype could be used as a capture_group to extract their respective value from the regex_field. """ - regex_field = input_data["regex_field"] + regex_field = call_args.input_data["regex_field"] warnings: list[ProcessingAnnotation] = [] errors: list[ProcessingAnnotation] = [] - pattern = args.get("pattern") - capture_group = args.get("capture_group") - uppercase = args.get("uppercase", False) + pattern = call_args.args.get("pattern") + capture_group = call_args.args.get("capture_group") + uppercase = call_args.args.get("uppercase", False) if not regex_field: return ProcessingResult(datum=None, warnings=warnings, errors=errors) if not isinstance(pattern, str): errors.append( ProcessingAnnotation.from_fields( - input_fields, - [output_field], + call_args.input_fields, + [call_args.output_field], AnnotationSourceType.METADATA, - message=regex_error("extract_regex", "pattern", input_data, args), + message=regex_error( + "extract_regex", "pattern", call_args.input_data, call_args.args + ), ) ) return ProcessingResult(datum=None, warnings=warnings, errors=errors) if not isinstance(capture_group, str): errors.append( ProcessingAnnotation.from_fields( - input_fields, - [output_field], + call_args.input_fields, + [call_args.output_field], AnnotationSourceType.METADATA, - message=regex_error("extract_regex", "capture_group", input_data, args), + message=regex_error( + "extract_regex", "capture_group", call_args.input_data, call_args.args + ), ) ) return ProcessingResult(datum=None, warnings=warnings, errors=errors) @@ -931,8 +915,8 @@ def extract_regex( except IndexError: errors.append( ProcessingAnnotation.from_fields( - input_fields, - [output_field], + call_args.input_fields, + [call_args.output_field], AnnotationSourceType.METADATA, message=( f"The pattern '{pattern}' does not contain a capture group: " @@ -944,8 +928,8 @@ def extract_regex( else: errors.append( ProcessingAnnotation.from_fields( - input_fields, - [output_field], + call_args.input_fields, + [call_args.output_field], AnnotationSourceType.METADATA, message=( f"The value '{regex_field}' does not match the expected regex " @@ -957,31 +941,30 @@ def extract_regex( @staticmethod def check_regex( - input_data: InputMetadata, - output_field: str, - input_fields: list[str], - args: FunctionArgs, + call_args: ProcessingFunctionCallArgs, ) -> ProcessingResult: """ Validates that the field regex_field matches the regex expression. If not return error """ - regex_field = input_data["regex_field"] + regex_field = call_args.input_data["regex_field"] warnings: list[ProcessingAnnotation] = [] errors: list[ProcessingAnnotation] = [] - pattern = args["pattern"] + pattern = call_args.args["pattern"] if not regex_field: return ProcessingResult(datum=None, warnings=warnings, errors=errors) if not isinstance(pattern, str): errors.append( ProcessingAnnotation.from_fields( - input_fields, - [output_field], + call_args.input_fields, + [call_args.output_field], AnnotationSourceType.METADATA, - message=regex_error("check_regex", "pattern", input_data, args), + message=regex_error( + "check_regex", "pattern", call_args.input_data, call_args.args + ), ) ) return ProcessingResult(datum=None, warnings=warnings, errors=errors) @@ -990,8 +973,8 @@ def check_regex( return ProcessingResult(datum=regex_field, warnings=warnings, errors=errors) errors.append( ProcessingAnnotation.from_fields( - input_fields, - [output_field], + call_args.input_fields, + [call_args.output_field], AnnotationSourceType.METADATA, message=( f"The value '{regex_field}' does not match the expected regex " @@ -1003,37 +986,39 @@ def check_regex( @staticmethod def identity( # noqa: C901, PLR0912 - input_data: InputMetadata, output_field: str, input_fields: list[str], args: FunctionArgs + call_args: ProcessingFunctionCallArgs, ) -> ProcessingResult: """Identity function, takes input_data["input"] and returns it as output""" - if "input" not in input_data: + if "input" not in call_args.input_data: return ProcessingResult( datum=None, warnings=[], errors=[ ProcessingAnnotation.from_fields( - input_fields, - [output_field], + call_args.input_fields, + [call_args.output_field], AnnotationSourceType.METADATA, - message=f"No data found for output field: {output_field}", + message=f"No data found for output field: {call_args.output_field}", ) ], ) - input_datum = input_data["input"] + input_datum = call_args.input_data["input"] if not input_datum: return ProcessingResult(datum=None, warnings=[], errors=[]) errors: list[ProcessingAnnotation] = [] output_datum: ProcessedMetadataValue - if args and "type" in args: - match args["type"]: + if call_args.args and "type" in call_args.args: + match call_args.args["type"]: case "int": try: output_datum = int(input_datum) except ValueError: output_datum = None errors.append( - invalid_value_annotation(input_datum, output_field, input_fields, "int") + invalid_value_annotation( + input_datum, call_args.output_field, call_args.input_fields, "int" + ) ) case "float": try: @@ -1046,7 +1031,7 @@ def identity( # noqa: C901, PLR0912 output_datum = None errors.append( invalid_value_annotation( - input_datum, output_field, input_fields, "float" + input_datum, call_args.output_field, call_args.input_fields, "float" ) ) case "boolean": @@ -1058,7 +1043,10 @@ def identity( # noqa: C901, PLR0912 output_datum = None errors.append( invalid_value_annotation( - input_datum, output_field, input_fields, "boolean" + input_datum, + call_args.output_field, + call_args.input_fields, + "boolean", ) ) case _: @@ -1072,48 +1060,46 @@ def identity( # noqa: C901, PLR0912 @staticmethod def process_options( - input_data: InputMetadata, output_field: str, input_fields: list[str], args: FunctionArgs + call_args: ProcessingFunctionCallArgs, ) -> ProcessingResult: """Checks that option is in options""" - if "options" not in args or not isinstance(args["options"], list): + if "options" not in call_args.args or not isinstance(call_args.args["options"], list): return ProcessingResult( datum=None, warnings=[], errors=[ ProcessingAnnotation.from_fields( - input_fields, - [output_field], + call_args.input_fields, + [call_args.output_field], AnnotationSourceType.METADATA, message=( "Website configuration error: no options list specified for field " - f"{output_field}, please contact an administrator." + f"{call_args.output_field}, please contact an administrator." ), ) ], ) - input_datum = input_data["input"] + input_datum = call_args.input_data["input"] if not input_datum: return ProcessingResult(datum=None, warnings=[], errors=[]) output_datum: ProcessedMetadataValue standardized_input_datum = standardize_option(input_datum) - if output_field in options_cache: - options = options_cache[output_field] + if call_args.output_field in options_cache: + options = options_cache[call_args.output_field] else: - options = compute_options_cache(output_field, args["options"]) - error_msg = ( - f"Metadata field {output_field}:'{input_datum}' - not in list of accepted options." - ) + options = compute_options_cache(call_args.output_field, call_args.args["options"]) + error_msg = f"Metadata field {call_args.output_field}:'{input_datum}' - not in list of accepted options." if standardized_input_datum in options: output_datum = options[standardized_input_datum] # Allow ingested data to include fields not in options - elif args["is_insdc_ingest_group"]: + elif call_args.args["is_insdc_ingest_group"]: return ProcessingResult( datum=input_datum, warnings=[ ProcessingAnnotation.from_fields( - input_fields, - [output_field], + call_args.input_fields, + [call_args.output_field], AnnotationSourceType.METADATA, message=error_msg, ) @@ -1126,8 +1112,8 @@ def process_options( warnings=[], errors=[ ProcessingAnnotation.from_fields( - input_fields, - [output_field], + call_args.input_fields, + [call_args.output_field], AnnotationSourceType.METADATA, message=error_msg, ) @@ -1137,30 +1123,30 @@ def process_options( @staticmethod def is_above_threshold( - input_data: InputMetadata, output_field: str, input_fields: list[str], args: FunctionArgs + call_args: ProcessingFunctionCallArgs, ) -> ProcessingResult: """Flag if input value is above a threshold specified in args""" - if "threshold" not in args: + if "threshold" not in call_args.args: return ProcessingResult( datum=None, warnings=[], errors=[ ProcessingAnnotation.from_fields( - input_fields, - [output_field], + call_args.input_fields, + [call_args.output_field], AnnotationSourceType.METADATA, message=( - f"Field {output_field} is missing threshold argument." + f"Field {call_args.output_field} is missing threshold argument." " Please report this error to the administrator." ), ) ], ) - input_datum = input_data["input"] + input_datum = call_args.input_data["input"] if not input_datum: return ProcessingResult(datum=None, warnings=[], errors=[]) try: - threshold = float(args["threshold"]) # type: ignore + threshold = float(call_args.args["threshold"]) # type: ignore input = float(input_datum) except (ValueError, TypeError): return ProcessingResult( @@ -1168,10 +1154,12 @@ def is_above_threshold( warnings=[], errors=[ ProcessingAnnotation.from_fields( - input_fields, - [output_field], + call_args.input_fields, + [call_args.output_field], AnnotationSourceType.METADATA, - message=(f"Field {output_field} has non-numeric threshold value."), + message=( + f"Field {call_args.output_field} has non-numeric threshold value." + ), ) ], ) diff --git a/preprocessing/nextclade/tests/factory_methods.py b/preprocessing/nextclade/tests/factory_methods.py index 933cc5f66d..ec4ec806e4 100644 --- a/preprocessing/nextclade/tests/factory_methods.py +++ b/preprocessing/nextclade/tests/factory_methods.py @@ -9,6 +9,7 @@ from loculus_preprocessing.datatypes import ( AnnotationSource, AnnotationSourceType, + InternalMetadata, NucleotideSequence, ProcessedData, ProcessedEntry, @@ -17,20 +18,19 @@ ProcessingAnnotationAlignment, SegmentName, UnprocessedData, - UnprocessedEntry, ) -def ts_from_ymd(year: int, month: int, day: int) -> str: - """Convert a year, month, and day into a UTC timestamp string.""" +def ts_from_ymd(year: int, month: int, day: int) -> int: + """Convert a year, month, and day into a UTC timestamp integer.""" dt = datetime(year, month, day, tzinfo=pytz.UTC) - return str(dt.timestamp()) + return int(dt.timestamp()) @dataclass class ProcessingTestCase: name: str - input: UnprocessedEntry + input: UnprocessedData expected_output: ProcessedEntry @@ -69,26 +69,35 @@ class ProcessedAlignment: ) +def get_dummy_internal_metadata( + accession: str = "LOC_01.1", group_id: int = 2, submission_id: str = "test_submission_id" +) -> InternalMetadata: + return InternalMetadata( + accession_version=accession, + submission_id=submission_id, + submitter="test_submitter", + group_id=group_id, + submitted_at=int(ts_from_ymd(2021, 12, 15)), + ) + + @dataclass -class UnprocessedEntryFactory: +class UnprocessedDataFactory: @staticmethod def create_unprocessed_entry( metadata_dict: dict[str, str | None], accession_id: str, sequences: dict[SegmentName, NucleotideSequence | None], group_id: int = 2, - ) -> UnprocessedEntry: - return UnprocessedEntry( - accessionVersion=f"LOC_{accession_id}.1", - data=UnprocessedData( - submitter="test_submitter", - submittedAt=str( - datetime.strptime("2021-12-15", "%Y-%m-%d").replace(tzinfo=pytz.utc).timestamp() - ), + ) -> UnprocessedData: + return UnprocessedData( + metadata=metadata_dict, + internal_metadata=get_dummy_internal_metadata( + accession=f"LOC_{accession_id}.1", group_id=group_id, - metadata=metadata_dict, - unalignedNucleotideSequences=sequences, + submission_id=f"SUB_{accession_id}", ), + unalignedNucleotideSequences=sequences, ) @@ -173,7 +182,7 @@ class Case: def create_test_case(self, factory_custom: ProcessedEntryFactory) -> ProcessingTestCase: if not self.expected_processed_alignment: self.expected_processed_alignment = ProcessedAlignment() - unprocessed_entry = UnprocessedEntryFactory.create_unprocessed_entry( + unprocessed_entry = UnprocessedDataFactory.create_unprocessed_entry( metadata_dict=self.input_metadata, accession_id=self.accession_id, sequences=self.input_sequence, @@ -181,7 +190,7 @@ def create_test_case(self, factory_custom: ProcessedEntryFactory) -> ProcessingT ) expected_output = factory_custom.create_processed_entry( metadata_dict=self.expected_metadata, - accession=unprocessed_entry.accessionVersion.split(".")[0], + accession=unprocessed_entry.internal_metadata.accession_version.split(".")[0], errors=self.expected_errors or [], warnings=self.expected_warnings or [], processed_alignment=self.expected_processed_alignment, diff --git a/preprocessing/nextclade/tests/test_metadata_processing_functions.py b/preprocessing/nextclade/tests/test_metadata_processing_functions.py index 06196480f6..f034c6c78c 100644 --- a/preprocessing/nextclade/tests/test_metadata_processing_functions.py +++ b/preprocessing/nextclade/tests/test_metadata_processing_functions.py @@ -6,7 +6,7 @@ ProcessingAnnotationHelper, ProcessingTestCase, build_processing_annotations, - ts_from_ymd, + get_dummy_internal_metadata, verify_processed_entry, ) @@ -15,8 +15,8 @@ FunctionArgs, InputMetadata, ProcessedEntry, + ProcessingFunctionCallArgs, UnprocessedData, - UnprocessedEntry, ) from loculus_preprocessing.prepro import process_all from loculus_preprocessing.processing_functions import ( @@ -705,18 +705,13 @@ def test_preprocessing(test_case_def: Case, config: Config, factory_custom: Proc def test_preprocessing_without_consensus_sequences(config: Config) -> None: sequence_name = "entry without sequences" - sequence_entry_data = UnprocessedEntry( - accessionVersion="LOC_01.1", - data=UnprocessedData( - submitter="test_submitter", - group_id=2, - submittedAt=ts_from_ymd(2021, 12, 15), - metadata={ - "ncbi_required_collection_date": "2024-01-01", - "name_required": sequence_name, - }, - unalignedNucleotideSequences={}, - ), + sequence_entry_data = UnprocessedData( + internal_metadata=get_dummy_internal_metadata(), + metadata={ + "ncbi_required_collection_date": "2024-01-01", + "name_required": sequence_name, + }, + unalignedNucleotideSequences={}, ) result = process_all([sequence_entry_data], "temp_dataset_dir", config) @@ -754,148 +749,82 @@ def test_format_authors() -> None: raise AssertionError(msg) +def generate_call_args(input_data: InputMetadata, field_type: str) -> ProcessingFunctionCallArgs: + return ProcessingFunctionCallArgs( + input_data=input_data, + output_field="field_name", + input_fields=["field_name"], + args={ + "fieldType": field_type, + }, + internal_metadata=get_dummy_internal_metadata(), + ) + + def test_parse_date_into_range() -> None: assert ( ProcessingFunctions.parse_date_into_range( - {"date": "2021-12"}, - "field_name", - ["field_name"], - { - "fieldType": "dateRangeString", - "submittedAt": ts_from_ymd(2021, 12, 15), - }, + generate_call_args({"date": "2021-12"}, "dateRangeString") ).datum == "2021-12" ), "dateRangeString: 2021-12 should be returned as is." assert ( ProcessingFunctions.parse_date_into_range( - {"date": "2021-12"}, - "field_name", - ["field_name"], - { - "fieldType": "dateRangeLower", - "submittedAt": ts_from_ymd(2021, 12, 15), - }, + generate_call_args({"date": "2021-12"}, "dateRangeLower") ).datum == "2021-12-01" ), "dateRangeLower: 2021-12 should be returned as 2021-12-01." assert ( ProcessingFunctions.parse_date_into_range( - {"date": "2021-12"}, - "field_name", - ["field_name"], - { - "fieldType": "dateRangeUpper", - "submittedAt": ts_from_ymd(2022, 12, 15), - }, + generate_call_args({"date": "2020-12"}, "dateRangeUpper") ).datum - == "2021-12-31" - ), "dateRangeUpper: 2021-12 should be returned as 2021-12-31." + == "2020-12-31" + ), "dateRangeUpper: 2020-12 should be returned as 2020-12-31." assert ( ProcessingFunctions.parse_date_into_range( - {"date": "2021-12"}, - "field_name", - ["field_name"], - { - "fieldType": "dateRangeUpper", - "submittedAt": ts_from_ymd(2021, 12, 15), - }, + generate_call_args({"date": "2021-12"}, "dateRangeUpper") ).datum == "2021-12-15" ), "dateRangeUpper: 2021-12 should be returned as submittedAt time: 2021-12-15." assert ( ProcessingFunctions.parse_date_into_range( - {"date": "2021-02"}, - "field_name", - ["field_name"], - { - "fieldType": "dateRangeUpper", - "submittedAt": ts_from_ymd(2021, 3, 15), - }, + generate_call_args({"date": "2021-02"}, "dateRangeUpper") ).datum == "2021-02-28" ), "dateRangeUpper: 2021-02 should be returned as 2021-02-28." assert ( ProcessingFunctions.parse_date_into_range( - {"date": "2021"}, - "field_name", - ["field_name"], - { - "fieldType": "dateRangeUpper", - "submittedAt": ts_from_ymd(2021, 12, 15), - }, + generate_call_args({"date": "2021"}, "dateRangeUpper") ).datum == "2021-12-15" ), "dateRangeUpper: 2021 should be returned as 2021-12-15." assert ( ProcessingFunctions.parse_date_into_range( - {"date": "2021"}, - "field_name", - ["field_name"], - { - "fieldType": "dateRangeUpper", - "submittedAt": ts_from_ymd(2022, 1, 15), - }, + generate_call_args({"date": "2021-12", "releaseDate": "2021-12-14"}, "dateRangeUpper") ).datum - == "2021-12-31" - ), "dateRangeUpper: 2021 should be returned as 2021-12-31." + == "2021-12-14" + ), "dateRangeUpper: 2021-12 with releaseDate 2021-12-14 should be returned as 2021-12-14." assert ( ProcessingFunctions.parse_date_into_range( - {"date": "2021-12", "releaseDate": "2021-12-15"}, - "field_name", - ["field_name"], - { - "fieldType": "dateRangeUpper", - "submittedAt": ts_from_ymd(2021, 12, 16), - }, + generate_call_args({"date": "", "releaseDate": "2021-12-14"}, "dateRangeUpper") ).datum - == "2021-12-15" - ), "dateRangeUpper: 2021-12 with releaseDate 2021-12-15 should be returned as 2021-12-15." - assert ( - ProcessingFunctions.parse_date_into_range( - {"date": "", "releaseDate": "2021-12-15"}, - "field_name", - ["field_name"], - { - "fieldType": "dateRangeUpper", - "submittedAt": ts_from_ymd(2021, 12, 16), - }, - ).datum - == "2021-12-15" - ), "dateRangeUpper: empty date with releaseDate 2021-12-15 should be returned as 2021-12-15." + == "2021-12-14" + ), "dateRangeUpper: empty date with releaseDate 2021-12-14 should be returned as 2021-12-15." assert ( ProcessingFunctions.parse_date_into_range( - {"date": ""}, - "field_name", - ["field_name"], - { - "fieldType": "dateRangeString", - "submittedAt": ts_from_ymd(2021, 12, 16), - }, + generate_call_args({"date": ""}, "dateRangeString") ).datum is None ), "dateRangeString: empty date should be returned as None." assert ( ProcessingFunctions.parse_date_into_range( - {"date": "not.date"}, - "field_name", - ["field_name"], - { - "fieldType": "dateRangeString", - "submittedAt": ts_from_ymd(2021, 12, 16), - }, + generate_call_args({"date": "not.date"}, "dateRangeString") ).datum is None ), "dateRangeString: invalid date should be returned as None." assert ( ProcessingFunctions.parse_date_into_range( - {"date": "", "releaseDate": "2021-12-15"}, - "field_name", - ["field_name"], - { - "fieldType": "dateRangeLower", - "submittedAt": ts_from_ymd(2021, 12, 16), - }, + generate_call_args({"date": "", "releaseDate": "2021-12-15"}, "dateRangeLower") ).datum is None ), "dateRangeLower: empty date should be returned as None." @@ -910,60 +839,75 @@ def test_concatenate() -> None: output_field: str = "displayName" input_fields: list[str] = ["geoLocCountry", "sampleCollectionDate"] args: FunctionArgs = { - "ACCESSION_VERSION": "version.1", + "ACCESSION_VERSION": "LOC_01.1", "order": ["someInt", "geoLocCountry", "ACCESSION_VERSION", "sampleCollectionDate"], "type": ["integer", "string", "ACCESSION_VERSION", "date"], } args_no_accession_version: FunctionArgs = { - "ACCESSION_VERSION": "version.1", + "ACCESSION_VERSION": "LOC_01.1", "order": ["someInt", "geoLocCountry", "sampleCollectionDate"], "type": ["integer", "string", "date"], "fallback_value": "unknown", } res_no_fallback_no_int = ProcessingFunctions.concatenate( - input_data, - output_field, - input_fields, - args, + ProcessingFunctionCallArgs( + args, + output_field, + input_fields, + input_data, + get_dummy_internal_metadata(), + ) ) input_data["someInt"] = "0" res_no_fallback = ProcessingFunctions.concatenate( - input_data, - output_field, - input_fields, - args, + ProcessingFunctionCallArgs( + args, + output_field, + input_fields, + input_data, + get_dummy_internal_metadata(), + ) ) args["fallback_value"] = "unknown" res_fallback = ProcessingFunctions.concatenate( - input_data, - output_field, - input_fields, - args, + ProcessingFunctionCallArgs( + args, + output_field, + input_fields, + input_data, + get_dummy_internal_metadata(), + ) ) res_fallback_no_accession_version = ProcessingFunctions.concatenate( - input_data, - output_field, - input_fields, - args_no_accession_version, + ProcessingFunctionCallArgs( + args_no_accession_version, + output_field, + input_fields, + input_data, + get_dummy_internal_metadata(), + ) ) input_data["sampleCollectionDate"] = None res_fallback_explicit_null = ProcessingFunctions.concatenate( - input_data, - output_field, - input_fields, - args, + ProcessingFunctionCallArgs( + args, + output_field, + input_fields, + input_data, + get_dummy_internal_metadata(), + ) ) - assert res_no_fallback_no_int.datum == "version.1/2025-01-01" - assert res_no_fallback.datum == "0//version.1/2025-01-01" - assert res_fallback.datum == "0/unknown/version.1/2025-01-01" + assert res_no_fallback_no_int.datum == "LOC_01.1/2025-01-01" + assert res_no_fallback.datum == "0//LOC_01.1/2025-01-01" + assert res_fallback.datum == "0/unknown/LOC_01.1/2025-01-01" assert res_fallback_no_accession_version.datum == "0/unknown/2025-01-01" - assert res_fallback_explicit_null.datum == "0/unknown/version.1/unknown" + assert res_fallback_explicit_null.datum == "0/unknown/LOC_01.1/unknown" if __name__ == "__main__": diff --git a/preprocessing/nextclade/tests/test_nextclade_preprocessing.py b/preprocessing/nextclade/tests/test_nextclade_preprocessing.py index 8f8d6e859e..eed734d22a 100644 --- a/preprocessing/nextclade/tests/test_nextclade_preprocessing.py +++ b/preprocessing/nextclade/tests/test_nextclade_preprocessing.py @@ -14,7 +14,7 @@ ProcessingAnnotationHelper, ProcessingTestCase, build_processing_annotations, - ts_from_ymd, + get_dummy_internal_metadata, verify_processed_entry, ) @@ -24,7 +24,6 @@ SegmentClassificationMethod, SubmissionData, UnprocessedData, - UnprocessedEntry, ) from loculus_preprocessing.embl import create_flatfile, reformat_authors_from_loculus_to_embl_style from loculus_preprocessing.prepro import process_all @@ -1193,18 +1192,13 @@ def test_preprocessing_multi_segment_none_requirement(test_case_def: Case): def test_preprocessing_without_metadata() -> None: config = get_config(MULTI_SEGMENT_CONFIG, ignore_args=True) - sequence_entry_data = UnprocessedEntry( - accessionVersion="LOC_01.1", - data=UnprocessedData( - group_id=2, - submitter="test_submitter", - submittedAt=ts_from_ymd(2021, 12, 15), - metadata={}, - unalignedNucleotideSequences={ - "ebola-sudan": sequence_with_mutation("ebola-sudan"), - "ebola-zaire": sequence_with_mutation("ebola-zaire"), - }, - ), + sequence_entry_data = UnprocessedData( + internal_metadata=get_dummy_internal_metadata(), + metadata={}, + unalignedNucleotideSequences={ + "ebola-sudan": sequence_with_mutation("ebola-sudan"), + "ebola-zaire": sequence_with_mutation("ebola-zaire"), + }, ) config.processing_spec = {} @@ -1309,21 +1303,16 @@ def test_create_flatfile(): embl_fields = get_config(EMBL_METADATA, ignore_args=True).processing_spec config.processing_spec.update(embl_fields) config.create_embl_file = True - sequence_entry_data = UnprocessedEntry( - accessionVersion="LOC_01.1", - data=UnprocessedData( - submitter="test_submitter", - group_id=2, - submittedAt=ts_from_ymd(2021, 12, 15), - metadata={ - "sampleCollectionDate": "2024-01-01", - "geoLocCountry": "Netherlands", - "geoLocAdmin1": "North Holland", - "geoLocCity": "Amsterdam", - "authors": "Smith, Doe A;", - }, - unalignedNucleotideSequences={"main": sequence_with_mutation("single")}, - ), + sequence_entry_data = UnprocessedData( + internal_metadata=get_dummy_internal_metadata(), + metadata={ + "sampleCollectionDate": "2024-01-01", + "geoLocCountry": "Netherlands", + "geoLocAdmin1": "North Holland", + "geoLocCity": "Amsterdam", + "authors": "Smith, Doe A;", + }, + unalignedNucleotideSequences={"main": sequence_with_mutation("single")}, ) result = process_all([sequence_entry_data], EBOLA_SUDAN_DATASET, config)