diff --git a/prepare_species/extract_species_data_psql.py b/prepare_species/extract_species_data_psql.py index 3350664..3dfa314 100644 --- a/prepare_species/extract_species_data_psql.py +++ b/prepare_species/extract_species_data_psql.py @@ -5,7 +5,7 @@ import os import sys from functools import partial -from multiprocessing import Pool +from multiprocessing import Pool, cpu_count from pathlib import Path from typing import Optional @@ -257,7 +257,11 @@ def extract_data_per_species( excludes_path: Optional[Path], output_directory_path: Path, target_projection: Optional[str], + processes_count: int, ) -> None: + # The limiting amount here is how many concurrent connections the database can take + worker_count=min(processes_count, 20) + logger.info("Using %d workers", worker_count) connection = psycopg2.connect(DB_CONFIG) cursor = connection.cursor() @@ -293,12 +297,11 @@ def extract_data_per_species( if overrides_path: results = apply_overrides(overrides_path, results) - # The limiting amount here is how many concurrent connections the database can take try: - with Pool(processes=20) as pool: + with Pool(processes=worker_count) as pool: reports = pool.map( partial(process_row, class_name, era_output_directory_path, target_projection, presence), - results + results, ) except psycopg2.OperationalError: sys.exit("Database connection failed for some rows, aborting") @@ -316,6 +319,7 @@ def extract_data_per_species( "excludes": "input.excludes", "output_directory_path": "params.output_dir", "target_projection": "params.projection", + "processes_count": "threads", }) def main() -> None: parser = argparse.ArgumentParser(description="Process agregate species data to per-species-file.") @@ -357,6 +361,14 @@ def main() -> None: dest="target_projection", default="ESRI:54017" ) + parser.add_argument( + "-j", + type=int, + required=False, + default=cpu_count() // 2, + dest="processes_count", + help="Number of concurrent threads to use." + ) args = parser.parse_args() extract_data_per_species( @@ -364,7 +376,8 @@ def main() -> None: args.overrides, args.excludes, args.output_directory_path, - args.target_projection + args.target_projection, + args.processes_count, ) if __name__ == "__main__": diff --git a/threats/threat_summation.py b/threats/threat_summation.py index b05d38c..e095fba 100644 --- a/threats/threat_summation.py +++ b/threats/threat_summation.py @@ -215,7 +215,7 @@ def main() -> None: threat_summation( args.rasters_directory, args.output_directory, - args.processes_count + args.processes_count, ) if __name__ == "__main__": diff --git a/workflow/rules/species.smk b/workflow/rules/species.smk index 720d353..5e6fee5 100644 --- a/workflow/rules/species.smk +++ b/workflow/rules/species.smk @@ -35,6 +35,7 @@ checkpoint extract_species_data: # Serialise DB access scripts - only one extraction script at a time # as it will make many concurrent connections internally db_connections=1, + threads: workflow.cores script: str(SRCDIR / "prepare_species" / "extract_species_data_psql.py")