diff --git a/.gitignore b/.gitignore index 866eaff..30fb92c 100644 --- a/.gitignore +++ b/.gitignore @@ -250,4 +250,4 @@ $RECYCLE.BIN/ # End of https://www.toptal.com/developers/gitignore/api/python,macos,windows,linux # pycharm -.idea \ No newline at end of file +.idea diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..380dafc --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,37 @@ +repos: + # This should be before any formatting hooks like isort + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: "v0.12.1" + hooks: + - id: ruff + args: ["--fix"] + # Run the formatter. + - id: ruff-format + - repo: https://github.com/PyCQA/isort + rev: 6.0.1 + hooks: + - id: isort + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v5.0.0 + hooks: + - id: check-ast + - id: check-case-conflict + - id: trailing-whitespace + - id: check-yaml + - id: debug-statements + - id: check-added-large-files + args: ["--enforce-all", "--maxkb=1054"] + exclude: "" + - id: end-of-file-fixer + - id: mixed-line-ending + - repo: https://github.com/codespell-project/codespell + rev: v2.4.1 + hooks: + - id: codespell + additional_dependencies: + - tomli + args: ["--write-changes"] + +ci: + autofix_prs: false + autoupdate_schedule: "quarterly" diff --git a/Dockerfile b/Dockerfile index 09ab680..d76e45c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM python:3.11 +FROM python:3.11-bookworm WORKDIR /code @@ -6,8 +6,8 @@ WORKDIR /code RUN apt-get update && apt-get install -y \ libgl1-mesa-glx \ libglib2.0-0 \ - git \ - && rm -rf /var/lib/apt/lists/* + git +RUN rm -rf /var/lib/apt/lists/* # Set environment variables diff --git a/README.md b/README.md index 29c7a71..a8a958e 100644 --- a/README.md +++ b/README.md @@ -23,8 +23,9 @@ `docker run -p 8000:80 arcaff-api:0.1.0` +## Configuration - +By default the model and data files will be store at `/arcnet/data` and `/arccnet/models`. These locations can be overridden by setting the `DATAPATH` and, or `MODESLPATH` environment variables. Additionally thes can be mount to host file system e.g. `-v /host/path:/arccnet/data` # Test @@ -114,4 +115,4 @@ curl -X 'POST' \ "mcintosh_class": "Axx" } ] -``` \ No newline at end of file +``` diff --git a/app/api.py b/app/api.py index 276350f..6a1f714 100644 --- a/app/api.py +++ b/app/api.py @@ -8,30 +8,42 @@ router = APIRouter() -@router.get("/arcnet/classify_cutout/", tags=['AR Cutout Classification']) -async def classify_cutout(classification_request: ARCutoutClassificationInput = Depends()) -> ARCutoutClassificationResult: +@router.get("/arcnet/classify_cutout/", tags=["AR Cutout Classification"]) +async def classify_cutout( + classification_request: ARCutoutClassificationInput = Depends(), +) -> ARCutoutClassificationResult: r""" Classify a cutout generated from a magnetogram at the given date and location as URL parameters. """ - classification = classify(time=classification_request.time, hgs_latitude=classification_request.hgs_latitude, - hgs_longitude=classification_request.hgs_longitude) + classification = classify( + time=classification_request.time, + hgs_latitude=classification_request.hgs_latitude, + hgs_longitude=classification_request.hgs_longitude, + ) classification_result = ARCutoutClassificationResult.model_validate(classification) return classification_result -@router.post("/arcnet/classify_cutout/", tags=['AR Cutout Classification']) -async def classify_cutout(classification_request: ARCutoutClassificationInput) -> ARCutoutClassificationResult: +@router.post("/arcnet/classify_cutout/", tags=["AR Cutout Classification"]) +async def classify_cutout( + classification_request: ARCutoutClassificationInput, +) -> ARCutoutClassificationResult: r""" Classify an AR cutout generated from a magnetogram at the given date and location as json data. """ - classification = classify(time=classification_request.time, hgs_latitude=classification_request.hgs_latitude, - hgs_longitude=classification_request.hgs_longitude) + classification = classify( + time=classification_request.time, + hgs_latitude=classification_request.hgs_latitude, + hgs_longitude=classification_request.hgs_longitude, + ) classification_result = ARCutoutClassificationResult.model_validate(classification) return classification_result -@router.get("/arcnet/full_disk_detection", tags=['Full disk AR Detection']) -async def full_disk_detection(detection_request: ARDetectionInput = Depends()) -> List[ARDetection]: +@router.get("/arcnet/full_disk_detection", tags=["Full disk AR Detection"]) +async def full_disk_detection( + detection_request: ARDetectionInput = Depends(), +) -> List[ARDetection]: r""" Detect and classify all ARs in a magnetogram at the given date as a URL parameter. """ @@ -40,11 +52,11 @@ async def full_disk_detection(detection_request: ARDetectionInput = Depends()) - return detection_result -@router.post("/arcnet/full_disk_detection", tags=['Full disk AR Detection']) +@router.post("/arcnet/full_disk_detection", tags=["Full disk AR Detection"]) async def full_disk_detection(detection_request: ARDetectionInput) -> List[ARDetection]: r""" Detect and classify all ARs in a magnetogram at the given date as json data. """ detections = detect(detection_request.time) detection_result = [ARDetection.model_validate(d) for d in detections] - return detection_result \ No newline at end of file + return detection_result diff --git a/app/classify.py b/app/classify.py index 8a792ee..45f7ad5 100644 --- a/app/classify.py +++ b/app/classify.py @@ -1,15 +1,12 @@ from datetime import datetime, timedelta -from pathlib import Path import astropy.units as u -import numpy as np -import torch from astropy.coordinates import SkyCoord -from astropy.time import Time from sunpy.map import Map from sunpy.net import Fido from sunpy.net import attrs as a +from app.config import settings from app.fulldisk_model import yolo_detection from app.hale_model import hale_classification from app.mcintosh_encoders import decode_predicted_classes_to_original @@ -18,13 +15,16 @@ CUTOUT = [800, 400] * u.pix -def download_magnetogram(time): +def download_magnetogram(time, download_path): r""" Download magnetogram and prepare for use Parameters ---------- time : datetime.datetime + Time to search for closest magnetogram + download_path : Path + Path to datadownload directory """ query = Fido.search( @@ -34,7 +34,7 @@ def download_magnetogram(time): ) if not query: raise MagNotFoundError() - files = Fido.fetch(query["vso"][0]) + files = Fido.fetch(query["vso"][0], path=download_path) if not files: raise MagDownloadError() mag_map = Map(files[0]) @@ -59,7 +59,7 @@ def classify(time: datetime, hgs_latitude: float, hgs_longitude: float): Classification result """ - mag_map = download_magnetogram(time) + mag_map = download_magnetogram(time, settings.data_path) if "hmi" in mag_map.detector.casefold(): size = CUTOUT else: @@ -143,7 +143,7 @@ def detect(time: datetime): List of detection results with bounding boxes and classifications """ # Download magnetogram - mag_map = download_magnetogram(time) + mag_map = download_magnetogram(time, settings.data_path) detections = yolo_detection(mag_map) return detections diff --git a/app/config.py b/app/config.py new file mode 100644 index 0000000..3437029 --- /dev/null +++ b/app/config.py @@ -0,0 +1,12 @@ +import os +from pathlib import Path + +from pydantic_settings import BaseSettings + + +class Settings(BaseSettings): + data_path: Path = Path(os.getenv("DATAPATH", "/arccnet/data")) + model_path: Path = Path(os.getenv("MODESLPATH", "/arccnet/models")) + + +settings = Settings() diff --git a/app/fulldisk_model.py b/app/fulldisk_model.py index b0d604e..ff0da81 100644 --- a/app/fulldisk_model.py +++ b/app/fulldisk_model.py @@ -1,19 +1,14 @@ -import logging from pathlib import Path import astropy.units as u import numpy as np import torch from arccnet.visualisation import utils as ut_v -from astropy.coordinates import SkyCoord from PIL import Image -from sunpy.coordinates import SphericalScreen from ultralytics import YOLO -from app.model_utils import ( - download_and_extract_model, - logger, -) +from app.config import settings +from app.model_utils import download_and_extract_model, logger device = "cuda" if torch.cuda.is_available() else "cpu" @@ -24,7 +19,12 @@ def download_yolo_model(): """ model_url = "https://www.comet.com/api/registry/model/item/download?modelItemId=9iPvriGnFaFjE6dNzGYYYZoEG" - weights_path = download_and_extract_model(model_url, "yolo_detection", "best.pt") + weights_path = download_and_extract_model( + model_url, + "yolo_detection", + extracted_weights_filename="best.pt", + model_data_path=settings.model_path, + ) try: # Load YOLO model with downloaded weights diff --git a/app/hale_model.py b/app/hale_model.py index 9a88e7f..120f8d7 100644 --- a/app/hale_model.py +++ b/app/hale_model.py @@ -1,16 +1,12 @@ import numpy as np import timm import torch - -from app.model_utils import ( - download_and_extract_model, - load_state_dict, - logger, - preprocess_data, - safe_inference, -) from arccnet.models import train_utils as ut_t +from app.config import settings +from app.model_utils import (download_and_extract_model, load_state_dict, + logger, preprocess_data, safe_inference) + def download_model(): """ @@ -23,7 +19,9 @@ def download_model(): num_classes = 5 # qs, ia, a, b, bg model_url = "https://www.comet.com/api/registry/model/item/download?modelItemId=2Y3HZMoq3XXffVgzkL9wE9IZb" - weights_path = download_and_extract_model(model_url, model_name) + weights_path = download_and_extract_model( + model_url, model_name, model_data_path=settings.model_path + ) state_dict = load_state_dict(weights_path) # Load and create model diff --git a/app/main.py b/app/main.py index 50fadce..370d2e2 100644 --- a/app/main.py +++ b/app/main.py @@ -43,13 +43,13 @@ "name": "ARCAFF", "url": "http://www.arcaff.eu", }, - openapi_tags=tags_metadata + openapi_tags=tags_metadata, ) -app.include_router(router, prefix='') +app.include_router(router, prefix="") app.add_middleware( CORSMiddleware, - allow_origins='*', + allow_origins="*", allow_credentials=True, allow_methods=["get", "post", "options"], allow_headers=["*"], @@ -62,4 +62,4 @@ async def add_process_time_header(request: Request, call_next): response = await call_next(request) process_time = time.time() - start_time response.headers["X-Process-Time"] = str(process_time) - return response \ No newline at end of file + return response diff --git a/app/mcintosh_model.py b/app/mcintosh_model.py index 7daced6..c0f5c38 100644 --- a/app/mcintosh_model.py +++ b/app/mcintosh_model.py @@ -1,23 +1,14 @@ -import logging -import zipfile -from pathlib import Path - -import numpy as np -import requests import torch import torch.nn as nn - -from app.mcintosh_encoders import c_classes, create_encoders, p_classes, z_classes -from app.model_utils import ( - download_and_extract_model, - load_state_dict, - logger, - preprocess_data, - safe_inference, -) from arccnet.models import train_utils as ut_t from arccnet.models.cutouts.mcintosh.models import HierarchicalResNet +from app.config import settings +from app.mcintosh_encoders import (c_classes, create_encoders, p_classes, + z_classes) +from app.model_utils import (download_and_extract_model, load_state_dict, + logger, preprocess_data, safe_inference) + device = "cuda" if torch.cuda.is_available() else "cpu" model_name = "resnet18" num_classes_Z = 5 # A, B, C, H, LG @@ -31,7 +22,11 @@ def download_model(): """ model_url = "https://www.comet.com/api/registry/model/item/download?modelItemId=ZkTcrrYWpJwlQ3Kmlp6GCJGiK" - weights_path = download_and_extract_model(model_url, f"{model_name}_mcintosh") + weights_path = download_and_extract_model( + model_url, + f"{model_name}_mcintosh", + model_data_path=settings.model_path, + ) state_dict = load_state_dict(weights_path, device) # Load and create model diff --git a/app/model_utils.py b/app/model_utils.py index 7666433..e5c0ed5 100644 --- a/app/model_utils.py +++ b/app/model_utils.py @@ -1,11 +1,9 @@ import logging import zipfile -from pathlib import Path import numpy as np import requests import torch - from arccnet.visualisation import utils as ut_v # Configure logging @@ -16,7 +14,11 @@ def download_and_extract_model( - model_url, model_name, extracted_weights_filename="model-data/comet-torch-model.pth" + model_url, + model_name, + *, + extracted_weights_filename="model-data/comet-torch-model.pth", + model_data_path, ): """ Downloads and extracts a model archive. @@ -29,16 +31,18 @@ def download_and_extract_model( Name of the model for cache naming extracted_weights_filename : str Expected filename of the weights in the archive + model_data_path : Path + Path to the model data directory Returns ------- Path Path to the extracted weights file """ - CACHE_DIR = Path(".cache") - CACHE_DIR.mkdir(parents=True, exist_ok=True) + model_data = model_data_path + model_data.mkdir(parents=True, exist_ok=True) - model_cache_dir = CACHE_DIR / model_name + model_cache_dir = model_data / model_name model_cache_dir.mkdir(parents=True, exist_ok=True) archive_path = model_cache_dir / f"{model_name}_archive.zip" @@ -81,10 +85,7 @@ def download_and_extract_model( # If extracted file has different name, create expected path actual_extracted_path = model_cache_dir / target_file if actual_extracted_path != extracted_path: - # Copy/move to expected location - import shutil - - shutil.move(str(actual_extracted_path), str(extracted_path)) + actual_extracted_path.move(extracted_path) if not extracted_path.exists(): raise FileNotFoundError(f"Weights file not found at {extracted_path}") diff --git a/requirements.txt b/requirements.txt index 5ab272b..765d40f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,4 @@ -fastapi >= 0.104.1 +fastapi[all] >= 0.104.1 uvicorn >= 0.24.0 -sunpy[net,map] ultralytics -arccnet[models]@git+https://github.com/ARCAFF/ARCCnet.git \ No newline at end of file +arccnet[models]@git+https://github.com/ARCAFF/ARCCnet.git