Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.8, 3.9, '3.10', 3.11]
python-version: [3.11, 3.12, 3.13]

steps:
- uses: actions/checkout@v4
Expand Down
61 changes: 54 additions & 7 deletions astrodbkit/astrodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,21 @@

import json
import os
import sqlite3
import shutil
import sqlite3

import numpy as np
import pandas as pd
import sqlalchemy.types as sqlalchemy_types
import yaml
from astropy.coordinates import SkyCoord
from astropy.table import Table as AstropyTable
from astropy.units.quantity import Quantity
from sqlalchemy import Table, and_, create_engine, event, or_, text
from sqlalchemy.engine import Engine
from sqlalchemy.orm import declarative_base, sessionmaker
from sqlalchemy.orm.query import Query
from sqlalchemy.schema import CreateSchema
from tqdm import tqdm

from . import FOREIGN_KEY, PRIMARY_TABLE, PRIMARY_TABLE_KEY, REFERENCE_TABLES
Expand Down Expand Up @@ -166,7 +168,7 @@ def load_connection(connection_string, sqlite_foreign=True, base=None, connectio
session = Session()

# Enable foreign key checks in SQLite
if "sqlite" in connection_string and sqlite_foreign:
if connection_string.startswith("sqlite") and sqlite_foreign:
set_sqlite()
# elif 'postgresql' in connection_string:
# # Set up schema in postgres (must be lower case?)
Expand All @@ -189,23 +191,60 @@ def set_sqlite_pragma(dbapi_connection, connection_record):
cursor.close()


def create_database(connection_string, drop_tables=False):
def create_database(connection_string, drop_tables=False, felis_schema=None):
"""
Create a database from a schema that utilizes the `astrodbkit2.astrodb.Base` class.
Some databases, eg Postgres, must already exist but any tables should be dropped.
The default behavior is to assume that a schema with SQLAlchemy definitions has been imported prior to calling this function.
If instead, Felis is being used to define the schema, the path to the YAML file needs to be provided to the felis_schema parameter (as a string).

Parameters
----------
connection_string : str
Connection string to database
drop_tables : bool
Flag to drop existing tables. This is needed when the schema changes. (Default: False)
felis_schema : str
Path to schema yaml file
"""

session, base, engine = load_connection(connection_string, base=Base)
if drop_tables:
base.metadata.drop_all()
base.metadata.create_all(engine) # this explicitly creates the database
if felis_schema is not None:
# Felis loader requires felis_schema
from felis.datamodel import Schema
from felis.metadata import MetaDataBuilder

# Load and validate the felis-formatted schema
data = yaml.safe_load(open(felis_schema, "r"))
schema = Schema.model_validate(data)
schema_name = data["name"] # get schema_name from the felis schema file

# engine = create_engine(connection_string)
session, base, engine = load_connection(connection_string)

# Schema handling for various database types
if connection_string.startswith("sqlite"):
db_name = connection_string.split("/")[-1]
with engine.begin() as conn:
conn.execute(text(f"ATTACH '{db_name}' AS {schema_name}"))
elif connection_string.startswith("postgres"):
with engine.connect() as connection:
connection.execute(CreateSchema(schema_name, if_not_exists=True))
connection.commit()

# Drop tables, if requested
if drop_tables:
base.metadata.drop_all()

# Create the database
metadata = MetaDataBuilder(schema).build()
metadata.create_all(bind=engine)
base.metadata = metadata
else:
session, base, engine = load_connection(connection_string, base=Base)
if drop_tables:
base.metadata.drop_all()
base.metadata.create_all(engine) # this explicitly creates the database

return session, base, engine


Expand Down Expand Up @@ -276,6 +315,7 @@ def __init__(
column_type_overrides={},
sqlite_foreign=True,
connection_arguments={},
schema=None,
):
"""
Wrapper for database calls and utility functions
Expand All @@ -301,8 +341,15 @@ def __init__(
Flag to enable/disable use of foreign keys with SQLite. Default: True
connection_arguments : dict
Additional connection arguments, like {'check_same_thread': False}. Default: {}
schema : str
Helper for setting default PostgreSQL schema. Equivalent to connection_arguments={"options": f"-csearch_path={schema}"}
"""

# Helper logic to set default postgres schema, if specified
if connection_string.startswith("postgres") and schema is not None:
if connection_arguments.get("options") is None:
connection_arguments["options"] = f"-csearch_path={schema}"

if connection_string == "sqlite://":
self.session, self.base, self.engine = create_database(connection_string)
else:
Expand Down
Loading
Loading