Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions python/valis/db/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from sdssdb.peewee.sdss5db import database as pdb
from sdssdb.sqlalchemy.sdss5db import database as sdb

from valis.utils.versions import get_software_tag

# To make Peewee async-compatible, we need to hack the peewee connection state
# See FastAPI/Peewee docs at https://fastapi.tiangolo.com/how-to/sql-databases-peewee/
Expand Down Expand Up @@ -69,18 +70,34 @@ def connect_db(db, orm: str = 'peewee'):

return db

async def get_pw_db(db_state=Depends(reset_db_state)):
""" Dependency to connect a database with peewee """
async def get_pw_db(db_state=Depends(reset_db_state), release: str | None = None):
""" Dependency to connect a database with peewee

from valis.main import settings
dependency inputs act as query parameters, release input comes from the
release qp dependency in routes/base.py
"""

from valis.main import settings
# connect to the db, yield None since we don't need the db in peewee
if settings.db_reset:
db = connect_db(pdb, orm='peewee')
else:
async with asyncio.Lock():
db = connect_db(pdb, orm='peewee')

# set the correct astra schema if needed
try:
vastra = get_software_tag(release, 'v_astra')
except AttributeError:
# case when release is None or invalid
# uses default set astra schema defined in sdssdb
pass
else:
# for dr19 or ipl3 set schema to 0.5.0 ; ipl4=0.8.0
vastra = "0.5.0" if vastra in ("0.5.0", "0.6.0") else vastra
schema = f"astra_{vastra.replace('.', '')}"
Comment on lines +97 to +98
Copy link

Copilot AI Oct 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The version mapping logic vastra = \"0.5.0\" if vastra in (\"0.5.0\", \"0.6.0\") else vastra is duplicated in both get_pw_db and get_astra_target. Consider extracting this logic into a shared utility function to improve maintainability.

Copilot uses AI. Check for mistakes.
pdb.set_astra_schema(schema)

try:
yield db
finally:
Expand Down
6 changes: 3 additions & 3 deletions python/valis/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ class AstraSource(PeeweeBase):
gaia_dr3_source_id: Optional[int] = None
tic_v8_id: Optional[int] = None
healpix: int = None
n_associated: int = None
n_associated: Optional[int] = None
n_neighborhood: int = None
sdss4_apogee_target1_flags: int = None
sdss4_apogee_target2_flags: int = None
Expand All @@ -170,8 +170,8 @@ class AstraSource(PeeweeBase):
n_apogee_visits: Optional[int] = None
l: float = None
b: float = None
ebv: float = None
e_ebv: float = None
ebv: Optional[float] = None
e_ebv: Optional[float] = None
gaia_v_rad: FloatNaN[float] = None
gaia_e_v_rad: FloatNaN[float] = None
g_mag: FloatNaN[float] = None
Expand Down
70 changes: 49 additions & 21 deletions python/valis/db/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@

# all resuable queries go here

from contextlib import contextmanager
import itertools
import packaging
import uuid
from typing import Sequence, Union, Generator
from typing import Union, Generator
from enum import Enum

import astropy.units as u
Expand Down Expand Up @@ -73,6 +72,10 @@ def append_pipes(query: peewee.ModelSelect, table: str = 'stacked',
if table not in {'stacked', 'flat'}:
raise ValueError('table must be either "stacked" or "flat"')

# cannot create temp table if query is None
if query is None:
return query

# Run initial query as a temporary table.
temp = create_temporary_table(query, indices=['sdss_id'])

Expand Down Expand Up @@ -412,6 +415,7 @@ def get_targets_obs(release: str, obs: str, spectrograph: str) -> peewee.ModelSe

# test sdss ids
# 23326 - boss/astra
# 25739 in astra 0.8.0 but not 0.5.0 sources
# 3350466 - apogee/astra
# 54392544 - all true
# 10 - all false
Expand Down Expand Up @@ -464,44 +468,68 @@ def get_boss_target(sdss_id: int, release: str, fields: list = None,
return query


def get_apogee_target(sdss_id: int, release: str, fields: list = None):
""" temporary placeholder for apogee """
def get_apogee_target(sdss_id: int, release: str, fields: list = None) -> peewee.ModelSelect:
"""Get the Apogee target metadata for an sdss_id

Retrieves the apogee pipeline data from the apogee_drp.star table
for the given sdss_id and data release.

Parameters
----------
sdss_id : int
the input sdss_id
release : str
the data release to look up
fields : list, optional
a list of fields to retrieve from the database, by default None

Returns
-------
peewee.ModelSelect
the output query
"""
# get the relevant software tag
apred = get_software_tag(release, 'apred_vers')

# create apogee version conditions
if isinstance(apred, list):
vercond = apo.Star.apred_vers.in_(apred)
avsver = astra.ApogeeVisitSpectrum.apred.in_(apred)
else:
vercond = apo.Star.apred_vers == apred
avsver = astra.ApogeeVisitSpectrum.apred == apred

# check fields
fields = fields or [apo.Star]
if fields and isinstance(fields[0], str):
fields = (getattr(apo.Star, i) for i in fields)

# get the astra source for the sdss_id
s = get_astra_target(sdss_id, release)
if not s:
return

# get the astra apogee visit spectrum
a = s.first().apogee_visit_spectrum.where(avsver).first()
if not a:
return

# get the apogee star data
return apo.Star.select(*fields).where(apo.Star.pk == a.star_pk, vercond)
return apo.Star.select(*fields).where(apo.Star.sdss_id == sdss_id, vercond)

def get_astra_target(sdss_id: int, release: str, fields: list = None) -> peewee.ModelSelect:
"""Get the Astra target metadata for an sdss_id

Retrieves the astra source data from the astra.source table
for the given sdss_id and data release.

def get_astra_target(sdss_id: int, release: str, fields: list = None):
""" temporary placeholder for astra """
Parameters
----------
sdss_id : int
the input sdss_id
release : str
the data release to look up
fields : list, optional
a list of fields to retrieve from the database, by default None

Returns
-------
peewee.ModelSelect
the output query
"""
# check the astra version against the assigned schema
vastra = get_software_tag(release, 'v_astra')
if not vastra or vastra not in ("0.5.0", "0.6.0"):
print('astra only supports DR19 / IPL3 = version 0.5.0, 0.6.0')
vastra = "0.5.0" if vastra in ("0.5.0", "0.6.0") else vastra
if vastra is None or vastra.replace('.', '') not in astra.Source._meta.schema:
print(f"warning: astra version for current release {release} does not match assigned astra schema {astra.Source._meta.schema}")
return None

# check fields
Expand Down
Loading