Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -250,4 +250,4 @@ $RECYCLE.BIN/
# End of https://www.toptal.com/developers/gitignore/api/python,macos,windows,linux

# pycharm
.idea
.idea
37 changes: 37 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
repos:
# This should be before any formatting hooks like isort
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.12.1"
hooks:
- id: ruff
args: ["--fix"]
# Run the formatter.
- id: ruff-format
- repo: https://github.com/PyCQA/isort
rev: 6.0.1
hooks:
- id: isort
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
- id: check-ast
- id: check-case-conflict
- id: trailing-whitespace
- id: check-yaml
- id: debug-statements
- id: check-added-large-files
args: ["--enforce-all", "--maxkb=1054"]
exclude: ""
- id: end-of-file-fixer
- id: mixed-line-ending
- repo: https://github.com/codespell-project/codespell
rev: v2.4.1
hooks:
- id: codespell
additional_dependencies:
- tomli
args: ["--write-changes"]

ci:
autofix_prs: false
autoupdate_schedule: "quarterly"
6 changes: 3 additions & 3 deletions Dockerfile

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be one RUN block based on docker best practices.

Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
FROM python:3.11
FROM python:3.11-bookworm

WORKDIR /code

# Install essential system dependencies
RUN apt-get update && apt-get install -y \
libgl1-mesa-glx \
libglib2.0-0 \
git \
&& rm -rf /var/lib/apt/lists/*
git
RUN rm -rf /var/lib/apt/lists/*


# Set environment variables
Expand Down
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@

`docker run -p 8000:80 arcaff-api:0.1.0`

## Configuration


By default the model and data files will be store at `/arcnet/data` and `/arccnet/models`. These locations can be overridden by setting the `DATAPATH` and, or `MODESLPATH` environment variables. Additionally thes can be mount to host file system e.g. `-v /host/path:/arccnet/data`


# Test
Expand Down Expand Up @@ -114,4 +115,4 @@ curl -X 'POST' \
"mcintosh_class": "Axx"
}
]
```
```
36 changes: 24 additions & 12 deletions app/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,30 +8,42 @@
router = APIRouter()


@router.get("/arcnet/classify_cutout/", tags=['AR Cutout Classification'])
async def classify_cutout(classification_request: ARCutoutClassificationInput = Depends()) -> ARCutoutClassificationResult:
@router.get("/arcnet/classify_cutout/", tags=["AR Cutout Classification"])
async def classify_cutout(
classification_request: ARCutoutClassificationInput = Depends(),
) -> ARCutoutClassificationResult:
r"""
Classify a cutout generated from a magnetogram at the given date and location as URL parameters.
"""
classification = classify(time=classification_request.time, hgs_latitude=classification_request.hgs_latitude,
hgs_longitude=classification_request.hgs_longitude)
classification = classify(
time=classification_request.time,
hgs_latitude=classification_request.hgs_latitude,
hgs_longitude=classification_request.hgs_longitude,
)
classification_result = ARCutoutClassificationResult.model_validate(classification)
return classification_result


@router.post("/arcnet/classify_cutout/", tags=['AR Cutout Classification'])
async def classify_cutout(classification_request: ARCutoutClassificationInput) -> ARCutoutClassificationResult:
@router.post("/arcnet/classify_cutout/", tags=["AR Cutout Classification"])
async def classify_cutout(
classification_request: ARCutoutClassificationInput,
) -> ARCutoutClassificationResult:
r"""
Classify an AR cutout generated from a magnetogram at the given date and location as json data.
"""
classification = classify(time=classification_request.time, hgs_latitude=classification_request.hgs_latitude,
hgs_longitude=classification_request.hgs_longitude)
classification = classify(
time=classification_request.time,
hgs_latitude=classification_request.hgs_latitude,
hgs_longitude=classification_request.hgs_longitude,
)
classification_result = ARCutoutClassificationResult.model_validate(classification)
return classification_result


@router.get("/arcnet/full_disk_detection", tags=['Full disk AR Detection'])
async def full_disk_detection(detection_request: ARDetectionInput = Depends()) -> List[ARDetection]:
@router.get("/arcnet/full_disk_detection", tags=["Full disk AR Detection"])
async def full_disk_detection(
detection_request: ARDetectionInput = Depends(),
) -> List[ARDetection]:
r"""
Detect and classify all ARs in a magnetogram at the given date as a URL parameter.
"""
Expand All @@ -40,11 +52,11 @@ async def full_disk_detection(detection_request: ARDetectionInput = Depends()) -
return detection_result


@router.post("/arcnet/full_disk_detection", tags=['Full disk AR Detection'])
@router.post("/arcnet/full_disk_detection", tags=["Full disk AR Detection"])
async def full_disk_detection(detection_request: ARDetectionInput) -> List[ARDetection]:
r"""
Detect and classify all ARs in a magnetogram at the given date as json data.
"""
detections = detect(detection_request.time)
detection_result = [ARDetection.model_validate(d) for d in detections]
return detection_result
return detection_result
16 changes: 8 additions & 8 deletions app/classify.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
from datetime import datetime, timedelta
from pathlib import Path

import astropy.units as u
import numpy as np
import torch
from astropy.coordinates import SkyCoord
from astropy.time import Time
from sunpy.map import Map
from sunpy.net import Fido
from sunpy.net import attrs as a

from app.config import settings
from app.fulldisk_model import yolo_detection
from app.hale_model import hale_classification
from app.mcintosh_encoders import decode_predicted_classes_to_original
Expand All @@ -18,13 +15,16 @@
CUTOUT = [800, 400] * u.pix


def download_magnetogram(time):
def download_magnetogram(time, download_path):
r"""
Download magnetogram and prepare for use

Parameters
----------
time : datetime.datetime
Time to search for closest magnetogram
download_path : Path
Path to datadownload directory

"""
query = Fido.search(
Expand All @@ -34,7 +34,7 @@ def download_magnetogram(time):
)
if not query:
raise MagNotFoundError()
files = Fido.fetch(query["vso"][0])
files = Fido.fetch(query["vso"][0], path=download_path)
if not files:
raise MagDownloadError()
mag_map = Map(files[0])
Expand All @@ -59,7 +59,7 @@ def classify(time: datetime, hgs_latitude: float, hgs_longitude: float):
Classification result

"""
mag_map = download_magnetogram(time)
mag_map = download_magnetogram(time, settings.data_path)
if "hmi" in mag_map.detector.casefold():
size = CUTOUT
else:
Expand Down Expand Up @@ -143,7 +143,7 @@ def detect(time: datetime):
List of detection results with bounding boxes and classifications
"""
# Download magnetogram
mag_map = download_magnetogram(time)
mag_map = download_magnetogram(time, settings.data_path)
detections = yolo_detection(mag_map)
return detections

Expand Down
12 changes: 12 additions & 0 deletions app/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import os
from pathlib import Path

from pydantic_settings import BaseSettings


class Settings(BaseSettings):
data_path: Path = Path(os.getenv("DATAPATH", "/arccnet/data"))
model_path: Path = Path(os.getenv("MODESLPATH", "/arccnet/models"))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MODESLPATH is a typo and it should be MODELPATH?



settings = Settings()
16 changes: 8 additions & 8 deletions app/fulldisk_model.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
import logging
from pathlib import Path

import astropy.units as u
import numpy as np
import torch
from arccnet.visualisation import utils as ut_v
from astropy.coordinates import SkyCoord
from PIL import Image
from sunpy.coordinates import SphericalScreen
from ultralytics import YOLO

from app.model_utils import (
download_and_extract_model,
logger,
)
from app.config import settings
from app.model_utils import download_and_extract_model, logger

device = "cuda" if torch.cuda.is_available() else "cpu"

Expand All @@ -24,7 +19,12 @@ def download_yolo_model():
"""
model_url = "https://www.comet.com/api/registry/model/item/download?modelItemId=9iPvriGnFaFjE6dNzGYYYZoEG"

weights_path = download_and_extract_model(model_url, "yolo_detection", "best.pt")
weights_path = download_and_extract_model(
model_url,
"yolo_detection",
extracted_weights_filename="best.pt",
model_data_path=settings.model_path,
)

try:
# Load YOLO model with downloaded weights
Expand Down
16 changes: 7 additions & 9 deletions app/hale_model.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
import numpy as np
import timm
import torch

from app.model_utils import (
download_and_extract_model,
load_state_dict,
logger,
preprocess_data,
safe_inference,
)
from arccnet.models import train_utils as ut_t

from app.config import settings
from app.model_utils import (download_and_extract_model, load_state_dict,
logger, preprocess_data, safe_inference)


def download_model():
"""
Expand All @@ -23,7 +19,9 @@ def download_model():
num_classes = 5 # qs, ia, a, b, bg
model_url = "https://www.comet.com/api/registry/model/item/download?modelItemId=2Y3HZMoq3XXffVgzkL9wE9IZb"

weights_path = download_and_extract_model(model_url, model_name)
weights_path = download_and_extract_model(
model_url, model_name, model_data_path=settings.model_path
)
state_dict = load_state_dict(weights_path)

# Load and create model
Expand Down
8 changes: 4 additions & 4 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,13 @@
"name": "ARCAFF",
"url": "http://www.arcaff.eu",
},
openapi_tags=tags_metadata
openapi_tags=tags_metadata,
)

app.include_router(router, prefix='')
app.include_router(router, prefix="")
app.add_middleware(
CORSMiddleware,
allow_origins='*',
allow_origins="*",
allow_credentials=True,
allow_methods=["get", "post", "options"],
allow_headers=["*"],
Expand All @@ -62,4 +62,4 @@ async def add_process_time_header(request: Request, call_next):
response = await call_next(request)
process_time = time.time() - start_time
response.headers["X-Process-Time"] = str(process_time)
return response
return response
27 changes: 11 additions & 16 deletions app/mcintosh_model.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,14 @@
import logging
import zipfile
from pathlib import Path

import numpy as np
import requests
import torch
import torch.nn as nn

from app.mcintosh_encoders import c_classes, create_encoders, p_classes, z_classes
from app.model_utils import (
download_and_extract_model,
load_state_dict,
logger,
preprocess_data,
safe_inference,
)
from arccnet.models import train_utils as ut_t
from arccnet.models.cutouts.mcintosh.models import HierarchicalResNet

from app.config import settings
from app.mcintosh_encoders import (c_classes, create_encoders, p_classes,
z_classes)
from app.model_utils import (download_and_extract_model, load_state_dict,
logger, preprocess_data, safe_inference)

device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "resnet18"
num_classes_Z = 5 # A, B, C, H, LG
Expand All @@ -31,7 +22,11 @@ def download_model():
"""
model_url = "https://www.comet.com/api/registry/model/item/download?modelItemId=ZkTcrrYWpJwlQ3Kmlp6GCJGiK"

weights_path = download_and_extract_model(model_url, f"{model_name}_mcintosh")
weights_path = download_and_extract_model(
model_url,
f"{model_name}_mcintosh",
model_data_path=settings.model_path,
)
state_dict = load_state_dict(weights_path, device)

# Load and create model
Expand Down
Loading