Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -256,15 +256,11 @@ class SubmitModel(
"from $submissionParams.submitter with UploadId $uploadId"
}
val now = dateProvider.getCurrentDateTime()
val maxSequencesPerEntry = backendConfig.getInstanceConfig(submissionParams.organism)
.schema
.submissionDataTypes
.maxSequencesPerEntry

try {
when (submissionParams) {
is SubmissionParams.OriginalSubmissionParams -> {
metadataEntryStreamAsSequence(metadataStream, maxSequencesPerEntry)
metadataEntryStreamAsSequence(metadataStream)
.chunked(batchSize)
.forEach { batch ->
uploadDatabaseService.batchInsertMetadataInAuxTable(
Expand All @@ -280,7 +276,7 @@ class SubmitModel(
}

is SubmissionParams.RevisionSubmissionParams -> {
revisionEntryStreamAsSequence(metadataStream, maxSequencesPerEntry)
revisionEntryStreamAsSequence(metadataStream)
.chunked(batchSize)
.forEach { batch ->
uploadDatabaseService.batchInsertRevisedMetadataInAuxTable(
Expand Down
29 changes: 5 additions & 24 deletions backend/src/main/kotlin/org/loculus/backend/utils/MetadataEntry.kt
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,7 @@ fun findAndValidateSubmissionIdHeader(headerNames: List<String>): String {
return submissionIdHeaders.first()
}

fun extractAndValidateFastaIds(
record: CSVRecord,
submissionId: String,
recordNumber: Int,
maxSequencesPerEntry: Int? = null,
): Set<FastaId> {
fun extractAndValidateFastaIds(record: CSVRecord, submissionId: String, recordNumber: Int): Set<FastaId> {
val headerNames = record.parser.headerNames
return when (headerNames.contains(FASTA_IDS_HEADER)) {
true -> {
Expand All @@ -74,14 +69,6 @@ fun extractAndValidateFastaIds(
)
}

if (maxSequencesPerEntry != null && fastaIds.size > maxSequencesPerEntry) {
throw UnprocessableEntityException(
"In metadata file: record #$recordNumber with id '$submissionId': " +
"found ${fastaIds.size} fasta ids but the maximum allowed number of " +
"sequences per entry is $maxSequencesPerEntry",
)
}

fastaIds.toSet()
}

Expand Down Expand Up @@ -134,10 +121,7 @@ private fun throwWithCsvExceptionUnwrapped(e: Exception): Nothing {
throw e
}

fun metadataEntryStreamAsSequence(
metadataInputStream: InputStream,
maxSequencesPerEntry: Int? = null,
): Sequence<MetadataEntry> {
fun metadataEntryStreamAsSequence(metadataInputStream: InputStream): Sequence<MetadataEntry> {
val csvParser = setUpCsvParser(metadataInputStream)

val headerNames = csvParser.headerNames
Expand All @@ -150,7 +134,7 @@ fun metadataEntryStreamAsSequence(

val submissionId = getValueAndValidateNoWhitespace(record, submissionIdHeader, recordNumber)

val fastaIds = extractAndValidateFastaIds(record, submissionId, recordNumber, maxSequencesPerEntry)
val fastaIds = extractAndValidateFastaIds(record, submissionId, recordNumber)

val metadata = record.toMap().filterKeys {
it != submissionIdHeader &&
Expand All @@ -174,10 +158,7 @@ data class RevisionEntry(
val fastaIds: Set<FastaId>? = null,
)

fun revisionEntryStreamAsSequence(
metadataInputStream: InputStream,
maxSequencesPerEntry: Int? = null,
): Sequence<RevisionEntry> {
fun revisionEntryStreamAsSequence(metadataInputStream: InputStream): Sequence<RevisionEntry> {
val csvParser = setUpCsvParser(metadataInputStream)

val headerNames = csvParser.headerNames
Expand All @@ -197,7 +178,7 @@ fun revisionEntryStreamAsSequence(
val submissionId = getValueAndValidateNoWhitespace(record, submissionIdHeader, recordNumber)
val accession = getValueAndValidateNoWhitespace(record, ACCESSION_HEADER, recordNumber)

val fastaIds = extractAndValidateFastaIds(record, submissionId, recordNumber, maxSequencesPerEntry)
val fastaIds = extractAndValidateFastaIds(record, submissionId, recordNumber)

val metadata = record.toMap().filterKeys {
it != submissionIdHeader && it != ACCESSION_HEADER &&
Expand Down
133 changes: 9 additions & 124 deletions backend/src/test/kotlin/org/loculus/backend/utils/MetadataEntryTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -129,104 +129,18 @@ class MetadataEntryTest {
}

@Test
fun `test maxSequencesPerEntry not set allows multiple sequences`() {
fun `test multiple fasta IDs are accepted without limit`() {
val str = """
submissionId${'\t'}fastaIds${'\t'}Country
foo${'\t'}seq1 seq2 seq3${'\t'}bar
""".trimIndent()
val inputStream = ByteArrayInputStream(str.toByteArray())
val entries = metadataEntryStreamAsSequence(inputStream, maxSequencesPerEntry = null).toList()
assertThat(entries, hasSize(1))
assertThat(entries[0].submissionId, equalTo("foo"))
assertThat(entries[0].fastaIds, equalTo(setOf("seq1", "seq2", "seq3")))
}

@Test
fun `test maxSequencesPerEntry allows sequences within limit`() {
val str = """
submissionId${'\t'}fastaIds${'\t'}Country
foo${'\t'}seq1 seq2${'\t'}bar
""".trimIndent()
val inputStream = ByteArrayInputStream(str.toByteArray())
val entries = metadataEntryStreamAsSequence(inputStream, maxSequencesPerEntry = 3).toList()
assertThat(entries, hasSize(1))
assertThat(entries[0].submissionId, equalTo("foo"))
assertThat(entries[0].fastaIds, equalTo(setOf("seq1", "seq2")))
}

@Test
fun `test maxSequencesPerEntry allows sequences at exact limit`() {
val str = """
submissionId${'\t'}fastaIds${'\t'}Country
foo${'\t'}seq1 seq2 seq3${'\t'}bar
""".trimIndent()
val inputStream = ByteArrayInputStream(str.toByteArray())
val entries = metadataEntryStreamAsSequence(inputStream, maxSequencesPerEntry = 3).toList()
val entries = metadataEntryStreamAsSequence(inputStream).toList()
assertThat(entries, hasSize(1))
assertThat(entries[0].submissionId, equalTo("foo"))
assertThat(entries[0].fastaIds, equalTo(setOf("seq1", "seq2", "seq3")))
}

@Test
fun `test maxSequencesPerEntry rejects sequences exceeding limit`() {
val str = """
submissionId${'\t'}fastaIds${'\t'}Country
foo${'\t'}seq1 seq2 seq3 seq4${'\t'}bar
""".trimIndent()
val inputStream = ByteArrayInputStream(str.toByteArray())
val exception = assertThrows<UnprocessableEntityException> {
metadataEntryStreamAsSequence(inputStream, maxSequencesPerEntry = 3).toList()
}
assertThat(exception.message, containsString("record #1"))
assertThat(exception.message, containsString("foo"))
assertThat(exception.message, containsString("found 4 fasta ids"))
assertThat(exception.message, containsString("maximum allowed number of sequences per entry is 3"))
}

@Test
fun `test maxSequencesPerEntry with single sequence limit`() {
val str = """
submissionId${'\t'}fastaIds${'\t'}Country
foo${'\t'}seq1 seq2${'\t'}bar
""".trimIndent()
val inputStream = ByteArrayInputStream(str.toByteArray())
val exception = assertThrows<UnprocessableEntityException> {
metadataEntryStreamAsSequence(inputStream, maxSequencesPerEntry = 1).toList()
}
assertThat(exception.message, containsString("record #1"))
assertThat(exception.message, containsString("foo"))
assertThat(exception.message, containsString("found 2 fasta ids"))
assertThat(exception.message, containsString("maximum allowed number of sequences per entry is 1"))
}

@Test
fun `test maxSequencesPerEntry allows single sequence when limit is 1`() {
val str = """
submissionId${'\t'}fastaIds${'\t'}Country
foo${'\t'}seq1${'\t'}bar
""".trimIndent()
val inputStream = ByteArrayInputStream(str.toByteArray())
val entries = metadataEntryStreamAsSequence(inputStream, maxSequencesPerEntry = 1).toList()
assertThat(entries, hasSize(1))
assertThat(entries[0].submissionId, equalTo("foo"))
assertThat(entries[0].fastaIds, equalTo(setOf("seq1")))
}

@Test
fun `test maxSequencesPerEntry correct record number for multiple rows`() {
val str = """
submissionId${'\t'}fastaIds${'\t'}Country
foo1${'\t'}seq1${'\t'}bar
foo2${'\t'}seq2 seq3${'\t'}bar
""".trimIndent()
val inputStream = ByteArrayInputStream(str.toByteArray())
val exception = assertThrows<UnprocessableEntityException> {
metadataEntryStreamAsSequence(inputStream, maxSequencesPerEntry = 1).toList()
}
assertThat(exception.message, containsString("record #2"))
assertThat(exception.message, containsString("foo2"))
}

@Test
fun `test multiple duplicate fasta IDs are all reported`() {
val str = """
Expand All @@ -245,14 +159,14 @@ class MetadataEntryTest {
}

@Test
fun `test duplicate detection works with maxSequencesPerEntry`() {
fun `test duplicate detection works`() {
val str = """
submissionId${'\t'}fastaIds${'\t'}Country
foo${'\t'}seq1 seq1${'\t'}bar
""".trimIndent()
val inputStream = ByteArrayInputStream(str.toByteArray())
val exception = assertThrows<UnprocessableEntityException> {
metadataEntryStreamAsSequence(inputStream, maxSequencesPerEntry = 3).toList()
metadataEntryStreamAsSequence(inputStream).toList()
}
assertThat(exception.message, containsString("duplicate fasta ids"))
assertThat(exception.message, containsString("seq1"))
Expand Down Expand Up @@ -349,45 +263,16 @@ class RevisionEntryTest {
}

@Test
fun `test revision maxSequencesPerEntry allows sequences within limit`() {
val str = """
submissionId${'\t'}accession${'\t'}fastaIds${'\t'}Country
foo${'\t'}ACC123${'\t'}seq1 seq2${'\t'}bar
""".trimIndent()
val inputStream = ByteArrayInputStream(str.toByteArray())
val entries = revisionEntryStreamAsSequence(inputStream, maxSequencesPerEntry = 3).toList()
assertThat(entries, hasSize(1))
assertThat(entries[0].submissionId, equalTo("foo"))
assertThat(entries[0].fastaIds, equalTo(setOf("seq1", "seq2")))
}

@Test
fun `test revision maxSequencesPerEntry rejects sequences exceeding limit`() {
fun `test revision multiple fasta IDs are accepted`() {
val str = """
submissionId${'\t'}accession${'\t'}fastaIds${'\t'}Country
foo${'\t'}ACC123${'\t'}seq1 seq2 seq3${'\t'}bar
""".trimIndent()
val inputStream = ByteArrayInputStream(str.toByteArray())
val exception = assertThrows<UnprocessableEntityException> {
revisionEntryStreamAsSequence(inputStream, maxSequencesPerEntry = 2).toList()
}
assertThat(exception.message, containsString("record #1"))
assertThat(exception.message, containsString("foo"))
assertThat(exception.message, containsString("found 3 fasta ids"))
assertThat(exception.message, containsString("maximum allowed number of sequences per entry is 2"))
}

@Test
fun `test revision maxSequencesPerEntry with single sequence limit`() {
val str = """
submissionId${'\t'}accession${'\t'}fastaIds${'\t'}Country
foo${'\t'}ACC123${'\t'}seq1${'\t'}bar
""".trimIndent()
val inputStream = ByteArrayInputStream(str.toByteArray())
val entries = revisionEntryStreamAsSequence(inputStream, maxSequencesPerEntry = 1).toList()
val entries = revisionEntryStreamAsSequence(inputStream).toList()
assertThat(entries, hasSize(1))
assertThat(entries[0].submissionId, equalTo("foo"))
assertThat(entries[0].fastaIds, equalTo(setOf("seq1")))
assertThat(entries[0].fastaIds, equalTo(setOf("seq1", "seq2", "seq3")))
}

@Test
Expand All @@ -407,14 +292,14 @@ class RevisionEntryTest {
}

@Test
fun `test revision duplicate detection works with maxSequencesPerEntry`() {
fun `test revision duplicate detection works`() {
val str = """
submissionId${'\t'}accession${'\t'}fastaIds${'\t'}Country
foo${'\t'}ACC123${'\t'}seq1 seq1${'\t'}bar
""".trimIndent()
val inputStream = ByteArrayInputStream(str.toByteArray())
val exception = assertThrows<UnprocessableEntityException> {
revisionEntryStreamAsSequence(inputStream, maxSequencesPerEntry = 3).toList()
revisionEntryStreamAsSequence(inputStream).toList()
}
assertThat(exception.message, containsString("duplicate fasta ids"))
assertThat(exception.message, containsString("seq1"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ data:
preprocessing-config.yaml: |
organism: {{ $organism }}
{{- $processingConfig.configFile | toYaml | nindent 4 }}
{{- if and (hasKey $organismConfig.schema "submissionDataTypes") (hasKey $organismConfig.schema.submissionDataTypes "maxSequencesPerEntry") }}
max_sequences_per_entry: {{ $organismConfig.schema.submissionDataTypes.maxSequencesPerEntry }}
{{- end }}
processing_spec:
{{- $args := dict "metadata" $metadata "referenceGenomes" $organismConfig.referenceGenomes }}
{{- include "loculus.preprocessingSpecs" $args | nindent 6 }}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ class Config(BaseModel):
keycloak_token_path: str = "realms/loculus/protocol/openid-connect/token" # noqa: S105

organism: str = "mpox"
max_sequences_per_entry: int | None = None
segments: list[Segment] = Field(default_factory=list)
processing_spec: dict[str, ProcessingSpec] = Field(default_factory=dict)
processing_order: tuple[str, ...] = ()
Expand Down
56 changes: 39 additions & 17 deletions preprocessing/nextclade/src/loculus_preprocessing/nextclade.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import pandas as pd
from Bio import SeqIO

from loculus_preprocessing.sequence_checks import error_on_excess_sequences

from .config import AlignmentRequirement, Config, NextcladeSequenceAndDataset, SequenceName
from .datatypes import (
AccessionVersion,
Expand Down Expand Up @@ -763,7 +765,42 @@ def load_aligned_aa_sequences(
return aligned_aminoacid_sequences


def enrich_with_nextclade( # noqa: C901, PLR0914
def assign_segment_for_alignment(
unprocessed: Sequence[UnprocessedEntry], config: Config, dataset_dir: str
) -> SequenceAssignmentBatch:
errors = {}
for entry in unprocessed:
errors[entry.accessionVersion] = error_on_excess_sequences(
len(entry.data.unalignedNucleotideSequences),
config,
)
if not config.multi_datasets:
batch = assign_all_single_segments(unprocessed, config=config)
else:
match config.segment_classification_method:
case SegmentClassificationMethod.DIAMOND:
batch = assign_segment_with_diamond(
unprocessed, config=config, dataset_dir=dataset_dir
)
case SegmentClassificationMethod.MINIMIZER:
batch = assign_segment_with_nextclade_sort(
unprocessed, config=config, dataset_dir=dataset_dir
)
case SegmentClassificationMethod.ALIGN:
batch = assign_segment_with_nextclade_align(
unprocessed, config=config, dataset_dir=dataset_dir
)
batch.alerts = {
id: Alert(
errors=[*batch.alerts[id].errors, *error] if error else batch.alerts[id].errors,
warnings=batch.alerts[id].warnings,
)
for id, error in errors.items()
}
return batch


def enrich_with_nextclade( # noqa: PLR0914
unprocessed: Sequence[UnprocessedEntry], dataset_dir: str, config: Config
) -> dict[AccessionVersion, UnprocessedAfterNextclade]:
"""
Expand Down Expand Up @@ -791,22 +828,7 @@ def enrich_with_nextclade( # noqa: C901, PLR0914
for entry in unprocessed
}

if not config.multi_datasets:
batch = assign_all_single_segments(unprocessed, config=config)
else:
match config.segment_classification_method:
case SegmentClassificationMethod.DIAMOND:
batch = assign_segment_with_diamond(
unprocessed, config=config, dataset_dir=dataset_dir
)
case SegmentClassificationMethod.MINIMIZER:
batch = assign_segment_with_nextclade_sort(
unprocessed, config=config, dataset_dir=dataset_dir
)
case SegmentClassificationMethod.ALIGN:
batch = assign_segment_with_nextclade_align(
unprocessed, config=config, dataset_dir=dataset_dir
)
batch = assign_segment_for_alignment(unprocessed, config=config, dataset_dir=dataset_dir)
unaligned_nucleotide_sequences = batch.unalignedNucleotideSequences
segment_assignment_map = batch.sequenceNameToFastaId
alerts: Alerts = batch.alerts
Expand Down
Loading
Loading