diff --git a/preprocessing/nextclade/src/loculus_preprocessing/config.py b/preprocessing/nextclade/src/loculus_preprocessing/config.py index d84d5403ca..62e358355b 100644 --- a/preprocessing/nextclade/src/loculus_preprocessing/config.py +++ b/preprocessing/nextclade/src/loculus_preprocessing/config.py @@ -60,6 +60,7 @@ class Config: nextclade_dataset_name: str | None = None nextclade_dataset_name_map: dict[str, str] | None = None nextclade_dataset_tag: str | None = None + nextclade_dataset_tag_map: dict[str, str] | None = None nextclade_dataset_server: str = "https://data.clades.nextstrain.org/v3" nextclade_dataset_server_map: dict[str, str] | None = None diff --git a/preprocessing/nextclade/src/loculus_preprocessing/prepro.py b/preprocessing/nextclade/src/loculus_preprocessing/prepro.py index 13f9eef580..a850cd1fd4 100644 --- a/preprocessing/nextclade/src/loculus_preprocessing/prepro.py +++ b/preprocessing/nextclade/src/loculus_preprocessing/prepro.py @@ -838,10 +838,17 @@ def get_nextclade_dataset_server(config: Config, segment: SegmentName) -> str: return config.nextclade_dataset_server +def get_nextclade_dataset_tag(config: Config, segment: SegmentName) -> str | None: + if config.nextclade_dataset_tag_map and segment in config.nextclade_dataset_tag_map: + return config.nextclade_dataset_tag_map[segment] + return config.nextclade_dataset_tag + + def download_nextclade_dataset(dataset_dir: str, config: Config) -> None: for segment in config.nucleotideSequences: nextclade_dataset_name = get_nextclade_dataset_name(config, segment) nextclade_dataset_server = get_nextclade_dataset_server(config, segment) + nextclade_dataset_tag = get_nextclade_dataset_tag(config, segment) dataset_dir_seg = dataset_dir if segment == "main" else dataset_dir + "/" + segment dataset_download_command = [ @@ -853,8 +860,8 @@ def download_nextclade_dataset(dataset_dir: str, config: Config) -> None: f"--output-dir={dataset_dir_seg}", ] - if config.nextclade_dataset_tag is not None: - dataset_download_command.append(f"--tag={config.nextclade_dataset_tag}") + if nextclade_dataset_tag is not None: + dataset_download_command.append(f"--tag={nextclade_dataset_tag}") logger.info("Downloading Nextclade dataset: %s", dataset_download_command) if subprocess.run(dataset_download_command, check=False).returncode != 0: # noqa: S603