Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions prepare_species/extract_species_data_psql.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
import sys
from functools import partial
from multiprocessing import Pool
from multiprocessing import Pool, cpu_count
from pathlib import Path
from typing import Optional

Expand Down Expand Up @@ -257,7 +257,11 @@ def extract_data_per_species(
excludes_path: Optional[Path],
output_directory_path: Path,
target_projection: Optional[str],
processes_count: int,
) -> None:
# The limiting amount here is how many concurrent connections the database can take
worker_count=min(processes_count, 20)
logger.info("Using %d workers", worker_count)

connection = psycopg2.connect(DB_CONFIG)
cursor = connection.cursor()
Expand Down Expand Up @@ -293,12 +297,11 @@ def extract_data_per_species(
if overrides_path:
results = apply_overrides(overrides_path, results)

# The limiting amount here is how many concurrent connections the database can take
try:
with Pool(processes=20) as pool:
with Pool(processes=worker_count) as pool:
reports = pool.map(
partial(process_row, class_name, era_output_directory_path, target_projection, presence),
results
results,
)
except psycopg2.OperationalError:
sys.exit("Database connection failed for some rows, aborting")
Expand All @@ -316,6 +319,7 @@ def extract_data_per_species(
"excludes": "input.excludes",
"output_directory_path": "params.output_dir",
"target_projection": "params.projection",
"processes_count": "threads",
})
def main() -> None:
parser = argparse.ArgumentParser(description="Process agregate species data to per-species-file.")
Expand Down Expand Up @@ -357,14 +361,23 @@ def main() -> None:
dest="target_projection",
default="ESRI:54017"
)
parser.add_argument(
"-j",
type=int,
required=False,
default=cpu_count() // 2,
dest="processes_count",
help="Number of concurrent threads to use."
)
args = parser.parse_args()

extract_data_per_species(
args.classname,
args.overrides,
args.excludes,
args.output_directory_path,
args.target_projection
args.target_projection,
args.processes_count,
)

if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion threats/threat_summation.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def main() -> None:
threat_summation(
args.rasters_directory,
args.output_directory,
args.processes_count
args.processes_count,
)

if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions workflow/rules/species.smk
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ checkpoint extract_species_data:
# Serialise DB access scripts - only one extraction script at a time
# as it will make many concurrent connections internally
db_connections=1,
threads: workflow.cores
script:
str(SRCDIR / "prepare_species" / "extract_species_data_psql.py")

Expand Down