diff --git a/.github/workflows/su6.yml b/.github/workflows/su6.yml index 9dc78c8..8295571 100644 --- a/.github/workflows/su6.yml +++ b/.github/workflows/su6.yml @@ -11,7 +11,7 @@ jobs: - uses: actions/checkout@v3 - uses: actions/setup-python@v4 with: - python-version: '3.10' + python-version: '3.11' - uses: yezz123/setup-uv@v4 with: uv-venv: ".venv" @@ -25,7 +25,8 @@ jobs: - uses: actions/checkout@v3 - uses: actions/setup-python@v4 with: - python-version: '3.13' + python-version: '3.14' + allow-prereleases: true - uses: yezz123/setup-uv@v4 with: uv-venv: ".venv" diff --git a/.readthedocs.yml b/.readthedocs.yml index f445b75..b57dff2 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -3,7 +3,7 @@ version: 2 build: os: ubuntu-22.04 tools: - python: "3.10" + python: "3.11" mkdocs: configuration: mkdocs.yml diff --git a/pyproject.toml b/pyproject.toml index 680c5bc..acfc3ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,12 +2,15 @@ requires = ["hatchling"] build-backend = "hatchling.build" +[tool.hatch.metadata] +allow-direct-references = true + [project] name = "TypeDAL" dynamic = ["version"] description = 'Typing support for PyDAL' readme = "README.md" -requires-python = ">=3.10" +requires-python = ">=3.11" license-expression = "MIT" keywords = [] authors = [ @@ -16,20 +19,21 @@ authors = [ classifiers = [ "Development Status :: 4 - Beta", "Programming Language :: Python", - "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] dependencies = [ - "pydal <= 20250228.1", # core + "pydal >= 20251012.3", # core "dill < 1", # caching - "configuraptor >= 1.26.2, < 2", # config + "configuraptor >= 1.27.1, < 2", # config "Configurable-JSON < 2", # json dumping "python-slugify < 9", - "legacy-cgi; python_version >= '3.13'" + "legacy-cgi; python_version >= '3.13'", + "python-dateutil < 3", ] [project.optional-dependencies] @@ -38,7 +42,7 @@ py4web = [ ] migrations = [ - "typer", + "typer >=0.18, <0.19", "tabulate", "pydal2sql>=1.2.0", "edwh-migrate>=0.8.0", @@ -48,7 +52,7 @@ migrations = [ all = [ "py4web", - "typer", + "typer >=0.18, <0.19", "tabulate", "pydal2sql[all]>=1.2.0", "edwh-migrate[full]>=0.8.0", @@ -114,7 +118,7 @@ badge = true mypy = "--disable-error-code misc" [tool.black] -target-version = ["py310"] +target-version = ["py313"] line-length = 120 # 'extend-exclude' excludes files or directories in addition to the defaults extend-exclude = ''' @@ -130,6 +134,7 @@ extend-exclude = ''' [tool.coverage.report] exclude_also = [ "if TYPE_CHECKING:", + "if t.TYPE_CHECKING:", "if typing.TYPE_CHECKING:", "except ImportError as e:", "except ImportError:", @@ -156,7 +161,7 @@ strict = true exclude = ["venv", ".bak"] [tool.ruff] -target-version = "py310" +target-version = "py313" line-length = 120 extend-exclude = ["*.bak/", "venv*/"] diff --git a/src/typedal/__init__.py b/src/typedal/__init__.py index 80a8301..1e3960c 100644 --- a/src/typedal/__init__.py +++ b/src/typedal/__init__.py @@ -2,16 +2,15 @@ TypeDAL Library. """ -from . import fields -from .core import ( - Relationship, - TypeDAL, - TypedField, - TypedRows, - TypedTable, - relationship, -) +from .core import TypeDAL +from .fields import TypedField from .helpers import sql_expression +from .query_builder import QueryBuilder +from .relationships import Relationship, relationship +from .rows import TypedRows +from .tables import TypedTable + +from . import fields # isort: skip try: from .for_py4web import DAL as P4W_DAL @@ -19,6 +18,7 @@ P4W_DAL = None # type: ignore __all__ = [ + "QueryBuilder", "Relationship", "TypeDAL", "TypedField", diff --git a/src/typedal/caching.py b/src/typedal/caching.py index 62e93c6..4add6e0 100644 --- a/src/typedal/caching.py +++ b/src/typedal/caching.py @@ -3,27 +3,28 @@ """ import contextlib +import datetime as dt import hashlib import json -import typing -from datetime import datetime, timedelta, timezone -from typing import Any, Iterable, Mapping, Optional, TypeVar +import typing as t import dill # nosec from pydal.objects import Field, Rows, Set -from .core import TypedField, TypedRows, TypedTable +from .fields import TypedField +from .rows import TypedRows +from .tables import TypedTable from .types import Query -if typing.TYPE_CHECKING: +if t.TYPE_CHECKING: from .core import TypeDAL -def get_now(tz: timezone = timezone.utc) -> datetime: +def get_now(tz: dt.timezone = dt.timezone.utc) -> dt.datetime: """ Get the default datetime, optionally in a specific timezone. """ - return datetime.now(tz) + return dt.datetime.now(tz) class _TypedalCache(TypedTable): @@ -33,8 +34,8 @@ class _TypedalCache(TypedTable): key: TypedField[str] data: TypedField[bytes] - cached_at = TypedField(datetime, default=get_now) - expires_at: TypedField[datetime | None] + cached_at = TypedField(dt.datetime, default=get_now) + expires_at: TypedField[dt.datetime | None] class _TypedalCacheDependency(TypedTable): @@ -47,7 +48,7 @@ class _TypedalCacheDependency(TypedTable): idx: TypedField[int] -def prepare(field: Any) -> str: +def prepare(field: t.Any) -> str: """ Prepare data to be used in a cache key. @@ -56,10 +57,10 @@ def prepare(field: Any) -> str: """ if isinstance(field, str): return field - elif isinstance(field, (dict, Mapping)): + elif isinstance(field, (dict, t.Mapping)): data = {str(k): prepare(v) for k, v in field.items()} return json.dumps(data, sort_keys=True) - elif isinstance(field, Iterable): + elif isinstance(field, t.Iterable): return ",".join(sorted([prepare(_) for _ in field])) elif isinstance(field, bool): return str(int(field)) @@ -67,7 +68,7 @@ def prepare(field: Any) -> str: return str(field) -def create_cache_key(*fields: Any) -> str: +def create_cache_key(*fields: t.Any) -> str: """ Turn any fields of data into a string. """ @@ -83,7 +84,7 @@ def hash_cache_key(cache_key: str | bytes) -> str: return h.hexdigest() -def create_and_hash_cache_key(*fields: Any) -> tuple[str, str]: +def create_and_hash_cache_key(*fields: t.Any) -> tuple[str, str]: """ Combine the input fields into one key and hash it with SHA 256. """ @@ -112,7 +113,7 @@ def _get_dependency_ids(rows: Rows, dependency_keys: list[tuple[Field, str]]) -> return dependencies -def _determine_dependencies_auto(_: TypedRows[Any], rows: Rows) -> DependencyTupleSet: +def _determine_dependencies_auto(_: TypedRows[t.Any], rows: Rows) -> DependencyTupleSet: dependency_keys = [] for field in rows.fields: if str(field).endswith(".id"): @@ -123,7 +124,7 @@ def _determine_dependencies_auto(_: TypedRows[Any], rows: Rows) -> DependencyTup return _get_dependency_ids(rows, dependency_keys) -def _determine_dependencies(instance: TypedRows[Any], rows: Rows, depends_on: list[Any]) -> DependencyTupleSet: +def _determine_dependencies(instance: TypedRows[t.Any], rows: Rows, depends_on: list[t.Any]) -> DependencyTupleSet: if not depends_on: return _determine_dependencies_auto(instance, rows) @@ -144,11 +145,11 @@ def _determine_dependencies(instance: TypedRows[Any], rows: Rows, depends_on: li return _get_dependency_ids(rows, dependency_keys) -def remove_cache(idx: int | Iterable[int], table: str) -> None: +def remove_cache(idx: int | t.Iterable[int], table: str) -> None: """ Remove any cache entries that are dependant on one or multiple indices of a table. """ - if not isinstance(idx, Iterable): + if not isinstance(idx, t.Iterable): idx = [idx] related = ( @@ -184,12 +185,14 @@ def _remove_cache(s: Set, tablename: str) -> None: remove_cache(indeces, tablename) -T_TypedTable = TypeVar("T_TypedTable", bound=TypedTable) +T_TypedTable = t.TypeVar("T_TypedTable", bound=TypedTable) def get_expire( - expires_at: Optional[datetime] = None, ttl: Optional[int | timedelta] = None, now: Optional[datetime] = None -) -> datetime | None: + expires_at: t.Optional[dt.datetime] = None, + ttl: t.Optional[int | dt.timedelta] = None, + now: t.Optional[dt.datetime] = None, +) -> dt.datetime | None: """ Based on an expires_at date or a ttl (in seconds or a time delta), determine the expire date. """ @@ -197,10 +200,10 @@ def get_expire( if expires_at and ttl: raise ValueError("Please only supply an `expired at` date or a `ttl` in seconds!") - elif isinstance(ttl, timedelta): + elif isinstance(ttl, dt.timedelta): return now + ttl elif ttl: - return now + timedelta(seconds=ttl) + return now + dt.timedelta(seconds=ttl) elif expires_at: return expires_at @@ -210,8 +213,8 @@ def get_expire( def save_to_cache( instance: TypedRows[T_TypedTable], rows: Rows, - expires_at: Optional[datetime] = None, - ttl: Optional[int | timedelta] = None, + expires_at: t.Optional[dt.datetime] = None, + ttl: t.Optional[int | dt.timedelta] = None, ) -> TypedRows[T_TypedTable]: """ Save a typedrows result to the database, and save dependencies from rows. @@ -237,13 +240,13 @@ def save_to_cache( return instance -def _load_from_cache(key: str, db: "TypeDAL") -> Any | None: +def _load_from_cache(key: str, db: "TypeDAL") -> t.Any | None: if not (row := _TypedalCache.where(key=key).first()): return None now = get_now() - expires = row.expires_at.replace(tzinfo=timezone.utc) if row.expires_at else None + expires = row.expires_at.replace(tzinfo=dt.timezone.utc) if row.expires_at else None if expires and now >= expires: row.delete_record() @@ -261,7 +264,7 @@ def _load_from_cache(key: str, db: "TypeDAL") -> Any | None: return inst -def load_from_cache(key: str, db: "TypeDAL") -> Any | None: +def load_from_cache(key: str, db: "TypeDAL") -> t.Any | None: """ If 'key' matches a non-expired row in the database, try to load the dill. @@ -302,10 +305,10 @@ def _expired_and_valid_query() -> tuple[str, str]: return expired_items, valid_items -T = typing.TypeVar("T") -Stats = typing.TypedDict("Stats", {"total": T, "valid": T, "expired": T}) +T = t.TypeVar("T") +Stats = t.TypedDict("Stats", {"total": T, "valid": T, "expired": T}) -RowStats = typing.TypedDict( +RowStats = t.TypedDict( "RowStats", { "Dependent Cache Entries": int, @@ -338,7 +341,7 @@ def row_stats(db: "TypeDAL", table: str, row_id: str) -> Stats[RowStats]: } -TableStats = typing.TypedDict( +TableStats = t.TypedDict( "TableStats", { "Dependent Cache Entries": int, @@ -371,7 +374,7 @@ def table_stats(db: "TypeDAL", table: str) -> Stats[TableStats]: } -GenericStats = typing.TypedDict( +GenericStats = t.TypedDict( "GenericStats", { "entries": int, diff --git a/src/typedal/config.py b/src/typedal/config.py index bd68cc7..26bf2d8 100644 --- a/src/typedal/config.py +++ b/src/typedal/config.py @@ -4,11 +4,10 @@ import os import re -import typing +import typing as t import warnings from collections import defaultdict from pathlib import Path -from typing import Any, Optional import tomli from configuraptor import TypedConfig, alias @@ -17,7 +16,7 @@ from .types import AnyDict -if typing.TYPE_CHECKING: +if t.TYPE_CHECKING: from edwh_migrate import Config as MigrateConfig from pydal2sql.typer_support import Config as P2SConfig @@ -41,15 +40,15 @@ class TypeDALConfig(TypedConfig): output: str = "" noop: bool = False magic: bool = True - tables: Optional[list[str]] = None + tables: t.Optional[list[str]] = None function: str = "define_tables" # edwh-migrate: # migrate uri = database - database_to_restore: Optional[str] - migrate_cat_command: Optional[str] - schema_version: Optional[str] - redis_host: Optional[str] + database_to_restore: t.Optional[str] + migrate_cat_command: t.Optional[str] + schema_version: t.Optional[str] + redis_host: t.Optional[str] migrate_table: str = "typedal_implemented_features" flag_location: str create_flag_location: bool = True @@ -148,7 +147,7 @@ def _load_toml(path: str | bool | Path | None = True) -> tuple[str, AnyDict]: with open(toml_path, "rb") as f: data = tomli.load(f) - return str(toml_path) or "", typing.cast(AnyDict, data["tool"]["typedal"]) + return str(toml_path) or "", t.cast(AnyDict, data["tool"]["typedal"]) except Exception as e: warnings.warn(f"Could not load typedal config toml: {e}", source=e) return str(toml_path) or "", {} @@ -194,7 +193,7 @@ def get_db_for_alias(db_name: str) -> str: return DB_ALIASES.get(db_name, db_name) -DEFAULTS: dict[str, Any | typing.Callable[[AnyDict], Any]] = { +DEFAULTS: dict[str, t.Any | t.Callable[[AnyDict], t.Any]] = { "database": lambda data: data.get("db_uri") or "sqlite:memory", "dialect": lambda data: ( get_db_for_alias(data["database"].split(":")[0]) if ":" in data["database"] else data.get("db_type") @@ -208,7 +207,7 @@ def get_db_for_alias(db_name: str) -> str: } -def _fill_defaults(data: AnyDict, prop: str, fallback: Any = None) -> None: +def _fill_defaults(data: AnyDict, prop: str, fallback: t.Any = None) -> None: default = DEFAULTS.get(prop, fallback) if callable(default): default = default(data) @@ -223,7 +222,7 @@ def fill_defaults(data: AnyDict, prop: str) -> None: _fill_defaults(data, prop) -TRANSFORMS: dict[str, typing.Callable[[AnyDict], Any]] = { +TRANSFORMS: dict[str, t.Callable[[AnyDict], t.Any]] = { "database": lambda data: ( data["database"] if (":" in data["database"] or not data.get("dialect")) @@ -264,7 +263,7 @@ def expand_posix_vars(posix_expr: str, context: dict[str, str]) -> str: # Regular expression to match "${VAR:default}" pattern pattern = r"\$\{([^}]+)\}" - def replace_var(match: re.Match[Any]) -> str: + def replace_var(match: re.Match[t.Any]) -> str: var_with_default = match.group(1) var_name, default_value = var_with_default.split(":") if ":" in var_with_default else (var_with_default, "") return env.get(var_name.lower(), default_value) @@ -325,10 +324,10 @@ def expand_env_vars_into_toml_values(toml: AnyDict, env: AnyDict) -> None: def load_config( - connection_name: Optional[str] = None, + connection_name: t.Optional[str] = None, _use_pyproject: bool | str | None = True, _use_env: bool | str | None = True, - **fallback: Any, + **fallback: t.Any, ) -> TypeDALConfig: """ Combines multiple sources of config into one config instance. @@ -338,7 +337,7 @@ def load_config( # combine and fill with fallback values # load typedal config or fail toml_path, toml = _load_toml(_use_pyproject) - dotenv_path, dotenv = _load_dotenv(_use_env) + _dotenv_path, dotenv = _load_dotenv(_use_env) expand_env_vars_into_toml_values(toml, dotenv) diff --git a/src/typedal/constants.py b/src/typedal/constants.py new file mode 100644 index 0000000..97791a0 --- /dev/null +++ b/src/typedal/constants.py @@ -0,0 +1,25 @@ +""" +Constants values. +""" + +import datetime as dt +import typing as t +from decimal import Decimal + +from .types import T_annotation + +JOIN_OPTIONS = t.Literal["left", "inner", None] +DEFAULT_JOIN_OPTION: JOIN_OPTIONS = "left" + +BASIC_MAPPINGS: dict[T_annotation, str] = { + str: "string", + int: "integer", + bool: "boolean", + bytes: "blob", + float: "double", + object: "json", + Decimal: "decimal(10,2)", + dt.date: "date", + dt.time: "time", + dt.datetime: "datetime", +} diff --git a/src/typedal/core.py b/src/typedal/core.py index f0865f2..6f67653 100644 --- a/src/typedal/core.py +++ b/src/typedal/core.py @@ -4,387 +4,138 @@ from __future__ import annotations -import contextlib -import copy -import csv -import datetime as dt -import functools -import inspect -import json -import math -import re import sys -import types -import typing -import uuid +import typing as t import warnings -from collections import defaultdict -from decimal import Decimal from pathlib import Path -from typing import Any, Optional, Type +from typing import Optional import pydal -from pydal._globals import DEFAULT - -# from pydal.objects import Field as _Field -# from pydal.objects import Query as _Query -from pydal.objects import Row - -# from pydal.objects import Table as _Table -from typing_extensions import Self, Unpack from .config import TypeDALConfig, load_config from .helpers import ( - DummyQuery, - all_annotations, - all_dict, - as_lambda, - classproperty, - extract_type_optional, - filter_out, - instanciate, - is_union, - looks_like, - mktable, - origin_is_subclass, + SYSTEM_SUPPORTS_TEMPLATES, + default_representer, + sql_escape_template, sql_expression, to_snake, - unwrap_type, ) -from .serializers import as_json -from .types import ( - AnyDict, - CacheMetadata, - Expression, - Field, - FieldSettings, - Metadata, - OpRow, - OrderBy, - PaginateDict, - Pagination, - Query, - Reference, - Rows, - SelectKwargs, - Set, - Table, - Validator, - _Types, -) - -# use typing.cast(type, ...) to make mypy happy with unions -T_annotation = Type[Any] | types.UnionType -T_Query = typing.Union["Table", Query, bool, None, "TypedTable", Type["TypedTable"], Expression] -T_Value = typing.TypeVar("T_Value") # actual type of the Field (via Generic) -T_MetaInstance = typing.TypeVar("T_MetaInstance", bound="TypedTable") # bound="TypedTable"; bound="TableMeta" -T = typing.TypeVar("T") - -BASIC_MAPPINGS: dict[T_annotation, str] = { - str: "string", - int: "integer", - bool: "boolean", - bytes: "blob", - float: "double", - object: "json", - Decimal: "decimal(10,2)", - dt.date: "date", - dt.time: "time", - dt.datetime: "datetime", -} - - -def is_typed_field(cls: Any) -> typing.TypeGuard["TypedField[Any]"]: - """ - Is `cls` an instance or subclass of TypedField? - - Deprecated - """ - return isinstance(cls, TypedField) or ( - isinstance(typing.get_origin(cls), type) and issubclass(typing.get_origin(cls), TypedField) - ) +from .types import Field, T, Template # type: ignore +try: + # python 3.14+ + from annotationlib import ForwardRef +except ImportError: # pragma: no cover + # python 3.13- + from typing import ForwardRef -JOIN_OPTIONS = typing.Literal["left", "inner", None] -DEFAULT_JOIN_OPTION: JOIN_OPTIONS = "left" +if t.TYPE_CHECKING: + from .fields import TypedField + from .types import AnyDict, Expression, T_Query, Table -# table-ish paramter: -P_Table = typing.Union[Type["TypedTable"], pydal.objects.Table] -Condition: typing.TypeAlias = typing.Optional[ - typing.Callable[ - # self, other -> Query - [P_Table, P_Table], - Query | bool, - ] -] +# note: these functions can not be moved to a different file, +# because then they will have different globals and it breaks! -OnQuery: typing.TypeAlias = typing.Optional[ - typing.Callable[ - # self, other -> list of .on statements - [P_Table, P_Table], - list[Expression], - ] -] -# To_Type = typing.TypeVar("To_Type", type[Any], Type[Any], str) -To_Type = typing.TypeVar("To_Type") - - -class Relationship(typing.Generic[To_Type]): +def evaluate_forward_reference_312(fw_ref: ForwardRef, namespace: dict[str, type]) -> type: # pragma: no cover """ - Define a relationship to another table. - """ - - _type: Type[To_Type] - table: Type["TypedTable"] | type | str - condition: Condition - condition_and: Condition - on: OnQuery - multiple: bool - join: JOIN_OPTIONS - - def __init__( - self, - _type: Type[To_Type], - condition: Condition = None, - join: JOIN_OPTIONS = None, - on: OnQuery = None, - condition_and: Condition = None, - ): - """ - Should not be called directly, use relationship() instead! - """ - if condition and on: - warnings.warn(f"Relation | Both specified! {condition=} {on=} {_type=}") - raise ValueError("Please specify either a condition or an 'on' statement for this relationship!") - - self._type = _type - self.condition = condition - self.join = "left" if on else join # .on is always left join! - self.on = on - self.condition_and = condition_and - - if args := typing.get_args(_type): - self.table = unwrap_type(args[0]) - self.multiple = True - else: - self.table = typing.cast(type[TypedTable], _type) - self.multiple = False - - if isinstance(self.table, str): - self.table = TypeDAL.to_snake(self.table) - - def clone(self, **update: Any) -> "Relationship[To_Type]": - """ - Create a copy of the relationship, possibly updated. - """ - return self.__class__( - update.get("_type") or self._type, - update.get("condition") or self.condition, - update.get("join") or self.join, - update.get("on") or self.on, - update.get("condition_and") or self.condition_and, - ) - - def __repr__(self) -> str: - """ - Representation of the relationship. - """ - if callback := self.condition or self.on: - src_code = inspect.getsource(callback).strip() - - if c_and := self.condition_and: - and_code = inspect.getsource(c_and).strip() - src_code += " AND " + and_code - else: - cls_name = self._type if isinstance(self._type, str) else self._type.__name__ - src_code = f"to {cls_name} (missing condition)" - - join = f":{self.join}" if self.join else "" - return f"" - - def get_table(self, db: "TypeDAL") -> Type["TypedTable"]: - """ - Get the table this relationship is bound to. - """ - table = self.table # can be a string because db wasn't available yet - - if isinstance(table, str): - if mapped := db._class_map.get(table): - # yay - return mapped - - # boo, fall back to untyped table but pretend it is typed: - return typing.cast(Type["TypedTable"], db[table]) # eh close enough! - - return table - - def get_table_name(self) -> str: - """ - Get the name of the table this relationship is bound to. - """ - if isinstance(self.table, str): - return self.table - - if isinstance(self.table, pydal.objects.Table): - return str(self.table) - - # else: typed table - try: - table = self.table._ensure_table_defined() if issubclass(self.table, TypedTable) else self.table - except Exception: # pragma: no cover - table = self.table - - return str(table) - - def __get__(self, instance: Any, owner: Any) -> "typing.Optional[list[Any]] | Relationship[To_Type]": - """ - Relationship is a descriptor class, which can be returned from a class but not an instance. - - For an instance, using .join() will replace the Relationship with the actual data. - If you forgot to join, a warning will be shown and empty data will be returned. - """ - if not instance: - # relationship queried on class, that's allowed - return self + Extract the original type from a forward reference string. - warnings.warn( - "Trying to get data from a relationship object! Did you forget to join it?", - category=RuntimeWarning, - ) - if self.multiple: - return [] - else: - return None - - -def relationship( - _type: typing.Type[To_Type], - condition: Condition = None, - join: JOIN_OPTIONS = None, - on: OnQuery = None, -) -> To_Type: + Variant for python 3.12 and below """ - Define a relationship to another table, when its id is not stored in the current table. - - Example: - class User(TypedTable): - name: str - - posts = relationship(list["Post"], condition=lambda self, post: self.id == post.author, join='left') - - class Post(TypedTable): - title: str - author: User - - User.join("posts").first() # User instance with list[Post] in .posts - - Here, Post stores the User ID, but `relationship(list["Post"])` still allows you to get the user's posts. - In this case, the join strategy is set to LEFT so users without posts are also still selected. + return t.cast( + type, + fw_ref._evaluate( + localns=locals(), + globalns=globals() | namespace, + recursive_guard=frozenset(), + ), + ) - For complex queries with a pivot table, a `on` can be set insteaad of `condition`: - class User(TypedTable): - ... - tags = relationship(list["Tag"], on=lambda self, tag: [ - Tagged.on(Tagged.entity == entity.gid), - Tag.on((Tagged.tag == tag.id)), - ]) +def evaluate_forward_reference_313(fw_ref: ForwardRef, namespace: dict[str, type]) -> type: # pragma: no cover + """ + Extract the original type from a forward reference string. - If you'd try to capture this in a single 'condition', pydal would create a cross join which is much less efficient. + Variant for python 3.13 """ - return typing.cast( - # note: The descriptor `Relationship[To_Type]` is more correct, but pycharm doesn't really get that. - # so for ease of use, just cast to the refered type for now! - # e.g. x = relationship(Author) -> x: Author - To_Type, - Relationship(_type, condition, join, on), + return t.cast( + type, + fw_ref._evaluate( + localns=locals(), + globalns=globals() | namespace, + recursive_guard=frozenset(), + type_params=(), # suggested since 3.13 (warning) and not supported before. Mandatory after 1.15! + ), ) -T_Field: typing.TypeAlias = typing.Union["TypedField[Any]", "Table", Type["TypedTable"]] +def evaluate_forward_reference_314(fw_ref: ForwardRef, namespace: dict[str, type]) -> type: # pragma: no cover + """ + Extract the original type from a forward reference string. + Variant for python 3.14 (and hopefully above) + """ + return t.cast( + type, + fw_ref.evaluate( + locals=locals(), + globals=globals() | namespace, + type_params=(), + ), + ) -def _generate_relationship_condition(_: Type["TypedTable"], key: str, field: T_Field) -> Condition: - origin = typing.get_origin(field) - # else: generic - if origin is list: - # field = typing.get_args(field)[0] # actual field - # return lambda _self, _other: cls[key].contains(field) +def evaluate_forward_reference( + fw_ref: ForwardRef, + namespace: dict[str, type] | None = None, +) -> type: # pragma: no cover + """ + Extract the original type from a forward reference string. - return lambda _self, _other: _self[key].contains(_other.id) + Automatically chooses strategy based on current Python version. + """ + if sys.version_info.minor < 13: + return evaluate_forward_reference_312(fw_ref, namespace=namespace or {}) + elif sys.version_info.minor == 13: + return evaluate_forward_reference_313(fw_ref, namespace=namespace or {}) else: - # normal reference - # return lambda _self, _other: cls[key] == field.id - return lambda _self, _other: _self[key] == _other.id + return evaluate_forward_reference_314(fw_ref, namespace=namespace or {}) -def to_relationship( - cls: Type["TypedTable"] | type[Any], - key: str, - field: T_Field, -) -> typing.Optional[Relationship[Any]]: +def resolve_annotation_313(ftype: str) -> type: # pragma: no cover """ - Used to automatically create relationship instance for reference fields. - - Example: - class MyTable(TypedTable): - reference: OtherTable - - `reference` contains the id of an Other Table row. - MyTable.relationships should have 'reference' as a relationship, so `MyTable.join('reference')` should work. + Resolve an annotation that's in string representation. - This function will automatically perform this logic (called in db.define): - to_relationship(MyTable, 'reference', OtherTable) -> Relationship[OtherTable] - - Also works for list:reference (list[OtherTable]) and TypedField[OtherTable]. + Variant for Python 3.13 """ - if looks_like(field, TypedField): - # typing.get_args works for list[str] but not for TypedField[role] :( - if args := typing.get_args(field): - # TypedField[SomeType] -> SomeType - field = args[0] - elif hasattr(field, "_type"): - # TypedField(SomeType) -> SomeType - field = typing.cast(T_Field, field._type) - else: # pragma: no cover - # weird - return None - - field, optional = extract_type_optional(field) + fw_ref: ForwardRef = t.get_args(t.Type[ftype])[0] + return evaluate_forward_reference(fw_ref) - try: - condition = _generate_relationship_condition(cls, key, field) - except Exception as e: # pragma: no cover - warnings.warn("Could not generate Relationship condition", source=e) - condition = None - if not condition: # pragma: no cover - # something went wrong, not a valid relationship - warnings.warn(f"Invalid relationship for {cls.__name__}.{key}: {field}") - return None - - join = "left" if optional or typing.get_origin(field) is list else "inner" +def resolve_annotation_314(ftype: str) -> type: # pragma: no cover + """ + Resolve an annotation that's in string representation. - return Relationship(typing.cast(type[TypedTable], field), condition, typing.cast(JOIN_OPTIONS, join)) + Variant for Python 3.14 + using annotationlib + """ + fw_ref = ForwardRef(ftype) + return evaluate_forward_reference(fw_ref) -def evaluate_forward_reference(fw_ref: typing.ForwardRef) -> type: - """ - Extract the original type from a forward reference string. +def resolve_annotation(ftype: str) -> type: # pragma: no cover """ - kwargs = dict( - localns=locals(), - globalns=globals(), - recursive_guard=frozenset(), - ) - if sys.version_info >= (3, 13): # pragma: no cover - # suggested since 3.13 (warning) and not supported before. Mandatory after 1.15! - kwargs["type_params"] = () + Resolve an annotation that's in string representation. - return fw_ref._evaluate(**kwargs) # type: ignore + Automatically chooses strategy based on current Python version. + """ + if sys.version_info.major != 3: + raise EnvironmentError("Only python 3 is supported.") + elif sys.version_info.minor <= 13: + return resolve_annotation_313(ftype) + else: + return resolve_annotation_314(ftype) class TypeDAL(pydal.DAL): # type: ignore @@ -393,6 +144,7 @@ class TypeDAL(pydal.DAL): # type: ignore """ _config: TypeDALConfig + _builder: TableDefinitionBuilder def __init__( self, @@ -414,7 +166,7 @@ def __init__( debug: bool = False, lazy_tables: bool = False, db_uid: Optional[str] = None, - after_connection: typing.Callable[..., Any] = None, + after_connection: t.Callable[..., t.Any] = None, tables: Optional[list[str]] = None, ignore_field_case: bool = True, entity_quoting: bool = True, @@ -443,6 +195,7 @@ def __init__( self._config = config self.db = self + self._builder = TableDefinitionBuilder(self) if config.folder: Path(config.folder).mkdir(exist_ok=True) @@ -477,7 +230,7 @@ def __init__( self.try_define(_TypedalCache) self.try_define(_TypedalCacheDependency) - def try_define(self, model: Type[T], verbose: bool = False) -> Type[T]: + def try_define(self, model: t.Type[T], verbose: bool = False) -> t.Type[T]: """ Try to define a model with migrate or fall back to fake migrate. """ @@ -495,125 +248,13 @@ def try_define(self, model: Type[T], verbose: bool = False) -> Type[T]: # try again: return self.define(model, migrate=True, fake_migrate=True, redefine=True) - default_kwargs: typing.ClassVar[AnyDict] = { + default_kwargs: t.ClassVar[AnyDict] = { # fields are 'required' (notnull) by default: "notnull": True, } - # maps table name to typedal class, for resolving future references - _class_map: typing.ClassVar[dict[str, Type["TypedTable"]]] = {} - - def _define(self, cls: Type[T], **kwargs: Any) -> Type[T]: - # todo: new relationship item added should also invalidate (previously unrelated) cache result - - # todo: option to enable/disable cache dependency behavior: - # - don't set _before_update and _before_delete - # - don't add TypedalCacheDependency entry - # - don't invalidate other item on new row of this type - - # when __future__.annotations is implemented, cls.__annotations__ will not work anymore as below. - # proper way to handle this would be (but gives error right now due to Table implementing magic methods): - # typing.get_type_hints(cls, globalns=None, localns=None) - # -> ERR e.g. `pytest -svxk cli` -> name 'BestFriend' is not defined - - # dirty way (with evil eval): - # [eval(v) for k, v in cls.__annotations__.items()] - # this however also stops working when variables outside this scope or even references to other - # objects are used. So for now, this package will NOT work when from __future__ import annotations is used, - # and might break in the future, when this annotations behavior is enabled by default. - - # non-annotated variables have to be passed to define_table as kwargs - full_dict = all_dict(cls) # includes properties from parents (e.g. useful for mixins) - - tablename = self.to_snake(cls.__name__) - # grab annotations of cls and it's parents: - annotations = all_annotations(cls) - # extend with `prop = TypedField()` 'annotations': - annotations |= {k: typing.cast(type, v) for k, v in full_dict.items() if is_typed_field(v)} - # remove internal stuff: - annotations = {k: v for k, v in annotations.items() if not k.startswith("_")} - - typedfields: dict[str, TypedField[Any]] = { - k: instanciate(v, True) for k, v in annotations.items() if is_typed_field(v) - } - - relationships: dict[str, type[Relationship[Any]]] = filter_out(annotations, Relationship) - - fields = {fname: self._to_field(fname, ftype) for fname, ftype in annotations.items()} - - # ! dont' use full_dict here: - other_kwargs = kwargs | { - k: v for k, v in cls.__dict__.items() if k not in annotations and not k.startswith("_") - } # other_kwargs was previously used to pass kwargs to typedal, but use @define(**kwargs) for that. - # now it's only used to extract relationships from the object. - # other properties of the class (incl methods) should not be touched - - # for key in typedfields.keys() - full_dict.keys(): - # # typed fields that don't haven't been added to the object yet - # setattr(cls, key, typedfields[key]) - - for key, field in typedfields.items(): - # clone every property so it can be re-used across mixins: - clone = copy.copy(field) - setattr(cls, key, clone) - typedfields[key] = clone - - # start with base classes and overwrite with current class: - relationships = filter_out(full_dict, Relationship) | relationships | filter_out(other_kwargs, Relationship) - - # DEPRECATED: Relationship as annotation is currently not supported! - # ensure they are all instances and - # not mix of instances (`= relationship()`) and classes (`: Relationship[...]`): - # relationships = { - # k: v if isinstance(v, Relationship) else to_relationship(cls, k, v) for k, v in relationships.items() - # } - - # keys of implicit references (also relationships): - reference_field_keys = [ - k for k, v in fields.items() if str(v.type).split(" ")[0] in ("list:reference", "reference") - ] - - # add implicit relationships: - # User; list[User]; TypedField[User]; TypedField[list[User]]; TypedField(User); TypedField(list[User]) - relationships |= { - k: new_relationship - for k in reference_field_keys - if k not in relationships and (new_relationship := to_relationship(cls, k, annotations[k])) - } - - # fixme: list[Reference] is recognized as relationship, - # TypedField(list[Reference]) is NOT recognized!!! - - cache_dependency = self._config.caching and kwargs.pop("cache_dependency", True) - - table: Table = self.define_table(tablename, *fields.values(), **kwargs) - - for name, typed_field in typedfields.items(): - field = fields[name] - typed_field.bind(field, table) - - if issubclass(cls, TypedTable): - cls.__set_internals__( - db=self, - table=table, - # by now, all relationships should be instances! - relationships=typing.cast(dict[str, Relationship[Any]], relationships), - ) - # map both name and rname: - self._class_map[str(table)] = cls - self._class_map[table._rname] = cls - cls.__on_define__(self) - else: - warnings.warn("db.define used without inheriting TypedTable. This could lead to strange problems!") - - if not tablename.startswith("typedal_") and cache_dependency: - table._before_update.append(lambda s, _: _remove_cache(s, tablename)) - table._before_delete.append(lambda s: _remove_cache(s, tablename)) - - return cls - - @typing.overload - def define(self, maybe_cls: None = None, **kwargs: Any) -> typing.Callable[[Type[T]], Type[T]]: + @t.overload + def define(self, maybe_cls: None = None, **kwargs: t.Any) -> t.Callable[[t.Type[T]], t.Type[T]]: """ Typing Overload for define without a class. @@ -621,8 +262,8 @@ def define(self, maybe_cls: None = None, **kwargs: Any) -> typing.Callable[[Type class MyTable(TypedTable): ... """ - @typing.overload - def define(self, maybe_cls: Type[T], **kwargs: Any) -> Type[T]: + @t.overload + def define(self, maybe_cls: t.Type[T], **kwargs: t.Any) -> t.Type[T]: """ Typing Overload for define with a class. @@ -630,7 +271,11 @@ def define(self, maybe_cls: Type[T], **kwargs: Any) -> Type[T]: class MyTable(TypedTable): ... """ - def define(self, maybe_cls: Type[T] | None = None, **kwargs: Any) -> Type[T] | typing.Callable[[Type[T]], Type[T]]: + def define( + self, + maybe_cls: t.Type[T] | None = None, + **kwargs: t.Any, + ) -> t.Type[T] | t.Callable[[t.Type[T]], t.Type[T]]: """ Can be used as a decorator on a class that inherits `TypedTable`, \ or as a regular method if you need to define your classes before you have access to a 'db' instance. @@ -653,39 +298,15 @@ class Article(TypedTable): the result of pydal.define_table """ - def wrapper(cls: Type[T]) -> Type[T]: - return self._define(cls, **kwargs) + def wrapper(cls: t.Type[T]) -> t.Type[T]: + return self._builder.define(cls, **kwargs) if maybe_cls: return wrapper(maybe_cls) return wrapper - # def drop(self, table_name: str) -> None: - # """ - # Remove a table by name (both on the database level and the typedal level). - # """ - # # drop calls TypedTable.drop() and removes it from the `_class_map` - # if cls := self._class_map.pop(table_name, None): - # cls.drop() - - # def drop_all(self, max_retries: int = None) -> None: - # """ - # Remove all tables and keep doing so until everything is gone! - # """ - # retries = 0 - # if max_retries is None: - # max_retries = len(self.tables) - # - # while self.tables: - # retries += 1 - # for table in self.tables: - # self.drop(table) - # - # if retries > max_retries: - # raise RuntimeError("Could not delete all tables") - - def __call__(self, *_args: T_Query, **kwargs: Any) -> "TypedSet": + def __call__(self, *_args: T_Query, **kwargs: t.Any) -> "TypedSet": """ A db instance can be called directly to perform a query. @@ -703,11 +324,11 @@ def __call__(self, *_args: T_Query, **kwargs: Any) -> "TypedSet": if isinstance(cls, type) and issubclass(type(cls), type) and issubclass(cls, TypedTable): # table defined without @db.define decorator! - _cls: Type[TypedTable] = cls + _cls: t.Type[TypedTable] = cls args[0] = _cls.id != None _set = super().__call__(*args, **kwargs) - return typing.cast(TypedSet, _set) + return t.cast(TypedSet, _set) def __getitem__(self, key: str) -> "Table": """ @@ -718,9 +339,9 @@ def __getitem__(self, key: str) -> "Table": Example: db['users'] -> user """ - return typing.cast(Table, super().__getitem__(str(key))) + return t.cast(Table, super().__getitem__(str(key))) - def find_model(self, table_name: str) -> Type["TypedTable"] | None: + def find_model(self, table_name: str) -> t.Type["TypedTable"] | None: """ Retrieves a mapped table class by its name. @@ -735,96 +356,12 @@ def find_model(self, table_name: str) -> Type["TypedTable"] | None: Returns: The mapped table class if it exists, otherwise None. """ - return self._class_map.get(table_name, None) - - @classmethod - def _build_field(cls, name: str, _type: str, **kw: Any) -> Field: - # return Field(name, _type, **{**cls.default_kwargs, **kw}) - kw_combined = cls.default_kwargs | kw - return Field(name, _type, **kw_combined) - - @classmethod - def _annotation_to_pydal_fieldtype( - cls, - _ftype: T_annotation, - mut_kw: typing.MutableMapping[str, Any], - ) -> Optional[str]: - # ftype can be a union or type. typing.cast is sometimes used to tell mypy when it's not a union. - ftype = typing.cast(type, _ftype) # cast from Type to type to make mypy happy) - - if isinstance(ftype, str): - # extract type from string - fw_ref: typing.ForwardRef = typing.get_args(Type[ftype])[0] - ftype = evaluate_forward_reference(fw_ref) - - if mapping := BASIC_MAPPINGS.get(ftype): - # basi types - return mapping - elif isinstance(ftype, pydal.objects.Table): - # db.table - return f"reference {ftype._tablename}" - elif issubclass(type(ftype), type) and issubclass(ftype, TypedTable): - # SomeTable - snakename = cls.to_snake(ftype.__name__) - return f"reference {snakename}" - elif isinstance(ftype, TypedField): - # FieldType(type, ...) - return ftype._to_field(mut_kw) - elif origin_is_subclass(ftype, TypedField): - # TypedField[int] - return cls._annotation_to_pydal_fieldtype(typing.get_args(ftype)[0], mut_kw) - elif isinstance(ftype, types.GenericAlias) and typing.get_origin(ftype) in (list, TypedField): - # list[str] -> str -> string -> list:string - _child_type = typing.get_args(ftype)[0] - _child_type = cls._annotation_to_pydal_fieldtype(_child_type, mut_kw) - return f"list:{_child_type}" - elif is_union(ftype): - # str | int -> UnionType - # typing.Union[str | int] -> typing._UnionGenericAlias - - # Optional[type] == type | None - - match typing.get_args(ftype): - case (_child_type, _Types.NONETYPE) | (_Types.NONETYPE, _child_type): - # good union of Nullable - - # if a field is optional, it is nullable: - mut_kw["notnull"] = False - return cls._annotation_to_pydal_fieldtype(_child_type, mut_kw) - case _: - # two types is not supported by the db! - return None - else: - return None - - @classmethod - def _to_field(cls, fname: str, ftype: type, **kw: Any) -> Field: - """ - Convert a annotation into a pydal Field. - - Args: - fname: name of the property - ftype: annotation of the property - kw: when using TypedField or a function returning it (e.g. StringField), - keyword args can be used to pass any other settings you would normally to a pydal Field - - -> pydal.Field(fname, ftype, **kw) - - Example: - class MyTable: - fname: ftype - id: int - name: str - reference: Table - other: TypedField(str, default="John Doe") # default will be in kwargs - """ - fname = cls.to_snake(fname) + return self._builder.class_map.get(table_name, None) - # note: 'kw' is updated in `_annotation_to_pydal_fieldtype` by the kwargs provided to the TypedField(...) - if converted_type := cls._annotation_to_pydal_fieldtype(ftype, kw): - return cls._build_field(fname, converted_type, **kw) - else: - raise NotImplementedError(f"Unsupported type {ftype}/{type(ftype)}") + @property + def _class_map(self) -> dict[str, t.Type["TypedTable"]]: + # alias for backward-compatibility + return self._builder.class_map @staticmethod def to_snake(camel: str) -> str: @@ -833,18 +370,71 @@ def to_snake(camel: str) -> str: """ return to_snake(camel) + def executesql( + self, + query: str | Template, + placeholders: t.Iterable[str] | dict[str, str] | None = None, + as_dict: bool = False, + fields: t.Iterable[Field | TypedField[t.Any]] | None = None, + colnames: t.Iterable[str] | None = None, + as_ordered_dict: bool = False, + ) -> list[t.Any]: + """ + Executes a raw SQL statement or a TypeDAL template query. + + If `query` is provided as a `Template` and the system supports template + rendering, it will be processed with `sql_escape_template` before being + executed. Otherwise, the query is passed to the underlying DAL as-is. + + Args: + query (str | Template): The SQL query to execute, either a plain + string or a `Template` (created via the `t""` syntax). + placeholders (Iterable[str] | dict[str, str] | None, optional): + Parameters to substitute into the SQL statement. Can be a sequence + (for positional parameters) or a dictionary (for named parameters). + Usually not applicable when using a t-string, since template + expressions handle interpolation directly. + as_dict (bool, optional): If True, return rows as dictionaries keyed by + column name. Defaults to False. + fields (Iterable[Field | TypedField] | None, optional): Explicit set of + fields to map results onto. Defaults to None. + colnames (Iterable[str] | None, optional): Explicit column names to use + in the result set. Defaults to None. + as_ordered_dict (bool, optional): If True, return rows as `OrderedDict`s + preserving column order. Defaults to False. + + Returns: + list[t.Any]: The query result set. Typically a list of tuples if + `as_dict` and `as_ordered_dict` are False, or a list of dict-like + objects if those flags are enabled. + """ + if SYSTEM_SUPPORTS_TEMPLATES and isinstance(query, Template): # pragma: no cover + query = sql_escape_template(self, query) + + rows: list[t.Any] = super().executesql( + query, + placeholders=placeholders, + as_dict=as_dict, + fields=fields, + colnames=colnames, + as_ordered_dict=as_ordered_dict, + ) + + return rows + def sql_expression( self, - sql_fragment: str, - *raw_args: Any, + sql_fragment: str | Template, + *raw_args: t.Any, output_type: str | None = None, - **raw_kwargs: Any, + **raw_kwargs: t.Any, ) -> Expression: """ Creates a pydal Expression object representing a raw SQL fragment. Args: sql_fragment: The raw SQL fragment. + In python 3.14+, this can also be a t-string. In that case, don't pass other args or kwargs. *raw_args: Arguments to be interpolated into the SQL fragment. output_type: The expected output type of the expression. **raw_kwargs: Keyword arguments to be interpolated into the SQL fragment. @@ -855,2598 +445,16 @@ def sql_expression( return sql_expression(self, sql_fragment, *raw_args, output_type=output_type, **raw_kwargs) -def default_representer(field: TypedField[T], value: T, table: Type[TypedTable]) -> str: - """ - Simply call field.represent on the value. - """ - if represent := getattr(field, "represent", None): - return str(represent(value, table)) - else: - return repr(value) - - TypeDAL.representers.setdefault("rows_render", default_representer) -P = typing.ParamSpec("P") -R = typing.TypeVar("R") - - -def reorder_fields( - table: pydal.objects.Table, - fields: typing.Iterable[str | Field | TypedField[Any]], - keep_others: bool = True, -) -> None: - """ - Reorder fields of a pydal table. - - Args: - table: The pydal table object (e.g., db.mytable). - fields: List of field names (str) or Field objects in desired order. - keep_others (bool): - - True (default): keep other fields at the end, in their original order. - - False: remove other fields (only keep what's specified). - """ - # Normalize input to field names - desired = [f.name if isinstance(f, (TypedField, Field, pydal.objects.Field)) else str(f) for f in fields] - - new_order = [f for f in desired if f in table._fields] - - if keep_others: - # Start with desired fields, then append the rest - new_order.extend(f for f in table._fields if f not in desired) - - table._fields = new_order - - -class TableMeta(type): - """ - This metaclass contains functionality on table classes, that doesn't exist on its instances. - - Example: - class MyTable(TypedTable): - some_field: TypedField[int] - - MyTable.update_or_insert(...) # should work - - MyTable.some_field # -> Field, can be used to query etc. - - row = MyTable.first() # returns instance of MyTable - - # row.update_or_insert(...) # shouldn't work! - - row.some_field # -> int, with actual data - - """ - - # set up by db.define: - # _db: TypeDAL | None = None - # _table: Table | None = None - _db: TypeDAL | None = None - _table: Table | None = None - _relationships: dict[str, Relationship[Any]] | None = None - - ######################### - # TypeDAL custom logic: # - ######################### - - def __set_internals__(self, db: pydal.DAL, table: Table, relationships: dict[str, Relationship[Any]]) -> None: - """ - Store the related database and pydal table for later usage. - """ - self._db = db - self._table = table - self._relationships = relationships - - def __getattr__(self, col: str) -> Optional[Field]: - """ - Magic method used by TypedTableMeta to get a database field with dot notation on a class. - - Example: - SomeTypedTable.col -> db.table.col (via TypedTableMeta.__getattr__) - - """ - if self._table: - return getattr(self._table, col, None) - - return None - - def _ensure_table_defined(self) -> Table: - if not self._table: - raise EnvironmentError("@define or db.define is not called on this class yet!") - return self._table - - def __iter__(self) -> typing.Generator[Field, None, None]: - """ - Loop through the columns of this model. - """ - table = self._ensure_table_defined() - yield from iter(table) - - def __getitem__(self, item: str) -> Field: - """ - Allow dict notation to get a column of this table (-> Field instance). - """ - table = self._ensure_table_defined() - return table[item] - - def __str__(self) -> str: - """ - Normally, just returns the underlying table name, but with a fallback if the model is unbound. - """ - if self._table: - return str(self._table) - else: - return f"" - - def from_row(self: Type[T_MetaInstance], row: pydal.objects.Row) -> T_MetaInstance: - """ - Create a model instance from a pydal row. - """ - return self(row) - - def all(self: Type[T_MetaInstance]) -> "TypedRows[T_MetaInstance]": - """ - Return all rows for this model. - """ - return self.collect() - - def get_relationships(self) -> dict[str, Relationship[Any]]: - """ - Return the registered relationships of the current model. - """ - return self._relationships or {} - - ########################## - # TypeDAL Modified Logic # - ########################## - - def insert(self: Type[T_MetaInstance], **fields: Any) -> T_MetaInstance: - """ - This is only called when db.define is not used as a decorator. - - cls.__table functions as 'self' - - Args: - **fields: anything you want to insert in the database - - Returns: the ID of the new row. - - """ - table = self._ensure_table_defined() - - result = table.insert(**fields) - # it already is an int but mypy doesn't understand that - return self(result) - - def _insert(self, **fields: Any) -> str: - table = self._ensure_table_defined() - - return str(table._insert(**fields)) - - def bulk_insert(self: Type[T_MetaInstance], items: list[AnyDict]) -> "TypedRows[T_MetaInstance]": - """ - Insert multiple rows, returns a TypedRows set of new instances. - """ - table = self._ensure_table_defined() - result = table.bulk_insert(items) - return self.where(lambda row: row.id.belongs(result)).collect() - - def update_or_insert( - self: Type[T_MetaInstance], - query: T_Query | AnyDict = DEFAULT, - **values: Any, - ) -> T_MetaInstance: - """ - Update a row if query matches, else insert a new one. - - Returns the created or updated instance. - """ - table = self._ensure_table_defined() - - if query is DEFAULT: - record = table(**values) - elif isinstance(query, dict): - record = table(**query) - else: - record = table(query) - - if not record: - return self.insert(**values) - - record.update_record(**values) - return self(record) - - def validate_and_insert( - self: Type[T_MetaInstance], - **fields: Any, - ) -> tuple[Optional[T_MetaInstance], Optional[dict[str, str]]]: - """ - Validate input data and then insert a row. - - Returns a tuple of (the created instance, a dict of errors). - """ - table = self._ensure_table_defined() - result = table.validate_and_insert(**fields) - if row_id := result.get("id"): - return self(row_id), None - else: - return None, result.get("errors") - - def validate_and_update( - self: Type[T_MetaInstance], - query: Query, - **fields: Any, - ) -> tuple[Optional[T_MetaInstance], Optional[dict[str, str]]]: - """ - Validate input data and then update max 1 row. - - Returns a tuple of (the updated instance, a dict of errors). - """ - table = self._ensure_table_defined() - - result = table.validate_and_update(query, **fields) - - if errors := result.get("errors"): - return None, errors - elif row_id := result.get("id"): - return self(row_id), None - else: # pragma: no cover - # update on query without result (shouldnt happen) - return None, None - - def validate_and_update_or_insert( - self: Type[T_MetaInstance], - query: Query, - **fields: Any, - ) -> tuple[Optional[T_MetaInstance], Optional[dict[str, str]]]: - """ - Validate input data and then update_and_insert (on max 1 row). - - Returns a tuple of (the updated/created instance, a dict of errors). - """ - table = self._ensure_table_defined() - result = table.validate_and_update_or_insert(query, **fields) - - if errors := result.get("errors"): - return None, errors - elif row_id := result.get("id"): - return self(row_id), None - else: # pragma: no cover - # update on query without result (shouldnt happen) - return None, None - - def select(self: Type[T_MetaInstance], *a: Any, **kw: Any) -> "QueryBuilder[T_MetaInstance]": - """ - See QueryBuilder.select! - """ - return QueryBuilder(self).select(*a, **kw) - - def column(self: Type[T_MetaInstance], field: "TypedField[T] | T", **options: Unpack[SelectKwargs]) -> list[T]: - """ - Get all values in a specific column. - - Shortcut for `.select(field).execute().column(field)`. - """ - return QueryBuilder(self).select(field, **options).execute().column(field) - - def paginate(self: Type[T_MetaInstance], limit: int, page: int = 1) -> "PaginatedRows[T_MetaInstance]": - """ - See QueryBuilder.paginate! - """ - return QueryBuilder(self).paginate(limit=limit, page=page) - - def chunk(self: Type[T_MetaInstance], chunk_size: int) -> typing.Generator["TypedRows[T_MetaInstance]", Any, None]: - """ - See QueryBuilder.chunk! - """ - return QueryBuilder(self).chunk(chunk_size) - - def where(self: Type[T_MetaInstance], *a: Any, **kw: Any) -> "QueryBuilder[T_MetaInstance]": - """ - See QueryBuilder.where! - """ - return QueryBuilder(self).where(*a, **kw) - - def orderby(self: Type[T_MetaInstance], *fields: OrderBy) -> "QueryBuilder[T_MetaInstance]": - """ - See QueryBuilder.orderby! - """ - return QueryBuilder(self).orderby(*fields) - - def cache(self: Type[T_MetaInstance], *deps: Any, **kwargs: Any) -> "QueryBuilder[T_MetaInstance]": - """ - See QueryBuilder.cache! - """ - return QueryBuilder(self).cache(*deps, **kwargs) - - def count(self: Type[T_MetaInstance]) -> int: - """ - See QueryBuilder.count! - """ - return QueryBuilder(self).count() - - def exists(self: Type[T_MetaInstance]) -> bool: - """ - See QueryBuilder.exists! - """ - return QueryBuilder(self).exists() - - def first(self: Type[T_MetaInstance]) -> T_MetaInstance | None: - """ - See QueryBuilder.first! - """ - return QueryBuilder(self).first() - - def first_or_fail(self: Type[T_MetaInstance]) -> T_MetaInstance: - """ - See QueryBuilder.first_or_fail! - """ - return QueryBuilder(self).first_or_fail() - - def join( - self: Type[T_MetaInstance], - *fields: str | Type["TypedTable"], - method: JOIN_OPTIONS = None, - on: OnQuery | list[Expression] | Expression = None, - condition: Condition = None, - condition_and: Condition = None, - ) -> "QueryBuilder[T_MetaInstance]": - """ - See QueryBuilder.join! - """ - return QueryBuilder(self).join(*fields, on=on, condition=condition, method=method, condition_and=condition_and) - - def collect(self: Type[T_MetaInstance], verbose: bool = False) -> "TypedRows[T_MetaInstance]": - """ - See QueryBuilder.collect! - """ - return QueryBuilder(self).collect(verbose=verbose) - - @property - def ALL(cls) -> pydal.objects.SQLALL: - """ - Select all fields for this table. - """ - table = cls._ensure_table_defined() - - return table.ALL - - ########################## - # TypeDAL Shadowed Logic # - ########################## - fields: list[str] - - # other table methods: +# note: these imports exist at the bottom of this file to prevent circular import issues: - def truncate(self, mode: str = "") -> None: - """ - Remove all data and reset index. - """ - table = self._ensure_table_defined() - table.truncate(mode) - - def drop(self, mode: str = "") -> None: - """ - Remove the underlying table. - """ - table = self._ensure_table_defined() - table.drop(mode) - - def create_index(self, name: str, *fields: Field | str, **kwargs: Any) -> bool: - """ - Add an index on some columns of this table. - """ - table = self._ensure_table_defined() - result = table.create_index(name, *fields, **kwargs) - return typing.cast(bool, result) - - def drop_index(self, name: str, if_exists: bool = False) -> bool: - """ - Remove an index from this table. - """ - table = self._ensure_table_defined() - result = table.drop_index(name, if_exists) - return typing.cast(bool, result) - - def import_from_csv_file( - self, - csvfile: typing.TextIO, - id_map: dict[str, str] = None, - null: Any = "", - unique: str = "uuid", - id_offset: dict[str, int] = None, # id_offset used only when id_map is None - transform: typing.Callable[[dict[Any, Any]], dict[Any, Any]] = None, - validate: bool = False, - encoding: str = "utf-8", - delimiter: str = ",", - quotechar: str = '"', - quoting: int = csv.QUOTE_MINIMAL, - restore: bool = False, - **kwargs: Any, - ) -> None: - """ - Load a csv file into the database. - """ - table = self._ensure_table_defined() - table.import_from_csv_file( - csvfile, - id_map=id_map, - null=null, - unique=unique, - id_offset=id_offset, - transform=transform, - validate=validate, - encoding=encoding, - delimiter=delimiter, - quotechar=quotechar, - quoting=quoting, - restore=restore, - **kwargs, - ) - - def on(self, query: Query | bool) -> Expression: - """ - Shadow Table.on. - - Used for joins. - - See Also: - http://web2py.com/books/default/chapter/29/06/the-database-abstraction-layer?search=export_to_csv_file#One-to-many-relation - """ - table = self._ensure_table_defined() - return typing.cast(Expression, table.on(query)) - - def with_alias(self: Type[T_MetaInstance], alias: str) -> Type[T_MetaInstance]: - """ - Shadow Table.with_alias. +from .fields import * # noqa: E402 F403 # isort: skip ; to fill globals() scope +from .define import TableDefinitionBuilder # noqa: E402 +from .rows import TypedSet # noqa: E402 +from .tables import TypedTable # noqa: E402 - Useful for joins when joining the same table multiple times. - - See Also: - http://web2py.com/books/default/chapter/29/06/the-database-abstraction-layer#One-to-many-relation - """ - table = self._ensure_table_defined() - return typing.cast(Type[T_MetaInstance], table.with_alias(alias)) - - def unique_alias(self: Type[T_MetaInstance]) -> Type[T_MetaInstance]: - """ - Generates a unique alias for this table. - - Useful for joins when joining the same table multiple times - and you don't want to keep track of aliases yourself. - """ - key = f"{self.__name__.lower()}_{hash(uuid.uuid4())}" - return self.with_alias(key) - - # hooks: - def _hook_once( - cls: Type[T_MetaInstance], - hooks: list[typing.Callable[P, R]], - fn: typing.Callable[P, R], - ) -> Type[T_MetaInstance]: - @functools.wraps(fn) - def wraps(*a: P.args, **kw: P.kwargs) -> R: - try: - return fn(*a, **kw) - finally: - hooks.remove(wraps) - - hooks.append(wraps) - return cls - - def before_insert( - cls: Type[T_MetaInstance], - fn: typing.Callable[[T_MetaInstance], Optional[bool]] | typing.Callable[[OpRow], Optional[bool]], - ) -> Type[T_MetaInstance]: - """ - Add a before insert hook. - """ - if fn not in cls._before_insert: - cls._before_insert.append(fn) - return cls - - def before_insert_once( - cls: Type[T_MetaInstance], - fn: typing.Callable[[T_MetaInstance], Optional[bool]] | typing.Callable[[OpRow], Optional[bool]], - ) -> Type[T_MetaInstance]: - """ - Add a before insert hook that only fires once and then removes itself. - """ - return cls._hook_once(cls._before_insert, fn) # type: ignore - - def after_insert( - cls: Type[T_MetaInstance], - fn: ( - typing.Callable[[T_MetaInstance, Reference], Optional[bool]] - | typing.Callable[[OpRow, Reference], Optional[bool]] - ), - ) -> Type[T_MetaInstance]: - """ - Add an after insert hook. - """ - if fn not in cls._after_insert: - cls._after_insert.append(fn) - return cls - - def after_insert_once( - cls: Type[T_MetaInstance], - fn: ( - typing.Callable[[T_MetaInstance, Reference], Optional[bool]] - | typing.Callable[[OpRow, Reference], Optional[bool]] - ), - ) -> Type[T_MetaInstance]: - """ - Add an after insert hook that only fires once and then removes itself. - """ - return cls._hook_once(cls._after_insert, fn) # type: ignore - - def before_update( - cls: Type[T_MetaInstance], - fn: typing.Callable[[Set, T_MetaInstance], Optional[bool]] | typing.Callable[[Set, OpRow], Optional[bool]], - ) -> Type[T_MetaInstance]: - """ - Add a before update hook. - """ - if fn not in cls._before_update: - cls._before_update.append(fn) - return cls - - def before_update_once( - cls, - fn: typing.Callable[[Set, T_MetaInstance], Optional[bool]] | typing.Callable[[Set, OpRow], Optional[bool]], - ) -> Type[T_MetaInstance]: - """ - Add a before update hook that only fires once and then removes itself. - """ - return cls._hook_once(cls._before_update, fn) # type: ignore - - def after_update( - cls: Type[T_MetaInstance], - fn: typing.Callable[[Set, T_MetaInstance], Optional[bool]] | typing.Callable[[Set, OpRow], Optional[bool]], - ) -> Type[T_MetaInstance]: - """ - Add an after update hook. - """ - if fn not in cls._after_update: - cls._after_update.append(fn) - return cls - - def after_update_once( - cls: Type[T_MetaInstance], - fn: typing.Callable[[Set, T_MetaInstance], Optional[bool]] | typing.Callable[[Set, OpRow], Optional[bool]], - ) -> Type[T_MetaInstance]: - """ - Add an after update hook that only fires once and then removes itself. - """ - return cls._hook_once(cls._after_update, fn) # type: ignore - - def before_delete(cls: Type[T_MetaInstance], fn: typing.Callable[[Set], Optional[bool]]) -> Type[T_MetaInstance]: - """ - Add a before delete hook. - """ - if fn not in cls._before_delete: - cls._before_delete.append(fn) - return cls - - def before_delete_once( - cls: Type[T_MetaInstance], - fn: typing.Callable[[Set], Optional[bool]], - ) -> Type[T_MetaInstance]: - """ - Add a before delete hook that only fires once and then removes itself. - """ - return cls._hook_once(cls._before_delete, fn) - - def after_delete(cls: Type[T_MetaInstance], fn: typing.Callable[[Set], Optional[bool]]) -> Type[T_MetaInstance]: - """ - Add an after delete hook. - """ - if fn not in cls._after_delete: - cls._after_delete.append(fn) - return cls - - def after_delete_once( - cls: Type[T_MetaInstance], - fn: typing.Callable[[Set], Optional[bool]], - ) -> Type[T_MetaInstance]: - """ - Add an after delete hook that only fires once and then removes itself. - """ - return cls._hook_once(cls._after_delete, fn) - - def reorder_fields(cls, *fields: str | Field | TypedField[Any], keep_others: bool = True) -> None: - """ - Reorder fields of a typedal table. - - Args: - fields: List of field names (str) or Field objects in desired order. - keep_others (bool): - - True (default): keep other fields at the end, in their original order. - - False: remove other fields (only keep what's specified). - """ - return reorder_fields(cls._table, fields, keep_others=keep_others) - - -class TypedField(Expression, typing.Generic[T_Value]): # pragma: no cover - """ - Typed version of pydal.Field, which will be converted to a normal Field in the background. - """ - - # will be set by .bind on db.define - name = "" - _db: Optional[pydal.DAL] = None - _rname: Optional[str] = None - _table: Optional[Table] = None - _field: Optional[Field] = None - - _type: T_annotation - kwargs: Any - - requires: Validator | typing.Iterable[Validator] - - # NOTE: for the logic of converting a TypedField into a pydal Field, see TypeDAL._to_field - - def __init__( - self, - _type: Type[T_Value] | types.UnionType = str, # type: ignore - /, - **settings: Unpack[FieldSettings], - ) -> None: - """ - Typed version of pydal.Field, which will be converted to a normal Field in the background. - - Provide the Python type for this field as the first positional argument - and any other settings to Field() as keyword parameters. - """ - self._type = _type - self.kwargs = settings - # super().__init__() - - @typing.overload - def __get__(self, instance: T_MetaInstance, owner: Type[T_MetaInstance]) -> T_Value: # pragma: no cover - """ - row.field -> (actual data). - """ - - @typing.overload - def __get__(self, instance: None, owner: "Type[TypedTable]") -> "TypedField[T_Value]": # pragma: no cover - """ - Table.field -> Field. - """ - - def __get__( - self, - instance: T_MetaInstance | None, - owner: Type[T_MetaInstance], - ) -> typing.Union[T_Value, "TypedField[T_Value]"]: - """ - Since this class is a Descriptor field, \ - it returns something else depending on if it's called on a class or instance. - - (this is mostly for mypy/typing) - """ - if instance: - # this is only reached in a very specific case: - # an instance of the object was created with a specific set of fields selected (excluding the current one) - # in that case, no value was stored in the owner -> return None (since the field was not selected) - return typing.cast(T_Value, None) # cast as T_Value so mypy understands it for selected fields - else: - # getting as class -> return actual field so pydal understands it when using in query etc. - return typing.cast(TypedField[T_Value], self._field) # pretend it's still typed for IDE support - - def __str__(self) -> str: - """ - String representation of a Typed Field. - - If `type` is set explicitly (e.g. TypedField(str, type="text")), that type is used: `TypedField.text`, - otherwise the type annotation is used (e.g. TypedField(str) -> TypedField.str) - """ - return str(self._field) if self._field else "" - - def __repr__(self) -> str: - """ - More detailed string representation of a Typed Field. - - Uses __str__ and adds the provided extra options (kwargs) in the representation. - """ - s = self.__str__() - - if "type" in self.kwargs: - # manual type in kwargs supplied - t = self.kwargs["type"] - elif issubclass(type, type(self._type)): - # normal type, str.__name__ = 'str' - t = getattr(self._type, "__name__", str(self._type)) - elif t_args := typing.get_args(self._type): - # list[str] -> 'str' - t = t_args[0].__name__ - else: # pragma: no cover - # fallback - something else, may not even happen, I'm not sure - t = self._type - - s = f"TypedField[{t}].{s}" if s else f"TypedField[{t}]" - - kw = self.kwargs.copy() - kw.pop("type", None) - return f"<{s} with options {kw}>" - - def _to_field(self, extra_kwargs: typing.MutableMapping[str, Any]) -> Optional[str]: - """ - Convert a Typed Field instance to a pydal.Field. - - Actual logic in TypeDAL._to_field but this function creates the pydal type name and updates the kwarg settings. - """ - other_kwargs = self.kwargs.copy() - extra_kwargs.update(other_kwargs) # <- modifies and overwrites the default kwargs with user-specified ones - return extra_kwargs.pop("type", False) or TypeDAL._annotation_to_pydal_fieldtype(self._type, extra_kwargs) - - def bind(self, field: pydal.objects.Field, table: pydal.objects.Table) -> None: - """ - Bind the right db/table/field info to this class, so queries can be made using `Class.field == ...`. - """ - self._table = table - self._field = field - - def __getattr__(self, key: str) -> Any: - """ - If the regular getattribute does not work, try to get info from the related Field. - """ - with contextlib.suppress(AttributeError): - return super().__getattribute__(key) - - # try on actual field: - return getattr(self._field, key) - - def __eq__(self, other: Any) -> Query: - """ - Performing == on a Field will result in a Query. - """ - return typing.cast(Query, self._field == other) - - def __ne__(self, other: Any) -> Query: - """ - Performing != on a Field will result in a Query. - """ - return typing.cast(Query, self._field != other) - - def __gt__(self, other: Any) -> Query: - """ - Performing > on a Field will result in a Query. - """ - return typing.cast(Query, self._field > other) - - def __lt__(self, other: Any) -> Query: - """ - Performing < on a Field will result in a Query. - """ - return typing.cast(Query, self._field < other) - - def __ge__(self, other: Any) -> Query: - """ - Performing >= on a Field will result in a Query. - """ - return typing.cast(Query, self._field >= other) - - def __le__(self, other: Any) -> Query: - """ - Performing <= on a Field will result in a Query. - """ - return typing.cast(Query, self._field <= other) - - def __hash__(self) -> int: - """ - Shadow Field.__hash__. - """ - return hash(self._field) - - def __invert__(self) -> Expression: - """ - Performing ~ on a Field will result in an Expression. - """ - if not self._field: # pragma: no cover - raise ValueError("Unbound Field can not be inverted!") - - return typing.cast(Expression, ~self._field) - - def lower(self) -> Expression: - """ - For string-fields: compare lowercased values. - """ - if not self._field: # pragma: no cover - raise ValueError("Unbound Field can not be lowered!") - - return typing.cast(Expression, self._field.lower()) - - # ... etc - - -class _TypedTable: - """ - This class is a final shared parent between TypedTable and Mixins. - - This needs to exist because otherwise the __on_define__ of Mixins are not executed. - Notably, this class exists at a level ABOVE the `metaclass=TableMeta`, - because otherwise typing gets confused when Mixins are used and multiple types could satisfy - generic 'T subclass of TypedTable' - -> Setting 'TypedTable' as the parent for Mixin does not work at runtime (and works semi at type check time) - """ - - id: "TypedField[int]" - - _before_insert: list[typing.Callable[[Self], Optional[bool]] | typing.Callable[[OpRow], Optional[bool]]] - _after_insert: list[ - typing.Callable[[Self, Reference], Optional[bool]] | typing.Callable[[OpRow, Reference], Optional[bool]] - ] - _before_update: list[typing.Callable[[Set, Self], Optional[bool]] | typing.Callable[[Set, OpRow], Optional[bool]]] - _after_update: list[typing.Callable[[Set, Self], Optional[bool]] | typing.Callable[[Set, OpRow], Optional[bool]]] - _before_delete: list[typing.Callable[[Set], Optional[bool]]] - _after_delete: list[typing.Callable[[Set], Optional[bool]]] - - @classmethod - def __on_define__(cls, db: TypeDAL) -> None: - """ - Method that can be implemented by tables to do an action after db.define is completed. - - This can be useful if you need to add something like requires=IS_NOT_IN_DB(db, "table.field"), - where you need a reference to the current database, which may not exist yet when defining the model. - """ - - @classproperty - def _hooks(cls) -> dict[str, list[typing.Callable[..., Optional[bool]]]]: - return { - "before_insert": cls._before_insert, - "after_insert": cls._after_insert, - "before_update": cls._before_update, - "after_update": cls._after_update, - "before_delete": cls._before_delete, - "after_delete": cls._after_delete, - } - - -class TypedTable(_TypedTable, metaclass=TableMeta): - """ - Enhanded modeling system on top of pydal's Table that adds typing and additional functionality. - """ - - # set up by 'new': - _row: Row | None = None - _rows: tuple[Row, ...] = () - - _with: list[str] - - def _setup_instance_methods(self) -> None: - self.as_dict = self._as_dict # type: ignore - self.__json__ = self.as_json = self._as_json # type: ignore - # self.as_yaml = self._as_yaml # type: ignore - self.as_xml = self._as_xml # type: ignore - - self.update = self._update # type: ignore - - self.delete_record = self._delete_record # type: ignore - self.update_record = self._update_record # type: ignore - - def __new__( - cls, - row_or_id: typing.Union[Row, Query, pydal.objects.Set, int, str, None, "TypedTable"] = None, - **filters: Any, - ) -> Self: - """ - Create a Typed Rows model instance from an existing row, ID or query. - - Examples: - MyTable(1) - MyTable(id=1) - MyTable(MyTable.id == 1) - """ - table = cls._ensure_table_defined() - inst = super().__new__(cls) - - if isinstance(row_or_id, TypedTable): - # existing typed table instance! - return typing.cast(Self, row_or_id) - - elif isinstance(row_or_id, pydal.objects.Row): - row = row_or_id - elif row_or_id is not None: - row = table(row_or_id, **filters) - elif filters: - row = table(**filters) - else: - # dummy object - return inst - - if not row: - return None # type: ignore - - inst._row = row - - if hasattr(row, "id"): - inst.__dict__.update(row) - else: - # deal with _extra (and possibly others?) - # Row <{actual: {}, _extra: ...}> - inst.__dict__.update(row[str(cls)]) - - inst._setup_instance_methods() - return inst - - def __iter__(self) -> typing.Generator[Any, None, None]: - """ - Allows looping through the columns. - """ - row = self._ensure_matching_row() - yield from iter(row) - - def __getitem__(self, item: str) -> Any: - """ - Allows dictionary notation to get columns. - """ - if item in self.__dict__: - return self.__dict__.get(item) - - # fallback to lookup in row - if self._row: - return self._row[item] - - # nothing found! - raise KeyError(item) - - def __getattr__(self, item: str) -> Any: - """ - Allows dot notation to get columns. - """ - if value := self.get(item): - return value - - raise AttributeError(item) - - def keys(self) -> list[str]: - """ - Return the combination of row + relationship keys. - - Used by dict(row). - """ - return list(self._row.keys() if self._row else ()) + getattr(self, "_with", []) - - def get(self, item: str, default: Any = None) -> Any: - """ - Try to get a column from this instance, else return default. - """ - try: - return self.__getitem__(item) - except KeyError: - return default - - def __setitem__(self, key: str, value: Any) -> None: - """ - Data can both be updated via dot and dict notation. - """ - return setattr(self, key, value) - - def __int__(self) -> int: - """ - Calling int on a model instance will return its id. - """ - return getattr(self, "id", 0) - - def __bool__(self) -> bool: - """ - If the instance has an underlying row with data, it is truthy. - """ - return bool(getattr(self, "_row", False)) - - def _ensure_matching_row(self) -> Row: - if not getattr(self, "_row", None): - raise EnvironmentError("Trying to access non-existant row. Maybe it was deleted or not yet initialized?") - return self._row - - def __repr__(self) -> str: - """ - String representation of the model instance. - """ - model_name = self.__class__.__name__ - model_data = {} - - if self._row: - model_data = self._row.as_json() - - details = model_name - details += f"({model_data})" - - if relationships := getattr(self, "_with", []): - details += f" + {relationships}" - - return f"<{details}>" - - # serialization - # underscore variants work for class instances (set up by _setup_instance_methods) - - @classmethod - def as_dict(cls, flat: bool = False, sanitize: bool = True) -> AnyDict: - """ - Dump the object to a plain dict. - - Can be used as both a class or instance method: - - dumps the table info if it's a class - - dumps the row info if it's an instance (see _as_dict) - """ - table = cls._ensure_table_defined() - result = table.as_dict(flat, sanitize) - return typing.cast(AnyDict, result) - - @classmethod - def as_json(cls, sanitize: bool = True, indent: Optional[int] = None, **kwargs: Any) -> str: - """ - Dump the object to json. - - Can be used as both a class or instance method: - - dumps the table info if it's a class - - dumps the row info if it's an instance (see _as_json) - """ - data = cls.as_dict(sanitize=sanitize) - return as_json.encode(data, indent=indent, **kwargs) - - @classmethod - def as_xml(cls, sanitize: bool = True) -> str: # pragma: no cover - """ - Dump the object to xml. - - Can be used as both a class or instance method: - - dumps the table info if it's a class - - dumps the row info if it's an instance (see _as_xml) - """ - table = cls._ensure_table_defined() - return typing.cast(str, table.as_xml(sanitize)) - - @classmethod - def as_yaml(cls, sanitize: bool = True) -> str: - """ - Dump the object to yaml. - - Can be used as both a class or instance method: - - dumps the table info if it's a class - - dumps the row info if it's an instance (see _as_yaml) - """ - table = cls._ensure_table_defined() - return typing.cast(str, table.as_yaml(sanitize)) - - def _as_dict( - self, - datetime_to_str: bool = False, - custom_types: typing.Iterable[type] | type | None = None, - ) -> AnyDict: - row = self._ensure_matching_row() - - result = row.as_dict(datetime_to_str=datetime_to_str, custom_types=custom_types) - - def asdict_method(obj: Any) -> Any: # pragma: no cover - if hasattr(obj, "_as_dict"): # typedal - return obj._as_dict() - elif hasattr(obj, "as_dict"): # pydal - return obj.as_dict() - else: # something else?? - return obj.__dict__ - - if _with := getattr(self, "_with", None): - for relationship in _with: - data = self.get(relationship) - - if isinstance(data, list): - data = [asdict_method(_) for _ in data] - elif data: - data = asdict_method(data) - - result[relationship] = data - - return typing.cast(AnyDict, result) - - def _as_json( - self, - default: typing.Callable[[Any], Any] = None, - indent: Optional[int] = None, - **kwargs: Any, - ) -> str: - data = self._as_dict() - return as_json.encode(data, default=default, indent=indent, **kwargs) - - def _as_xml(self, sanitize: bool = True) -> str: # pragma: no cover - row = self._ensure_matching_row() - return typing.cast(str, row.as_xml(sanitize)) - - # def _as_yaml(self, sanitize: bool = True) -> str: - # row = self._ensure_matching_row() - # return typing.cast(str, row.as_yaml(sanitize)) - - def __setattr__(self, key: str, value: Any) -> None: - """ - When setting a property on a Typed Table model instance, also update the underlying row. - """ - if self._row and key in self._row.__dict__ and not callable(value): - # enables `row.key = value; row.update_record()` - self._row[key] = value - - super().__setattr__(key, value) - - @classmethod - def update(cls: Type[T_MetaInstance], query: Query, **fields: Any) -> T_MetaInstance | None: - """ - Update one record. - - Example: - MyTable.update(MyTable.id == 1, name="NewName") -> MyTable - """ - # todo: update multiple? - if record := cls(query): - return record.update_record(**fields) - else: - return None - - def _update(self: T_MetaInstance, **fields: Any) -> T_MetaInstance: - row = self._ensure_matching_row() - row.update(**fields) - self.__dict__.update(**fields) - return self - - def _update_record(self: T_MetaInstance, **fields: Any) -> T_MetaInstance: - row = self._ensure_matching_row() - new_row = row.update_record(**fields) - self.update(**new_row) - return self - - def update_record(self: T_MetaInstance, **fields: Any) -> T_MetaInstance: # pragma: no cover - """ - Here as a placeholder for _update_record. - - Will be replaced on instance creation! - """ - return self._update_record(**fields) - - def _delete_record(self) -> int: - """ - Actual logic in `pydal.helpers.classes.RecordDeleter`. - """ - row = self._ensure_matching_row() - result = row.delete_record() - self.__dict__ = {} # empty self, since row is no more. - self._row = None # just to be sure - self._setup_instance_methods() - # ^ instance methods might've been deleted by emptying dict, - # but we still want .as_dict to show an error, not the table's as_dict. - return typing.cast(int, result) - - def delete_record(self) -> int: # pragma: no cover - """ - Here as a placeholder for _delete_record. - - Will be replaced on instance creation! - """ - return self._delete_record() - - # __del__ is also called on the end of a scope so don't remove records on every del!! - - # pickling: - - def __getstate__(self) -> AnyDict: - """ - State to save when pickling. - - Prevents db connection from being pickled. - Similar to as_dict but without changing the data of the relationships (dill does that recursively) - """ - row = self._ensure_matching_row() - result: AnyDict = row.as_dict() - - if _with := getattr(self, "_with", None): - result["_with"] = _with - for relationship in _with: - data = self.get(relationship) - - result[relationship] = data - - result["_row"] = self._row.as_json() if self._row else "" - return result - - def __setstate__(self, state: AnyDict) -> None: - """ - Used by dill when loading from a bytestring. - """ - # as_dict also includes table info, so dump as json to only get the actual row data - # then create a new (more empty) row object: - state["_row"] = Row(json.loads(state["_row"])) - self.__dict__ |= state - - @classmethod - def _sql(cls) -> str: - """ - Generate SQL Schema for this table via pydal2sql (if 'migrations' extra is installed). - """ - try: - import pydal2sql - except ImportError as e: # pragma: no cover - raise RuntimeError("Can not generate SQL without the 'migration' extra or `pydal2sql` installed!") from e - - return pydal2sql.generate_sql(cls) - - def render(self, fields: list[Field] = None, compact: bool = False) -> Self: - """ - Renders a copy of the object with potentially modified values. - - Args: - fields: A list of fields to render. Defaults to all representable fields in the table. - compact: Whether to return only the value of the first field if there is only one field. - - Returns: - A copy of the object with potentially modified values. - """ - row = copy.deepcopy(self) - keys = list(row) - if not fields: - fields = [self._table[f] for f in self._table._fields] - fields = [f for f in fields if isinstance(f, Field) and f.represent] - - for field in fields: - if field._table == self._table: - row[field.name] = self._db.represent( - "rows_render", - field, - row[field.name], - row, - ) - # else: relationship, different logic: - - for relation_name in getattr(row, "_with", []): - if relation := self._relationships.get(relation_name): - relation_table = relation.table - if isinstance(relation_table, str): - relation_table = self._db[relation_table] - - relation_row = row[relation_name] - - if isinstance(relation_row, list): - # list of rows - combined = [] - - for related_og in relation_row: - related = copy.deepcopy(related_og) - for fieldname in related: - field = relation_table[fieldname] - related[field.name] = self._db.represent( - "rows_render", - field, - related[field.name], - related, - ) - combined.append(related) - - row[relation_name] = combined - else: - # 1 row - for fieldname in relation_row: - field = relation_table[fieldname] - row[relation_name][fieldname] = self._db.represent( - "rows_render", - field, - relation_row[field.name], - relation_row, - ) - - if compact and len(keys) == 1 and keys[0] != "_extra": # pragma: no cover - return typing.cast(Self, row[keys[0]]) - return row - - -# backwards compat: -TypedRow = TypedTable - - -class TypedRows(typing.Collection[T_MetaInstance], Rows): - """ - Slighly enhaned and typed functionality on top of pydal Rows (the result of a select). - """ - - records: dict[int, T_MetaInstance] - # _rows: Rows - model: Type[T_MetaInstance] - metadata: Metadata - - # pseudo-properties: actually stored in _rows - db: TypeDAL - colnames: list[str] - fields: list[Field] - colnames_fields: list[Field] - response: list[tuple[Any, ...]] - - def __init__( - self, - rows: Rows, - model: Type[T_MetaInstance], - records: dict[int, T_MetaInstance] = None, - metadata: Metadata = None, - raw: dict[int, list[Row]] = None, - ) -> None: - """ - Should not be called manually! - - Normally, the `records` from an existing `Rows` object are used - but these can be overwritten with a `records` dict. - `metadata` can be any (un)structured data - `model` is a Typed Table class - """ - - def _get_id(row: Row) -> int: - """ - Try to find the id field in a row. - - If _extra exists, the row changes: - - """ - if idx := getattr(row, "id", None): - return typing.cast(int, idx) - elif main := getattr(row, str(model), None): - return typing.cast(int, main.id) - else: # pragma: no cover - raise NotImplementedError(f"`id` could not be found for {row}") - - records = records or {_get_id(row): model(row) for row in rows} - raw = raw or {} - - for idx, entity in records.items(): - entity._rows = tuple(raw.get(idx, [])) - - super().__init__(rows.db, records, rows.colnames, rows.compact, rows.response, rows.fields) - self.model = model - self.metadata = metadata or {} - self.colnames = rows.colnames - - def __len__(self) -> int: - """ - Return the count of rows. - """ - return len(self.records) - - def __iter__(self) -> typing.Iterator[T_MetaInstance]: - """ - Loop through the rows. - """ - yield from self.records.values() - - def __contains__(self, ind: Any) -> bool: - """ - Check if an id exists in this result set. - """ - return ind in self.records - - def first(self) -> T_MetaInstance | None: - """ - Get the row with the lowest id. - """ - if not self.records: - return None - - return next(iter(self)) - - def last(self) -> T_MetaInstance | None: - """ - Get the row with the highest id. - """ - if not self.records: - return None - - max_id = max(self.records.keys()) - return self[max_id] - - def find( - self, - f: typing.Callable[[T_MetaInstance], Query], - limitby: tuple[int, int] = None, - ) -> "TypedRows[T_MetaInstance]": - """ - Returns a new Rows object, a subset of the original object, filtered by the function `f`. - """ - if not self.records: - return self.__class__(self, self.model, {}) - - records = {} - if limitby: - _min, _max = limitby - else: - _min, _max = 0, len(self) - count = 0 - for i, row in self.records.items(): - if f(row): - if _min <= count: - records[i] = row - count += 1 - if count == _max: - break - - return self.__class__(self, self.model, records) - - def exclude(self, f: typing.Callable[[T_MetaInstance], Query]) -> "TypedRows[T_MetaInstance]": - """ - Removes elements from the calling Rows object, filtered by the function `f`, \ - and returns a new Rows object containing the removed elements. - """ - if not self.records: - return self.__class__(self, self.model, {}) - removed = {} - to_remove = [] - for i in self.records: - row = self[i] - if f(row): - removed[i] = self.records[i] - to_remove.append(i) - - [self.records.pop(i) for i in to_remove] - - return self.__class__( - self, - self.model, - removed, - ) - - def sort(self, f: typing.Callable[[T_MetaInstance], Any], reverse: bool = False) -> list[T_MetaInstance]: - """ - Returns a list of sorted elements (not sorted in place). - """ - return [r for (r, s) in sorted(zip(self.records.values(), self), key=lambda r: f(r[1]), reverse=reverse)] - - def __str__(self) -> str: - """ - Simple string representation. - """ - return f"" - - def __repr__(self) -> str: - """ - Print a table on repr(). - """ - data = self.as_dict() - try: - headers = list(next(iter(data.values())).keys()) - except StopIteration: - headers = [] - - return mktable(data, headers) - - def group_by_value( - self, - *fields: "str | Field | TypedField[T]", - one_result: bool = False, - **kwargs: Any, - ) -> dict[T, list[T_MetaInstance]]: - """ - Group the rows by a specific field (which will be the dict key). - """ - kwargs["one_result"] = one_result - result = super().group_by_value(*fields, **kwargs) - return typing.cast(dict[T, list[T_MetaInstance]], result) - - def as_csv(self) -> str: - """ - Dump the data to csv. - """ - return typing.cast(str, super().as_csv()) - - def as_dict( - self, - key: str | Field | None = None, - compact: bool = False, - storage_to_dict: bool = False, - datetime_to_str: bool = False, - custom_types: list[type] | None = None, - ) -> dict[int, AnyDict]: - """ - Get the data in a dict of dicts. - """ - if any([key, compact, storage_to_dict, datetime_to_str, custom_types]): - # functionality not guaranteed - if isinstance(key, Field): - key = key.name - - return typing.cast( - dict[int, AnyDict], - super().as_dict( - key or "id", - compact, - storage_to_dict, - datetime_to_str, - custom_types, - ), - ) - - return {k: v.as_dict() for k, v in self.records.items()} - - def as_json(self, default: typing.Callable[[Any], Any] = None, indent: Optional[int] = None, **kwargs: Any) -> str: - """ - Turn the data into a dict and then dump to JSON. - """ - data = self.as_list() - - return as_json.encode(data, default=default, indent=indent, **kwargs) - - def json(self, default: typing.Callable[[Any], Any] = None, indent: Optional[int] = None, **kwargs: Any) -> str: - """ - Turn the data into a dict and then dump to JSON. - """ - return self.as_json(default=default, indent=indent, **kwargs) - - def as_list( - self, - compact: bool = False, - storage_to_dict: bool = False, - datetime_to_str: bool = False, - custom_types: list[type] = None, - ) -> list[AnyDict]: - """ - Get the data in a list of dicts. - """ - if any([compact, storage_to_dict, datetime_to_str, custom_types]): - return typing.cast(list[AnyDict], super().as_list(compact, storage_to_dict, datetime_to_str, custom_types)) - - return [_.as_dict() for _ in self.records.values()] - - def __getitem__(self, item: int) -> T_MetaInstance: - """ - You can get a specific row by ID from a typedrows by using rows[idx] notation. - - Since pydal's implementation differs (they expect a list instead of a dict with id keys), - using rows[0] will return the first row, regardless of its id. - """ - try: - return self.records[item] - except KeyError as e: - if item == 0 and (row := self.first()): - # special case: pydal internals think Rows.records is a list, not a dict - return row - - raise e - - def get(self, item: int) -> typing.Optional[T_MetaInstance]: - """ - Get a row by ID, or receive None if it isn't in this result set. - """ - return self.records.get(item) - - def update(self, **new_values: Any) -> bool: - """ - Update the current rows in the database with new_values. - """ - # cast to make mypy understand .id is a TypedField and not an int! - table = typing.cast(Type[TypedTable], self.model._ensure_table_defined()) - - ids = set(self.column("id")) - query = table.id.belongs(ids) - return bool(self.db(query).update(**new_values)) - - def delete(self) -> bool: - """ - Delete the currently selected rows from the database. - """ - # cast to make mypy understand .id is a TypedField and not an int! - table = typing.cast(Type[TypedTable], self.model._ensure_table_defined()) - - ids = set(self.column("id")) - query = table.id.belongs(ids) - return bool(self.db(query).delete()) - - def join( - self, - field: "Field | TypedField[Any]", - name: str = None, - constraint: Query = None, - fields: list[str | Field] = None, - orderby: Optional[str | Field] = None, - ) -> T_MetaInstance: - """ - This can be used to JOIN with some relationships after the initial select. - - Using the querybuilder's .join() method is prefered! - """ - result = super().join(field, name, constraint, fields or [], orderby) - return typing.cast(T_MetaInstance, result) - - def export_to_csv_file( - self, - ofile: typing.TextIO, - null: Any = "", - delimiter: str = ",", - quotechar: str = '"', - quoting: int = csv.QUOTE_MINIMAL, - represent: bool = False, - colnames: list[str] = None, - write_colnames: bool = True, - *args: Any, - **kwargs: Any, - ) -> None: - """ - Shadow export_to_csv_file from Rows, but with typing. - - See http://web2py.com/books/default/chapter/29/06/the-database-abstraction-layer?search=export_to_csv_file#Exporting-and-importing-data - """ - super().export_to_csv_file( - ofile, - null, - *args, - delimiter=delimiter, - quotechar=quotechar, - quoting=quoting, - represent=represent, - colnames=colnames or self.colnames, - write_colnames=write_colnames, - **kwargs, - ) - - @classmethod - def from_rows( - cls, - rows: Rows, - model: Type[T_MetaInstance], - metadata: Metadata = None, - ) -> "TypedRows[T_MetaInstance]": - """ - Internal method to convert a Rows object to a TypedRows. - """ - return cls(rows, model, metadata=metadata) - - def __getstate__(self) -> AnyDict: - """ - Used by dill to dump to bytes (exclude db connection etc). - """ - return { - "metadata": json.dumps(self.metadata, default=str), - "records": self.records, - "model": str(self.model._table), - "colnames": self.colnames, - } - - def __setstate__(self, state: AnyDict) -> None: - """ - Used by dill when loading from a bytestring. - """ - state["metadata"] = json.loads(state["metadata"]) - self.__dict__.update(state) - # db etc. set after undill by caching.py - - def render( - self, i: int | None = None, fields: list[Field] | None = None - ) -> typing.Generator[T_MetaInstance, None, None]: - """ - Takes an index and returns a copy of the indexed row with values \ - transformed via the "represent" attributes of the associated fields. - - Args: - i: index. If not specified, a generator is returned for iteration - over all the rows. - fields: a list of fields to transform (if None, all fields with - "represent" attributes will be transformed) - """ - if i is None: - # difference: uses .keys() instead of index - return (self.render(i, fields=fields) for i in self.records) - - if not self.db.has_representer("rows_render"): # pragma: no cover - raise RuntimeError( - "Rows.render() needs a `rows_render` representer in DAL instance", - ) - - row = self.records[i] - return row.render(fields, compact=self.compact) - - -from .caching import ( # noqa: E402 - _remove_cache, +from .caching import ( # isort: skip # noqa: E402 _TypedalCache, _TypedalCacheDependency, - create_and_hash_cache_key, - get_expire, - load_from_cache, - save_to_cache, ) - - -def normalize_table_keys(row: Row, pattern: re.Pattern[str] = re.compile(r"^([a-zA-Z_]+)_(\d{5,})$")) -> Row: - """ - Normalize table keys in a PyDAL Row object by stripping numeric hash suffixes from table names, \ - only if the suffix is 5 or more digits. - - For example: - Row({'articles_12345': {...}}) -> Row({'articles': {...}}) - Row({'articles_123': {...}}) -> unchanged - - Returns: - Row: A new Row object with normalized keys. - """ - new_data: dict[str, Any] = {} - for key, value in row.items(): - if match := pattern.match(key): - base, _suffix = match.groups() - normalized_key = base - new_data[normalized_key] = value - else: - new_data[key] = value - return Row(new_data) - - -class QueryBuilder(typing.Generic[T_MetaInstance]): - """ - Abstration on top of pydal's query system. - """ - - model: Type[T_MetaInstance] - query: Query - select_args: list[Any] - select_kwargs: SelectKwargs - relationships: dict[str, Relationship[Any]] - metadata: Metadata - - def __init__( - self, - model: Type[T_MetaInstance], - add_query: Optional[Query] = None, - select_args: Optional[list[Any]] = None, - select_kwargs: Optional[SelectKwargs] = None, - relationships: dict[str, Relationship[Any]] = None, - metadata: Metadata = None, - ): - """ - Normally, you wouldn't manually initialize a QueryBuilder but start using a method on a TypedTable. - - Example: - MyTable.where(...) -> QueryBuilder[MyTable] - """ - self.model = model - table = model._ensure_table_defined() - default_query = typing.cast(Query, table.id > 0) - self.query = add_query or default_query - self.select_args = select_args or [] - self.select_kwargs = select_kwargs or {} - self.relationships = relationships or {} - self.metadata = metadata or {} - - def __str__(self) -> str: - """ - Simple string representation for the query builder. - """ - return f"QueryBuilder for {self.model}" - - def __repr__(self) -> str: - """ - Advanced string representation for the query builder. - """ - return ( - f"" - ) - - def __bool__(self) -> bool: - """ - Querybuilder is truthy if it has any conditions. - """ - table = self.model._ensure_table_defined() - default_query = typing.cast(Query, table.id > 0) - return any( - [ - self.query != default_query, - self.select_args, - self.select_kwargs, - self.relationships, - self.metadata, - ], - ) - - def _extend( - self, - add_query: Optional[Query] = None, - overwrite_query: Optional[Query] = None, - select_args: Optional[list[Any]] = None, - select_kwargs: Optional[SelectKwargs] = None, - relationships: dict[str, Relationship[Any]] = None, - metadata: Metadata = None, - ) -> "QueryBuilder[T_MetaInstance]": - return QueryBuilder( - self.model, - (add_query & self.query) if add_query else overwrite_query or self.query, - (self.select_args + select_args) if select_args else self.select_args, - (self.select_kwargs | select_kwargs) if select_kwargs else self.select_kwargs, - (self.relationships | relationships) if relationships else self.relationships, - (self.metadata | (metadata or {})) if metadata else self.metadata, - ) - - def select(self, *fields: Any, **options: Unpack[SelectKwargs]) -> "QueryBuilder[T_MetaInstance]": - """ - Fields: database columns by name ('id'), by field reference (table.id) or other (e.g. table.ALL). - - Options: - paraphrased from the web2py pydal docs, - For more info, see http://www.web2py.com/books/default/chapter/29/06/the-database-abstraction-layer#orderby-groupby-limitby-distinct-having-orderby_on_limitby-join-left-cache - - orderby: field(s) to order by. Supported: - table.name - sort by name, ascending - ~table.name - sort by name, descending - - sort randomly - table.name|table.id - sort by two fields (first name, then id) - - groupby, having: together with orderby: - groupby can be a field (e.g. table.name) to group records by - having can be a query, only those `having` the condition are grouped - - limitby: tuple of min and max. When using the query builder, .paginate(limit, page) is recommended. - distinct: bool/field. Only select rows that differ - orderby_on_limitby (bool, default: True): by default, an implicit orderby is added when doing limitby. - join: othertable.on(query) - do an INNER JOIN. Using TypeDAL relationships with .join() is recommended! - left: othertable.on(query) - do a LEFT JOIN. Using TypeDAL relationships with .join() is recommended! - cache: cache the query result to speed up repeated queries; e.g. (cache=(cache.ram, 3600), cacheable=True) - """ - return self._extend(select_args=list(fields), select_kwargs=options) - - def orderby(self, *fields: OrderBy) -> "QueryBuilder[T_MetaInstance]": - """ - Order the query results by specified fields. - - Args: - fields: field(s) to order by. Supported: - table.name - sort by name, ascending - ~table.name - sort by name, descending - - sort randomly - table.name|table.id - sort by two fields (first name, then id) - - Returns: - QueryBuilder: A new QueryBuilder instance with the ordering applied. - """ - return self.select(orderby=fields) - - def where( - self, - *queries_or_lambdas: Query | typing.Callable[[Type[T_MetaInstance]], Query] | dict[str, Any], - **filters: Any, - ) -> "QueryBuilder[T_MetaInstance]": - """ - Extend the builder's query. - - Can be used in multiple ways: - .where(Query) -> with a direct query such as `Table.id == 5` - .where(lambda table: table.id == 5) -> with a query via a lambda - .where(id=5) -> via keyword arguments - - When using multiple where's, they will be ANDed: - .where(lambda table: table.id == 5).where(lambda table: table.id == 6) == (table.id == 5) & (table.id=6) - When passing multiple queries to a single .where, they will be ORed: - .where(lambda table: table.id == 5, lambda table: table.id == 6) == (table.id == 5) | (table.id=6) - """ - new_query = self.query - table = self.model._ensure_table_defined() - - queries_or_lambdas = ( - *queries_or_lambdas, - filters, - ) - - subquery = typing.cast(Query, DummyQuery()) - for query_part in queries_or_lambdas: - if isinstance(query_part, (Field, pydal.objects.Field)) or is_typed_field(query_part): - subquery |= typing.cast(Query, query_part != None) - elif isinstance(query_part, (pydal.objects.Query, Expression, pydal.objects.Expression)): - subquery |= typing.cast(Query, query_part) - elif callable(query_part): - if result := query_part(self.model): - subquery |= result - elif isinstance(query_part, dict): - subsubquery = DummyQuery() - for field, value in query_part.items(): - subsubquery &= table[field] == value - if subsubquery: - subquery |= subsubquery - else: - raise ValueError(f"Unexpected query type ({type(query_part)}).") - - if subquery: - new_query &= subquery - - return self._extend(overwrite_query=new_query) - - def join( - self, - *fields: str | Type[TypedTable], - method: JOIN_OPTIONS = None, - on: OnQuery | list[Expression] | Expression = None, - condition: Condition = None, - condition_and: Condition = None, - ) -> "QueryBuilder[T_MetaInstance]": - """ - Include relationship fields in the result. - - `fields` can be names of Relationships on the current model. - If no fields are passed, all will be used. - - By default, the `method` defined in the relationship is used. - This can be overwritten with the `method` keyword argument (left or inner) - - `condition_and` can be used to add extra conditions to an inner join. - """ - # todo: allow limiting amount of related rows returned for join? - # todo: it would be nice if 'fields' could be an actual relationship - # (Article.tags = list[Tag]) and you could change the .condition and .on - # this could deprecate condition_and - - relationships = self.model.get_relationships() - - if condition and on: - raise ValueError("condition and on can not be used together!") - elif condition: - if len(fields) != 1: - raise ValueError("join(field, condition=...) can only be used with exactly one field!") - - if isinstance(condition, pydal.objects.Query): - condition = as_lambda(condition) - - to_field = typing.cast(Type[TypedTable], fields[0]) - relationships = { - str(to_field): Relationship(to_field, condition=condition, join=method, condition_and=condition_and) - } - elif on: - if len(fields) != 1: - raise ValueError("join(field, on=...) can only be used with exactly one field!") - - if isinstance(on, pydal.objects.Expression): - on = [on] - - if isinstance(on, list): - on = as_lambda(on) - - to_field = typing.cast(Type[TypedTable], fields[0]) - relationships = {str(to_field): Relationship(to_field, on=on, join=method, condition_and=condition_and)} - - else: - if fields: - # join on every relationship - relationships = {str(k): relationships[str(k)].clone(condition_and=condition_and) for k in fields} - - if method: - relationships = { - str(k): r.clone(join=method, condition_and=condition_and) for k, r in relationships.items() - } - - return self._extend(relationships=relationships) - - def cache( - self, - *deps: Any, - expires_at: Optional[dt.datetime] = None, - ttl: Optional[int | dt.timedelta] = None, - ) -> "QueryBuilder[T_MetaInstance]": - """ - Enable caching for this query to load repeated calls from a dill row \ - instead of executing the sql and collecing matching rows again. - """ - existing = self.metadata.get("cache", {}) - - metadata: Metadata = {} - - cache_meta = typing.cast( - CacheMetadata, - self.metadata.get("cache", {}) - | { - "enabled": True, - "depends_on": existing.get("depends_on", []) + [str(_) for _ in deps], - "expires_at": get_expire(expires_at=expires_at, ttl=ttl), - }, - ) - - metadata["cache"] = cache_meta - return self._extend(metadata=metadata) - - def _get_db(self) -> TypeDAL: - if db := self.model._db: - return db - else: # pragma: no cover - raise EnvironmentError("@define or db.define is not called on this class yet!") - - def _select_arg_convert(self, arg: Any) -> Any: - # typedfield are not really used at runtime anymore, but leave it in for safety: - if isinstance(arg, TypedField): # pragma: no cover - arg = arg._field - - return arg - - def delete(self) -> list[int]: - """ - Based on the current query, delete rows and return a list of deleted IDs. - """ - db = self._get_db() - removed_ids = [_.id for _ in db(self.query).select("id")] - if db(self.query).delete(): - # success! - return removed_ids - - return [] - - def _delete(self) -> str: - db = self._get_db() - return str(db(self.query)._delete()) - - def update(self, **fields: Any) -> list[int]: - """ - Based on the current query, update `fields` and return a list of updated IDs. - """ - # todo: limit? - db = self._get_db() - updated_ids = db(self.query).select("id").column("id") - if db(self.query).update(**fields): - # success! - return updated_ids - - return [] - - def _update(self, **fields: Any) -> str: - db = self._get_db() - return str(db(self.query)._update(**fields)) - - def _before_query(self, mut_metadata: Metadata, add_id: bool = True) -> tuple[Query, list[Any], SelectKwargs]: - select_args = [self._select_arg_convert(_) for _ in self.select_args] or [self.model.ALL] - select_kwargs = self.select_kwargs.copy() - query = self.query - model = self.model - mut_metadata["query"] = query - # require at least id of main table: - select_fields = ", ".join([str(_) for _ in select_args]) - tablename = str(model) - - if add_id and f"{tablename}.id" not in select_fields: - # fields of other selected, but required ID is missing. - select_args.append(model.id) - - if self.relationships: - query, select_args = self._handle_relationships_pre_select(query, select_args, select_kwargs, mut_metadata) - - return query, select_args, select_kwargs - - def to_sql(self, add_id: bool = False) -> str: - """ - Generate the SQL for the built query. - """ - db = self._get_db() - - query, select_args, select_kwargs = self._before_query({}, add_id=add_id) - - return str(db(query)._select(*select_args, **select_kwargs)) - - def _collect(self) -> str: - """ - Alias for to_sql, pydal-like syntax. - """ - return self.to_sql() - - def _collect_cached(self, metadata: Metadata) -> "TypedRows[T_MetaInstance] | None": - expires_at = metadata["cache"].get("expires_at") - metadata["cache"] |= { - # key is partly dependant on cache metadata but not these: - "key": None, - "status": None, - "cached_at": None, - "expires_at": None, - } - - _, key = create_and_hash_cache_key( - self.model, - metadata, - self.query, - self.select_args, - self.select_kwargs, - self.relationships.keys(), - ) - - # re-set after creating key: - metadata["cache"]["expires_at"] = expires_at - metadata["cache"]["key"] = key - - return load_from_cache(key, self._get_db()) - - def execute(self, add_id: bool = False) -> Rows: - """ - Raw version of .collect which only executes the SQL, without performing any magic afterwards. - """ - db = self._get_db() - metadata = typing.cast(Metadata, self.metadata.copy()) - - query, select_args, select_kwargs = self._before_query(metadata, add_id=add_id) - - return db(query).select(*select_args, **select_kwargs) - - def collect( - self, - verbose: bool = False, - _to: Type["TypedRows[Any]"] = None, - add_id: bool = True, - ) -> "TypedRows[T_MetaInstance]": - """ - Execute the built query and turn it into model instances, while handling relationships. - """ - if _to is None: - _to = TypedRows - - db = self._get_db() - metadata = typing.cast(Metadata, self.metadata.copy()) - - if metadata.get("cache", {}).get("enabled") and (result := self._collect_cached(metadata)): - return result - - query, select_args, select_kwargs = self._before_query(metadata, add_id=add_id) - - metadata["sql"] = db(query)._select(*select_args, **select_kwargs) - - if verbose: # pragma: no cover - print(metadata["sql"]) - - rows: Rows = db(query).select(*select_args, **select_kwargs) - - metadata["final_query"] = str(query) - metadata["final_args"] = [str(_) for _ in select_args] - metadata["final_kwargs"] = select_kwargs - - if verbose: # pragma: no cover - print(rows) - - if not self.relationships: - # easy - typed_rows = _to.from_rows(rows, self.model, metadata=metadata) - - else: - # harder: try to match rows to the belonging objects - # assume structure of {'table': } per row. - # if that's not the case, return default behavior again - typed_rows = self._collect_with_relationships(rows, metadata=metadata, _to=_to) - - # only saves if requested in metadata: - return save_to_cache(typed_rows, rows) - - @typing.overload - def column(self, field: TypedField[T], **options: Unpack[SelectKwargs]) -> list[T]: - """ - If a typedfield is passed, the output type can be safely determined. - """ - - @typing.overload - def column(self, field: T, **options: Unpack[SelectKwargs]) -> list[T]: - """ - Otherwise, the output type is loosely determined (assumes `field: type` or Any). - """ - - def column(self, field: TypedField[T] | T, **options: Unpack[SelectKwargs]) -> list[T]: - """ - Get all values in a specific column. - - Shortcut for `.select(field).execute().column(field)`. - """ - return self.select(field, **options).execute().column(field) - - def _handle_relationships_pre_select( - self, - query: Query, - select_args: list[Any], - select_kwargs: SelectKwargs, - metadata: Metadata, - ) -> tuple[Query, list[Any]]: - db = self._get_db() - model = self.model - - metadata["relationships"] = set(self.relationships.keys()) - - join = [] - for key, relation in self.relationships.items(): - if not relation.condition or relation.join != "inner": - continue - - other = relation.get_table(db) - other = other.with_alias(f"{key}_{hash(relation)}") - condition = relation.condition(model, other) - if callable(relation.condition_and): - condition &= relation.condition_and(model, other) - - join.append(other.on(condition)) - - if limitby := select_kwargs.pop("limitby", ()): - # if limitby + relationships: - # 1. get IDs of main table entries that match 'query' - # 2. change query to .belongs(id) - # 3. add joins etc - - kwargs: SelectKwargs = select_kwargs | {"limitby": limitby} - # if orderby := select_kwargs.get("orderby"): - # kwargs["orderby"] = orderby - - if join: - kwargs["join"] = join - - ids = db(query)._select(model.id, **kwargs) - query = model.id.belongs(ids) - metadata["ids"] = ids - - if join: - select_kwargs["join"] = join - - left = [] - - for key, relation in self.relationships.items(): - other = relation.get_table(db) - method: JOIN_OPTIONS = relation.join or DEFAULT_JOIN_OPTION - - select_fields = ", ".join([str(_) for _ in select_args]) - pre_alias = str(other) - - if f"{other}." not in select_fields: - # no fields of other selected. add .ALL: - select_args.append(other.ALL) - elif f"{other}.id" not in select_fields: - # fields of other selected, but required ID is missing. - select_args.append(other.id) - - if relation.on: - # if it has a .on, it's always a left join! - on = relation.on(model, other) - if not isinstance(on, list): # pragma: no cover - on = [on] - - on = [ - _ - for _ in on - # only allow Expressions (query and such): - if isinstance(_, pydal.objects.Expression) - ] - left.extend(on) - elif method == "left": - # .on not given, generate it: - other = other.with_alias(f"{key}_{hash(relation)}") - condition = typing.cast(Query, relation.condition(model, other)) - if callable(relation.condition_and): - condition &= relation.condition_and(model, other) - left.append(other.on(condition)) - else: - # else: inner join (handled earlier) - other = other.with_alias(f"{key}_{hash(relation)}") # only for replace - - # if no fields of 'other' are included, add other.ALL - # else: only add other.id if missing - select_fields = ", ".join([str(_) for _ in select_args]) - - post_alias = str(other).split(" AS ")[-1] - if pre_alias != post_alias: - # replace .select's with aliased: - select_fields = select_fields.replace( - f"{pre_alias}.", - f"{post_alias}.", - ) - - select_args = select_fields.split(", ") - - select_kwargs["left"] = left - return query, select_args - - def _collect_with_relationships( - self, - rows: Rows, - metadata: Metadata, - _to: Type["TypedRows[Any]"], - ) -> "TypedRows[T_MetaInstance]": - """ - Transform the raw rows into Typed Table model instances. - """ - db = self._get_db() - main_table = self.model._ensure_table_defined() - - # id: Model - records = {} - - # id: [Row] - raw_per_id = defaultdict(list) - - seen_relations: dict[str, set[str]] = defaultdict(set) # main id -> set of col + id for relation - - for row in rows: - main = row[main_table] - main_id = main.id - - raw_per_id[main_id].append(normalize_table_keys(row)) - - if main_id not in records: - records[main_id] = self.model(main) - records[main_id]._with = list(self.relationships.keys()) - - # setup up all relationship defaults (once) - for col, relationship in self.relationships.items(): - records[main_id][col] = [] if relationship.multiple else None - - # now add other relationship data - for column, relation in self.relationships.items(): - relationship_column = f"{column}_{hash(relation)}" - - # relationship_column works for aliases with the same target column. - # if col + relationship not in the row, just use the regular name. - - relation_data = ( - row[relationship_column] if relationship_column in row else row[relation.get_table_name()] - ) - - if relation_data.id is None: - # always skip None ids - continue - - if f"{column}-{relation_data.id}" in seen_relations[main_id]: - # speed up duplicates - continue - else: - seen_relations[main_id].add(f"{column}-{relation_data.id}") - - relation_table = relation.get_table(db) - # hopefully an instance of a typed table and a regular row otherwise: - instance = relation_table(relation_data) if looks_like(relation_table, TypedTable) else relation_data - - if relation.multiple: - # create list of T - if not isinstance(records[main_id].get(column), list): # pragma: no cover - # should already be set up before! - setattr(records[main_id], column, []) - - records[main_id][column].append(instance) - else: - # create single T - records[main_id][column] = instance - - return _to(rows, self.model, records, metadata=metadata, raw=raw_per_id) - - def collect_or_fail(self, exception: typing.Optional[Exception] = None) -> "TypedRows[T_MetaInstance]": - """ - Call .collect() and raise an error if nothing found. - - Basically unwraps Optional type. - """ - if result := self.collect(): - return result - - if not exception: - exception = ValueError("Nothing found!") - - raise exception - - def __iter__(self) -> typing.Generator[T_MetaInstance, None, None]: - """ - You can start iterating a Query Builder object before calling collect, for ease of use. - """ - yield from self.collect() - - def __count(self, db: TypeDAL, distinct: typing.Optional[bool] = None) -> Query: - # internal, shared logic between .count and ._count - model = self.model - query = self.query - for key, relation in self.relationships.items(): - if (not relation.condition or relation.join != "inner") and not distinct: - continue - - other = relation.get_table(db) - if not distinct: - # todo: can this lead to other issues? - other = other.with_alias(f"{key}_{hash(relation)}") - query &= relation.condition(model, other) - - return query - - def count(self, distinct: typing.Optional[bool] = None) -> int: - """ - Return the amount of rows matching the current query. - """ - db = self._get_db() - query = self.__count(db, distinct=distinct) - - return db(query).count(distinct) - - def _count(self, distinct: typing.Optional[bool] = None) -> str: - """ - Return the SQL for .count(). - """ - db = self._get_db() - query = self.__count(db, distinct=distinct) - - return typing.cast(str, db(query)._count(distinct)) - - def exists(self) -> bool: - """ - Determines if any records exist matching the current query. - - Returns True if one or more records exist; otherwise, False. - - Returns: - bool: A boolean indicating whether any records exist. - """ - return bool(self.count()) - - def __paginate( - self, - limit: int, - page: int = 1, - ) -> "QueryBuilder[T_MetaInstance]": - available = self.count() - - _from = limit * (page - 1) - _to = (limit * page) if limit else available - - metadata: Metadata = {} - - metadata["pagination"] = { - "limit": limit, - "current_page": page, - "max_page": math.ceil(available / limit) if limit else 1, - "rows": available, - "min_max": (_from, _to), - } - - return self._extend(select_kwargs={"limitby": (_from, _to)}, metadata=metadata) - - def paginate(self, limit: int, page: int = 1, verbose: bool = False) -> "PaginatedRows[T_MetaInstance]": - """ - Paginate transforms the more readable `page` and `limit` to pydals internal limit and offset. - - Note: when using relationships, this limit is only applied to the 'main' table and any number of extra rows \ - can be loaded with relationship data! - """ - builder = self.__paginate(limit, page) - - rows = typing.cast(PaginatedRows[T_MetaInstance], builder.collect(verbose=verbose, _to=PaginatedRows)) - - rows._query_builder = builder - return rows - - def _paginate( - self, - limit: int, - page: int = 1, - ) -> str: - builder = self.__paginate(limit, page) - return builder._collect() - - def chunk(self, chunk_size: int) -> typing.Generator["TypedRows[T_MetaInstance]", Any, None]: - """ - Generator that yields rows from a paginated source in chunks. - - This function retrieves rows from a paginated data source in chunks of the - specified `chunk_size` and yields them as TypedRows. - - Example: - ``` - for chunk_of_rows in Table.where(SomeTable.id > 5).chunk(100): - for row in chunk_of_rows: - # Process each row within the chunk. - pass - ``` - """ - page = 1 - - while rows := self.__paginate(chunk_size, page).collect(): - yield rows - page += 1 - - def first(self, verbose: bool = False) -> T_MetaInstance | None: - """ - Get the first row matching the currently built query. - - Also adds paginate, since it would be a waste to select more rows than needed. - """ - if row := self.paginate(page=1, limit=1, verbose=verbose).first(): - return self.model.from_row(row) - else: - return None - - def _first(self) -> str: - return self._paginate(page=1, limit=1) - - def first_or_fail(self, exception: typing.Optional[BaseException] = None, verbose: bool = False) -> T_MetaInstance: - """ - Call .first() and raise an error if nothing found. - - Basically unwraps Optional type. - """ - if inst := self.first(verbose=verbose): - return inst - - if not exception: - exception = ValueError("Nothing found!") - - raise exception - - -S = typing.TypeVar("S") - - -class PaginatedRows(TypedRows[T_MetaInstance]): - """ - Extension on top of rows that is used when calling .paginate() instead of .collect(). - """ - - _query_builder: QueryBuilder[T_MetaInstance] - - @property - def data(self) -> list[T_MetaInstance]: - """ - Get the underlying data. - """ - return list(self.records.values()) - - @property - def pagination(self) -> Pagination: - """ - Get all page info. - """ - pagination_data = self.metadata["pagination"] - - has_next_page = pagination_data["current_page"] < pagination_data["max_page"] - has_prev_page = pagination_data["current_page"] > 1 - return { - "total_items": pagination_data["rows"], - "current_page": pagination_data["current_page"], - "per_page": pagination_data["limit"], - "total_pages": pagination_data["max_page"], - "has_next_page": has_next_page, - "has_prev_page": has_prev_page, - "next_page": pagination_data["current_page"] + 1 if has_next_page else None, - "prev_page": pagination_data["current_page"] - 1 if has_prev_page else None, - } - - def next(self) -> Self: - """ - Get the next page. - """ - data = self.metadata["pagination"] - if data["current_page"] >= data["max_page"]: - raise StopIteration("Final Page") - - return self._query_builder.paginate(limit=data["limit"], page=data["current_page"] + 1) - - def previous(self) -> Self: - """ - Get the previous page. - """ - data = self.metadata["pagination"] - if data["current_page"] <= 1: - raise StopIteration("First Page") - - return self._query_builder.paginate(limit=data["limit"], page=data["current_page"] - 1) - - def as_dict(self, *_: Any, **__: Any) -> PaginateDict: # type: ignore - """ - Convert to a dictionary with pagination info and original data. - - All arguments are ignored! - """ - return {"data": super().as_dict(), "pagination": self.pagination} - - -class TypedSet(pydal.objects.Set): # type: ignore # pragma: no cover - """ - Used to make pydal Set more typed. - - This class is not actually used, only 'cast' by TypeDAL.__call__ - """ - - def count(self, distinct: typing.Optional[bool] = None, cache: AnyDict = None) -> int: - """ - Count returns an int. - """ - result = super().count(distinct, cache) - return typing.cast(int, result) - - def select(self, *fields: Any, **attributes: Any) -> TypedRows[T_MetaInstance]: - """ - Select returns a TypedRows of a user defined table. - - Example: - result: TypedRows[MyTable] = db(MyTable.id > 0).select() - - for row in result: - reveal_type(row) # MyTable - """ - rows = super().select(*fields, **attributes) - return typing.cast(TypedRows[T_MetaInstance], rows) diff --git a/src/typedal/define.py b/src/typedal/define.py new file mode 100644 index 0000000..42f0f54 --- /dev/null +++ b/src/typedal/define.py @@ -0,0 +1,188 @@ +""" +Seperates the table definition code from core DAL code. + +Since otherwise helper methods would clutter up the TypeDAl class. +""" + +from __future__ import annotations + +import copy +import types +import typing as t +import warnings + +import pydal + +from .constants import BASIC_MAPPINGS +from .core import TypeDAL, evaluate_forward_reference, resolve_annotation +from .fields import TypedField, is_typed_field +from .helpers import ( + all_annotations, + all_dict, + filter_out, + instanciate, + is_union, + origin_is_subclass, + to_snake, +) +from .relationships import Relationship, to_relationship +from .tables import TypedTable +from .types import ( + Field, + T, + T_annotation, + Table, + _Types, +) + +try: + # python 3.14+ + from annotationlib import ForwardRef +except ImportError: # pragma: no cover + # python 3.13- + from typing import ForwardRef + + +class TableDefinitionBuilder: + """Handles the conversion of TypedTable classes to pydal tables.""" + + def __init__(self, db: "TypeDAL"): + """ + Before, the `class_map` was a singleton on the pydal class; now it's per database. + """ + self.db = db + self.class_map: dict[str, t.Type["TypedTable"]] = {} + + def define(self, cls: t.Type[T], **kwargs: t.Any) -> t.Type[T]: + """Build and register a table from a TypedTable class.""" + full_dict = all_dict(cls) + tablename = to_snake(cls.__name__) + annotations = all_annotations(cls) + annotations |= {k: t.cast(type, v) for k, v in full_dict.items() if is_typed_field(v)} + annotations = {k: v for k, v in annotations.items() if not k.startswith("_")} + + typedfields: dict[str, TypedField[t.Any]] = { + k: instanciate(v, True) for k, v in annotations.items() if is_typed_field(v) + } + + relationships: dict[str, type[Relationship[t.Any]]] = filter_out(annotations, Relationship) + fields = {fname: self.to_field(fname, ftype) for fname, ftype in annotations.items()} + + other_kwargs = kwargs | { + k: v for k, v in cls.__dict__.items() if k not in annotations and not k.startswith("_") + } + + for key, field in typedfields.items(): + clone = copy.copy(field) + setattr(cls, key, clone) + typedfields[key] = clone + + relationships = filter_out(full_dict, Relationship) | relationships | filter_out(other_kwargs, Relationship) + + reference_field_keys = [ + k for k, v in fields.items() if str(v.type).split(" ")[0] in ("list:reference", "reference") + ] + + relationships |= { + k: new_relationship + for k in reference_field_keys + if k not in relationships and (new_relationship := to_relationship(cls, k, annotations[k])) + } + + cache_dependency = self.db._config.caching and kwargs.pop("cache_dependency", True) + table: Table = self.db.define_table(tablename, *fields.values(), **kwargs) + + for name, typed_field in typedfields.items(): + field = fields[name] + typed_field.bind(field, table) + + if issubclass(cls, TypedTable): + cls.__set_internals__( + db=self.db, + table=table, + relationships=t.cast(dict[str, Relationship[t.Any]], relationships), + ) + self.class_map[str(table)] = cls + self.class_map[table._rname] = cls + cls.__on_define__(self.db) + else: + warnings.warn("db.define used without inheriting TypedTable. This could lead to strange problems!") + + if not tablename.startswith("typedal_") and cache_dependency: + from .caching import _remove_cache + + table._before_update.append(lambda s, _: _remove_cache(s, tablename)) + table._before_delete.append(lambda s: _remove_cache(s, tablename)) + + return cls + + def to_field(self, fname: str, ftype: type, **kw: t.Any) -> Field: + """Convert annotation to pydal Field.""" + fname = to_snake(fname) + if converted_type := self.annotation_to_pydal_fieldtype(ftype, kw): + return self.build_field(fname, converted_type, **kw) + else: + raise NotImplementedError(f"Unsupported type {ftype}/{type(ftype)}") + + def annotation_to_pydal_fieldtype( + self, + ftype_annotation: T_annotation, + mut_kw: t.MutableMapping[str, t.Any], + ) -> t.Optional[str]: + """Convert Python type annotation to pydal field type string.""" + ftype = t.cast(type, ftype_annotation) # cast from Type to type to make mypy happy) + + if isinstance(ftype, str): + # extract type from string + ftype = resolve_annotation(ftype) + + if isinstance(ftype, ForwardRef): + known_classes = {table.__name__: table for table in self.class_map.values()} + + ftype = evaluate_forward_reference(ftype, namespace=known_classes) + + if mapping := BASIC_MAPPINGS.get(ftype): + # basi types + return mapping + elif isinstance(ftype, pydal.objects.Table): + # db.table + return f"reference {ftype._tablename}" + elif issubclass(type(ftype), type) and issubclass(ftype, TypedTable): + # SomeTable + snakename = to_snake(ftype.__name__) + return f"reference {snakename}" + elif isinstance(ftype, TypedField): + # FieldType(type, ...) + return ftype._to_field(mut_kw, self) + elif origin_is_subclass(ftype, TypedField): + # TypedField[int] + return self.annotation_to_pydal_fieldtype(t.get_args(ftype)[0], mut_kw) + elif isinstance(ftype, types.GenericAlias) and t.get_origin(ftype) in (list, TypedField): # type: ignore + # list[str] -> str -> string -> list:string + _child_type = t.get_args(ftype)[0] + _child_type = self.annotation_to_pydal_fieldtype(_child_type, mut_kw) + return f"list:{_child_type}" + elif is_union(ftype): + # str | int -> UnionType + # typing.Union[str | int] -> typing._UnionGenericAlias + + # Optional[type] == type | None + + match t.get_args(ftype): + case (_child_type, _Types.NONETYPE) | (_Types.NONETYPE, _child_type): + # good union of Nullable + + # if a field is optional, it is nullable: + mut_kw["notnull"] = False + return self.annotation_to_pydal_fieldtype(_child_type, mut_kw) + case _: + # two types is not supported by the db! + return None + else: + return None + + @classmethod + def build_field(cls, name: str, field_type: str, **kw: t.Any) -> Field: + """Create a pydal Field with default kwargs.""" + kw_combined = TypeDAL.default_kwargs | kw + return Field(name, field_type, **kw_combined) diff --git a/src/typedal/fields.py b/src/typedal/fields.py index ac268fd..c5a9691 100644 --- a/src/typedal/fields.py +++ b/src/typedal/fields.py @@ -2,27 +2,248 @@ This file contains available Field types. """ +from __future__ import annotations + import ast +import contextlib import datetime as dt import decimal -import typing +import types +import typing as t import uuid +import pydal from pydal.helpers.classes import SQLCustomType from pydal.objects import Table -from typing_extensions import Unpack -from .core import TypeDAL, TypedField, TypedTable -from .types import FieldSettings +from .core import TypeDAL +from .types import ( + Expression, + Field, + FieldSettings, + Query, + T_annotation, + T_MetaInstance, + T_subclass, + T_Value, + Validator, +) -T = typing.TypeVar("T", bound=typing.Any) +if t.TYPE_CHECKING: + # will be imported for real later: + from .tables import TypedTable ## general +class TypedField(Expression, t.Generic[T_Value]): # pragma: no cover + """ + Typed version of pydal.Field, which will be converted to a normal Field in the background. + """ + + # will be set by .bind on db.define + name = "" + _db: t.Optional[pydal.DAL] = None + _rname: t.Optional[str] = None + _table: t.Optional[Table] = None + _field: t.Optional[Field] = None + + _type: T_annotation + kwargs: t.Any + + requires: Validator | t.Iterable[Validator] + + # NOTE: for the logic of converting a TypedField into a pydal Field, see TypeDAL._to_field + + def __init__( + self, + _type: t.Type[T_Value] | types.UnionType = str, # type: ignore + /, + **settings: t.Unpack[FieldSettings], + ) -> None: + """ + Typed version of pydal.Field, which will be converted to a normal Field in the background. + + Provide the Python type for this field as the first positional argument + and any other settings to Field() as keyword parameters. + """ + self._type = _type + self.kwargs = settings + # super().__init__() + + @t.overload + def __get__(self, instance: T_MetaInstance, owner: t.Type[T_MetaInstance]) -> T_Value: # pragma: no cover + """ + row.field -> (actual data). + """ + + @t.overload + def __get__(self, instance: None, owner: "t.Type[TypedTable]") -> "TypedField[T_Value]": # pragma: no cover + """ + Table.field -> Field. + """ + + def __get__( + self, + instance: T_MetaInstance | None, + owner: t.Type[T_MetaInstance], + ) -> t.Union[T_Value, "TypedField[T_Value]"]: + """ + Since this class is a Descriptor field, \ + it returns something else depending on if it's called on a class or instance. + + (this is mostly for mypy/typing) + """ + if instance: + # this is only reached in a very specific case: + # an instance of the object was created with a specific set of fields selected (excluding the current one) + # in that case, no value was stored in the owner -> return None (since the field was not selected) + return t.cast(T_Value, None) # cast as T_Value so mypy understands it for selected fields + else: + # getting as class -> return actual field so pydal understands it when using in query etc. + return t.cast(TypedField[T_Value], self._field) # pretend it's still typed for IDE support + + def __str__(self) -> str: + """ + String representation of a Typed Field. + + If `type` is set explicitly (e.g. TypedField(str, type="text")), that type is used: `TypedField.text`, + otherwise the type annotation is used (e.g. TypedField(str) -> TypedField.str) + """ + return str(self._field) if self._field else "" + + def __repr__(self) -> str: + """ + More detailed string representation of a Typed Field. + + Uses __str__ and adds the provided extra options (kwargs) in the representation. + """ + string_value = self.__str__() + + if "type" in self.kwargs: + # manual type in kwargs supplied + typename = self.kwargs["type"] + elif issubclass(type, type(self._type)): + # normal type, str.__name__ = 'str' + typename = getattr(self._type, "__name__", str(self._type)) + elif t_args := t.get_args(self._type): + # list[str] -> 'str' + typename = t_args[0].__name__ + else: # pragma: no cover + # fallback - something else, may not even happen, I'm not sure + typename = self._type + + string_value = f"TypedField[{typename}].{string_value}" if string_value else f"TypedField[{typename}]" + + kw = self.kwargs.copy() + kw.pop("type", None) + return f"<{string_value} with options {kw}>" + + def _to_field(self, extra_kwargs: t.MutableMapping[str, t.Any], builder: TableDefinitionBuilder) -> t.Optional[str]: + """ + Convert a Typed Field instance to a pydal.Field. + + Actual logic in TypeDAL._to_field but this function creates the pydal type name and updates the kwarg settings. + """ + other_kwargs = self.kwargs.copy() + extra_kwargs.update(other_kwargs) # <- modifies and overwrites the default kwargs with user-specified ones + return extra_kwargs.pop("type", False) or builder.annotation_to_pydal_fieldtype( + self._type, + extra_kwargs, + ) + + def bind(self, field: pydal.objects.Field, table: pydal.objects.Table) -> None: + """ + Bind the right db/table/field info to this class, so queries can be made using `Class.field == ...`. + """ + self._table = table + self._field = field + + def __getattr__(self, key: str) -> t.Any: + """ + If the regular getattribute does not work, try to get info from the related Field. + """ + with contextlib.suppress(AttributeError): + return super().__getattribute__(key) + + # try on actual field: + return getattr(self._field, key) + + def __eq__(self, other: t.Any) -> Query: + """ + Performing == on a Field will result in a Query. + """ + return t.cast(Query, self._field == other) + + def __ne__(self, other: t.Any) -> Query: + """ + Performing != on a Field will result in a Query. + """ + return t.cast(Query, self._field != other) + + def __gt__(self, other: t.Any) -> Query: + """ + Performing > on a Field will result in a Query. + """ + return t.cast(Query, self._field > other) + + def __lt__(self, other: t.Any) -> Query: + """ + Performing < on a Field will result in a Query. + """ + return t.cast(Query, self._field < other) + + def __ge__(self, other: t.Any) -> Query: + """ + Performing >= on a Field will result in a Query. + """ + return t.cast(Query, self._field >= other) + + def __le__(self, other: t.Any) -> Query: + """ + Performing <= on a Field will result in a Query. + """ + return t.cast(Query, self._field <= other) + + def __hash__(self) -> int: + """ + Shadow Field.__hash__. + """ + return hash(self._field) + + def __invert__(self) -> Expression: + """ + Performing ~ on a Field will result in an Expression. + """ + if not self._field: # pragma: no cover + raise ValueError("Unbound Field can not be inverted!") + + return t.cast(Expression, ~self._field) + + def lower(self) -> Expression: + """ + For string-fields: compare lowercased values. + """ + if not self._field: # pragma: no cover + raise ValueError("Unbound Field can not be lowered!") + + return t.cast(Expression, self._field.lower()) + + +def is_typed_field(cls: t.Any) -> t.TypeGuard["TypedField[t.Any]"]: + """ + Is `cls` an instance or subclass of TypedField? + + Deprecated + """ + return isinstance(cls, TypedField) or ( + isinstance(t.get_origin(cls), type) and issubclass(t.get_origin(cls), TypedField) + ) + + ## specific -def StringField(**kw: Unpack[FieldSettings]) -> TypedField[str]: +def StringField(**kw: t.Unpack[FieldSettings]) -> TypedField[str]: """ Pydal type is string, Python type is str. """ @@ -33,7 +254,7 @@ def StringField(**kw: Unpack[FieldSettings]) -> TypedField[str]: String = StringField -def TextField(**kw: Unpack[FieldSettings]) -> TypedField[str]: +def TextField(**kw: t.Unpack[FieldSettings]) -> TypedField[str]: """ Pydal type is text, Python type is str. """ @@ -44,7 +265,7 @@ def TextField(**kw: Unpack[FieldSettings]) -> TypedField[str]: Text = TextField -def BlobField(**kw: Unpack[FieldSettings]) -> TypedField[bytes]: +def BlobField(**kw: t.Unpack[FieldSettings]) -> TypedField[bytes]: """ Pydal type is blob, Python type is bytes. """ @@ -55,7 +276,7 @@ def BlobField(**kw: Unpack[FieldSettings]) -> TypedField[bytes]: Blob = BlobField -def BooleanField(**kw: Unpack[FieldSettings]) -> TypedField[bool]: +def BooleanField(**kw: t.Unpack[FieldSettings]) -> TypedField[bool]: """ Pydal type is boolean, Python type is bool. """ @@ -66,7 +287,7 @@ def BooleanField(**kw: Unpack[FieldSettings]) -> TypedField[bool]: Boolean = BooleanField -def IntegerField(**kw: Unpack[FieldSettings]) -> TypedField[int]: +def IntegerField(**kw: t.Unpack[FieldSettings]) -> TypedField[int]: """ Pydal type is integer, Python type is int. """ @@ -77,7 +298,7 @@ def IntegerField(**kw: Unpack[FieldSettings]) -> TypedField[int]: Integer = IntegerField -def DoubleField(**kw: Unpack[FieldSettings]) -> TypedField[float]: +def DoubleField(**kw: t.Unpack[FieldSettings]) -> TypedField[float]: """ Pydal type is double, Python type is float. """ @@ -88,7 +309,7 @@ def DoubleField(**kw: Unpack[FieldSettings]) -> TypedField[float]: Double = DoubleField -def DecimalField(n: int, m: int, **kw: Unpack[FieldSettings]) -> TypedField[decimal.Decimal]: +def DecimalField(n: int, m: int, **kw: t.Unpack[FieldSettings]) -> TypedField[decimal.Decimal]: """ Pydal type is decimal, Python type is Decimal. """ @@ -99,7 +320,7 @@ def DecimalField(n: int, m: int, **kw: Unpack[FieldSettings]) -> TypedField[deci Decimal = DecimalField -def DateField(**kw: Unpack[FieldSettings]) -> TypedField[dt.date]: +def DateField(**kw: t.Unpack[FieldSettings]) -> TypedField[dt.date]: """ Pydal type is date, Python type is datetime.date. """ @@ -110,7 +331,7 @@ def DateField(**kw: Unpack[FieldSettings]) -> TypedField[dt.date]: Date = DateField -def TimeField(**kw: Unpack[FieldSettings]) -> TypedField[dt.time]: +def TimeField(**kw: t.Unpack[FieldSettings]) -> TypedField[dt.time]: """ Pydal type is time, Python type is datetime.time. """ @@ -121,7 +342,7 @@ def TimeField(**kw: Unpack[FieldSettings]) -> TypedField[dt.time]: Time = TimeField -def DatetimeField(**kw: Unpack[FieldSettings]) -> TypedField[dt.datetime]: +def DatetimeField(**kw: t.Unpack[FieldSettings]) -> TypedField[dt.datetime]: """ Pydal type is datetime, Python type is datetime.datetime. """ @@ -132,7 +353,7 @@ def DatetimeField(**kw: Unpack[FieldSettings]) -> TypedField[dt.datetime]: Datetime = DatetimeField -def PasswordField(**kw: Unpack[FieldSettings]) -> TypedField[str]: +def PasswordField(**kw: t.Unpack[FieldSettings]) -> TypedField[str]: """ Pydal type is password, Python type is str. """ @@ -143,7 +364,7 @@ def PasswordField(**kw: Unpack[FieldSettings]) -> TypedField[str]: Password = PasswordField -def UploadField(**kw: Unpack[FieldSettings]) -> TypedField[str]: +def UploadField(**kw: t.Unpack[FieldSettings]) -> TypedField[str]: """ Pydal type is upload, Python type is str. """ @@ -153,11 +374,10 @@ def UploadField(**kw: Unpack[FieldSettings]) -> TypedField[str]: Upload = UploadField -T_subclass = typing.TypeVar("T_subclass", TypedTable, Table) - def ReferenceField( - other_table: str | typing.Type[TypedTable] | TypedTable | Table | T_subclass, **kw: Unpack[FieldSettings] + other_table: str | t.Type[TypedTable] | TypedTable | Table | T_subclass, + **kw: t.Unpack[FieldSettings], ) -> TypedField[int]: """ Pydal type is reference, Python type is int (id). @@ -180,7 +400,7 @@ def ReferenceField( Reference = ReferenceField -def ListStringField(**kw: Unpack[FieldSettings]) -> TypedField[list[str]]: +def ListStringField(**kw: t.Unpack[FieldSettings]) -> TypedField[list[str]]: """ Pydal type is list:string, Python type is list of str. """ @@ -191,7 +411,7 @@ def ListStringField(**kw: Unpack[FieldSettings]) -> TypedField[list[str]]: ListString = ListStringField -def ListIntegerField(**kw: Unpack[FieldSettings]) -> TypedField[list[int]]: +def ListIntegerField(**kw: t.Unpack[FieldSettings]) -> TypedField[list[int]]: """ Pydal type is list:integer, Python type is list of int. """ @@ -202,7 +422,7 @@ def ListIntegerField(**kw: Unpack[FieldSettings]) -> TypedField[list[int]]: ListInteger = ListIntegerField -def ListReferenceField(other_table: str, **kw: Unpack[FieldSettings]) -> TypedField[list[int]]: +def ListReferenceField(other_table: str, **kw: t.Unpack[FieldSettings]) -> TypedField[list[int]]: """ Pydal type is list:reference, Python type is list of int (id). """ @@ -213,7 +433,7 @@ def ListReferenceField(other_table: str, **kw: Unpack[FieldSettings]) -> TypedFi ListReference = ListReferenceField -def JSONField(**kw: Unpack[FieldSettings]) -> TypedField[object]: +def JSONField(**kw: t.Unpack[FieldSettings]) -> TypedField[object]: """ Pydal type is json, Python type is object (can be anything JSON-encodable). """ @@ -221,7 +441,7 @@ def JSONField(**kw: Unpack[FieldSettings]) -> TypedField[object]: return TypedField(object, **kw) -def BigintField(**kw: Unpack[FieldSettings]) -> TypedField[int]: +def BigintField(**kw: t.Unpack[FieldSettings]) -> TypedField[int]: """ Pydal type is bigint, Python type is int. """ @@ -241,7 +461,7 @@ def BigintField(**kw: Unpack[FieldSettings]) -> TypedField[int]: ) -def TimestampField(**kw: Unpack[FieldSettings]) -> TypedField[dt.datetime]: +def TimestampField(**kw: t.Unpack[FieldSettings]) -> TypedField[dt.datetime]: """ Database type is timestamp, Python type is datetime. @@ -275,7 +495,7 @@ def safe_decode_native_point(value: str | None) -> tuple[float, ...]: try: parsed = ast.literal_eval(value) - return typing.cast(tuple[float, ...], parsed) + return t.cast(tuple[float, ...], parsed) except ValueError: # pragma: no cover # should not happen when inserted with `safe_encode_native_point` but you never know return () @@ -328,7 +548,7 @@ def safe_encode_native_point(value: tuple[str, str] | tuple[float, float] | str) ) -def PointField(**kw: Unpack[FieldSettings]) -> TypedField[tuple[float, float]]: +def PointField(**kw: t.Unpack[FieldSettings]) -> TypedField[tuple[float, float]]: """ Database type is point, Python type is tuple[float, float]. """ @@ -344,9 +564,14 @@ def PointField(**kw: Unpack[FieldSettings]) -> TypedField[tuple[float, float]]: ) -def UUIDField(**kw: Unpack[FieldSettings]) -> TypedField[uuid.UUID]: +def UUIDField(**kw: t.Unpack[FieldSettings]) -> TypedField[uuid.UUID]: """ Database type is uuid, Python type is UUID. """ kw["type"] = NativeUUIDField return TypedField(uuid.UUID, **kw) + + +# note: import at the end to prevent circular imports: +from .define import TableDefinitionBuilder # noqa: E402 +from .tables import TypedTable # noqa: E402 diff --git a/src/typedal/for_web2py.py b/src/typedal/for_web2py.py index b72300b..0cd9fd8 100644 --- a/src/typedal/for_web2py.py +++ b/src/typedal/for_web2py.py @@ -6,7 +6,7 @@ from pydal.validators import IS_NOT_IN_DB -from .core import TypeDAL, TypedField, TypedTable +from . import TypeDAL, TypedField, TypedTable from .fields import TextField from .web2py_py4web_shared import AuthUser diff --git a/src/typedal/helpers.py b/src/typedal/helpers.py index e5bcf6d..5ba6031 100644 --- a/src/typedal/helpers.py +++ b/src/typedal/helpers.py @@ -7,19 +7,25 @@ import datetime as dt import fnmatch import io +import re +import sys import types -import typing +import typing as t from collections import ChainMap -from typing import Any from pydal import DAL -from .types import AnyDict, Expression, Field, Table +from .types import AnyDict, Expression, Field, Row, T, Table, Template # type: ignore -if typing.TYPE_CHECKING: - from . import TypeDAL, TypedField, TypedTable +try: + import annotationlib +except ImportError: # pragma: no cover + annotationlib = None + +if t.TYPE_CHECKING: + from string.templatelib import Interpolation -T = typing.TypeVar("T") + from . import TypeDAL, TypedField, TypedTable def is_union(some_type: type | types.UnionType) -> bool: @@ -27,19 +33,34 @@ def is_union(some_type: type | types.UnionType) -> bool: Check if a type is some type of Union. Args: - some_type: types.UnionType = type(int | str); typing.Union = typing.Union[int, str] + some_type: types.UnionType = type(int | str); t.Union = t.Union[int, str] """ - return typing.get_origin(some_type) in (types.UnionType, typing.Union) + return t.get_origin(some_type) in (types.UnionType, t.Union) -def reversed_mro(cls: type) -> typing.Iterable[type]: +def reversed_mro(cls: type) -> t.Iterable[type]: """ Get the Method Resolution Order (mro) for a class, in reverse order to be used with ChainMap. """ return reversed(getattr(cls, "__mro__", [])) +def _cls_annotations(c: type) -> dict[str, type]: # pragma: no cover + """ + Functions to get the annotations of a class (excl inherited, use _all_annotations for that). + + Uses `annotationlib` if available (since 3.14) and if so, resolves forward references immediately. + """ + if annotationlib: + return t.cast( + dict[str, type], + annotationlib.get_annotations(c, format=annotationlib.Format.VALUE, eval_str=True), + ) + else: + return getattr(c, "__annotations__", {}) + + def _all_annotations(cls: type) -> ChainMap[str, type]: """ Returns a dictionary-like ChainMap that includes annotations for all \ @@ -47,7 +68,7 @@ def _all_annotations(cls: type) -> ChainMap[str, type]: """ # chainmap reverses the iterable, so reverse again beforehand to keep order normally: - return ChainMap(*(c.__annotations__ for c in reversed_mro(cls) if "__annotations__" in c.__dict__)) + return ChainMap(*(_cls_annotations(c) for c in reversed_mro(cls))) def all_dict(cls: type) -> AnyDict: @@ -57,9 +78,9 @@ def all_dict(cls: type) -> AnyDict: return dict(ChainMap(*(c.__dict__ for c in reversed_mro(cls)))) # type: ignore -def all_annotations(cls: type, _except: typing.Optional[typing.Iterable[str]] = None) -> dict[str, type]: +def all_annotations(cls: type, _except: t.Optional[t.Iterable[str]] = None) -> dict[str, type]: """ - Wrapper around `_all_annotations` that filters away any keys in _except. + Wrapper around `_all_annotations` that filters away t.Any keys in _except. It also flattens the ChainMap to a regular dict. """ @@ -70,7 +91,7 @@ def all_annotations(cls: type, _except: typing.Optional[typing.Iterable[str]] = return {k: v for k, v in _all.items() if k not in _except} -def instanciate(cls: typing.Type[T] | T, with_args: bool = False) -> T: +def instanciate(cls: t.Type[T] | T, with_args: bool = False) -> T: """ Create an instance of T (if it is a class). @@ -80,20 +101,20 @@ def instanciate(cls: typing.Type[T] | T, with_args: bool = False) -> T: If with_args: spread the generic args into the class creation (needed for e.g. TypedField(str), but not for list[str]) """ - if inner_cls := typing.get_origin(cls): + if inner_cls := t.get_origin(cls): if not with_args: - return typing.cast(T, inner_cls()) + return t.cast(T, inner_cls()) - args = typing.get_args(cls) - return typing.cast(T, inner_cls(*args)) + args = t.get_args(cls) + return t.cast(T, inner_cls(*args)) if isinstance(cls, type): - return typing.cast(T, cls()) + return t.cast(T, cls()) return cls -def origin_is_subclass(obj: Any, _type: type) -> bool: +def origin_is_subclass(obj: t.Any, _type: type) -> bool: """ Check if the origin of a generic is a subclass of _type. @@ -101,15 +122,13 @@ def origin_is_subclass(obj: Any, _type: type) -> bool: origin_is_subclass(list[str], list) -> True """ return bool( - typing.get_origin(obj) - and isinstance(typing.get_origin(obj), type) - and issubclass(typing.get_origin(obj), _type), + t.get_origin(obj) and isinstance(t.get_origin(obj), type) and issubclass(t.get_origin(obj), _type), ) def mktable( - data: dict[Any, Any], - header: typing.Optional[typing.Iterable[str] | range] = None, + data: dict[t.Any, t.Any], + header: t.Optional[t.Iterable[str] | range] = None, skip_first: bool = True, ) -> str: """ @@ -154,11 +173,11 @@ def mktable( return output.getvalue() -K = typing.TypeVar("K") -V = typing.TypeVar("V") +K = t.TypeVar("K") +V = t.TypeVar("V") -def looks_like(v: Any, _type: type[Any]) -> bool: +def looks_like(v: t.Any, _type: type[t.Any]) -> bool: """ Returns true if v or v's class is of type _type, including if it is a generic. @@ -186,19 +205,19 @@ def unwrap_type(_type: type) -> type: Example: list[list[str]] -> str """ - while args := typing.get_args(_type): + while args := t.get_args(_type): _type = args[0] return _type -@typing.overload +@t.overload def extract_type_optional(annotation: T) -> tuple[T, bool]: """ T -> T is not exactly right because you'll get the inner type, but mypy seems happy with this. """ -@typing.overload +@t.overload def extract_type_optional(annotation: None) -> tuple[None, bool]: """ None leads to None, False. @@ -212,10 +231,10 @@ def extract_type_optional(annotation: T | None) -> tuple[T | None, bool]: if annotation is None: return None, False - if origin := typing.get_origin(annotation): - args = typing.get_args(annotation) + if origin := t.get_origin(annotation): + args = t.get_args(annotation) - if origin in (typing.Union, types.UnionType, typing.Optional) and args: + if origin in (t.Union, types.UnionType, t.Optional) and args: # remove None: return next(_ for _ in args if _ and _ != types.NoneType and not isinstance(_, types.NoneType)), True @@ -256,7 +275,7 @@ def __bool__(self) -> bool: return False -def as_lambda(value: T) -> typing.Callable[..., T]: +def as_lambda(value: T) -> t.Callable[..., T]: """ Wrap value in a callable. """ @@ -289,21 +308,21 @@ def get_db(table: "TypedTable | Table") -> "DAL": """ Get the underlying DAL instance for a pydal or typedal table. """ - return typing.cast("DAL", table._db) + return t.cast("DAL", table._db) def get_table(table: "TypedTable | Table") -> "Table": """ Get the underlying pydal table for a typedal table. """ - return typing.cast("Table", table._table) + return t.cast("Table", table._table) -def get_field(field: "TypedField[typing.Any] | Field") -> "Field": +def get_field(field: "TypedField[t.Any] | Field") -> "Field": """ Get the underlying pydal field from a typedal field. """ - return typing.cast( + return t.cast( "Field", field, # Table.field already is a Field, but cast to make sure the editor knows this too. ) @@ -314,7 +333,7 @@ class classproperty: Combination of @classmethod and @property. """ - def __init__(self, fget: typing.Callable[..., typing.Any]) -> None: + def __init__(self, fget: t.Callable[..., t.Any]) -> None: """ Initialize the classproperty. @@ -323,7 +342,7 @@ def __init__(self, fget: typing.Callable[..., typing.Any]) -> None: """ self.fget = fget - def __get__(self, obj: typing.Any, owner: typing.Type[T]) -> typing.Any: + def __get__(self, obj: t.Any, owner: t.Type[T]) -> t.Any: """ Retrieve the property value. @@ -337,7 +356,7 @@ def __get__(self, obj: typing.Any, owner: typing.Type[T]) -> typing.Any: return self.fget(owner) -def smarter_adapt(db: TypeDAL, placeholder: Any) -> str: +def smarter_adapt(db: TypeDAL, placeholder: t.Any) -> str: """ Smarter adaptation of placeholder to quote if needed. @@ -349,7 +368,7 @@ def smarter_adapt(db: TypeDAL, placeholder: Any) -> str: Quoted placeholder if needed, except for numbers (smart_adapt logic) or fields/tables (use already quoted rname). """ - return typing.cast( + return t.cast( str, getattr(placeholder, "sql_shortref", None) # for tables or getattr(placeholder, "sqlsafe", None) # for fields @@ -357,26 +376,123 @@ def smarter_adapt(db: TypeDAL, placeholder: Any) -> str: ) -def sql_escape(db: TypeDAL, sql_fragment: str, *raw_args: Any, **raw_kwargs: Any) -> str: +# https://docs.python.org/3.14/library/string.templatelib.html +SYSTEM_SUPPORTS_TEMPLATES = sys.version_info > (3, 14) + + +def process_tstring(template: Template, operation: t.Callable[["Interpolation"], str]) -> str: # pragma: no cover """ - Generates escaped SQL fragments with placeholders. + Process a Template string by applying an operation to each interpolation. + + This function iterates through a Template object, which contains both string literals + and Interpolation objects. String literals are preserved as-is, while Interpolation + objects are transformed using the provided operation function. Args: - db: Database object. - sql_fragment: SQL fragment with placeholders. - *raw_args: Positional arguments to be escaped. - **raw_kwargs: Keyword arguments to be escaped. + template: A Template object containing mixed string literals and Interpolation objects. + operation: A callable that takes an Interpolation object and returns a string. + This function will be applied to each interpolated value in the template. + + Returns: + str: The processed string with all interpolations replaced by the results of + applying the operation function. + + Example: + Basic f-string functionality can be implemented as: + + >>> def fstring_operation(interpolation): + ... return str(interpolation.value) + >>> value = "test" + >>> template = t"{value = }" # Template string literal + >>> result = process_tstring(template, fstring_operation) + >>> print(result) # "value = test" + + Note: + This is a generic template processor. The specific behavior depends entirely + on the operation function provided. + """ + return "".join(part if isinstance(part, str) else operation(part) for part in template) + + +def sql_escape_template(db: TypeDAL, sql_fragment: Template) -> str: # pragma: no cover + r""" + Safely escape a Template string for SQL execution using database-specific escaping. + + This function processes a Template string (t-string) by escaping all interpolated + values using the database adapter's escape mechanism, preventing SQL injection + attacks while maintaining the structure of the SQL query. + + Args: + db: TypeDAL database connection object that provides the adapter for escaping. + sql_fragment: A Template object (t-string) containing SQL with interpolated values. + The interpolated values will be automatically escaped. + + Returns: + str: SQL string with all interpolated values properly escaped for safe execution. + + Example: + >>> user_input = "'; DROP TABLE users; --" + >>> query = t"SELECT * FROM users WHERE name = {user_input}" + >>> safe_query = sql_escape_template(db, query) + >>> print(safe_query) # "SELECT * FROM users WHERE name = '\'; DROP TABLE users; --'" + + Security: + This function is essential for preventing SQL injection attacks when using + user-provided data in SQL queries. All interpolated values are escaped + according to the database adapter's rules. + + Note: + Only available in Python 3.14+ when SYSTEM_SUPPORTS_TEMPLATES is True. + For earlier Python versions, use sql_escape() with string formatting. + """ + return process_tstring(sql_fragment, lambda part: smarter_adapt(db, part.value)) + + +def sql_escape(db: TypeDAL, sql_fragment: str | Template, *raw_args: t.Any, **raw_kwargs: t.Any) -> str: + """ + Generate escaped SQL fragments with safely substituted placeholders. + + This function provides secure SQL string construction by escaping all provided + arguments using the database adapter's escaping mechanism. It supports both + traditional string formatting (Python < 3.14) and Template strings (Python 3.14+). + + Args: + db: TypeDAL database connection object that provides the adapter for escaping. + sql_fragment: SQL fragment with placeholders (%s for positional, %(name)s for named). + In Python 3.14+, this can also be a Template (t-string) with + interpolated values that will be automatically escaped. + *raw_args: Positional arguments to be escaped and substituted for %s placeholders. + Only use with string fragments, not Template objects. + **raw_kwargs: Keyword arguments to be escaped and substituted for %(name)s placeholders. + Only use with string fragments, not Template objects. Returns: - Escaped SQL fragment with placeholders replaced with escaped values. + str: SQL fragment with all placeholders replaced by properly escaped values. Raises: - ValueError: If both args and kwargs are provided. + ValueError: If both positional and keyword arguments are provided simultaneously. + + Examples: + Positional arguments: + >>> safe_sql = sql_escape(db, "SELECT * FROM users WHERE id = %s", user_id) + + Keyword arguments: + >>> safe_sql = sql_escape(db, "SELECT * FROM users WHERE name = %(name)s", name=username) + + Template strings (Python 3.14+): + >>> safe_sql = sql_escape(db, t"SELECT * FROM users WHERE id = {user_id}") + + Security: + All arguments are escaped using the database adapter's escaping rules to prevent + SQL injection attacks. Never concatenate user input directly into SQL strings. """ if raw_args and raw_kwargs: # pragma: no cover raise ValueError("Please provide either args or kwargs, not both.") - elif raw_args: + if SYSTEM_SUPPORTS_TEMPLATES and isinstance(sql_fragment, Template): # pragma: no cover + return sql_escape_template(db, sql_fragment) + + if raw_args: # list return sql_fragment % tuple(smarter_adapt(db, placeholder) for placeholder in raw_args) else: @@ -386,23 +502,58 @@ def sql_escape(db: TypeDAL, sql_fragment: str, *raw_args: Any, **raw_kwargs: Any def sql_expression( db: TypeDAL, - sql_fragment: str, - *raw_args: Any, + sql_fragment: str | Template, + *raw_args: t.Any, output_type: str | None = None, - **raw_kwargs: Any, + **raw_kwargs: t.Any, ) -> Expression: """ - Creates a pydal Expression object representing a raw SQL fragment. + Create a PyDAL Expression object from a raw SQL fragment with safe parameter substitution. + + This function combines SQL escaping with PyDAL's Expression system, allowing you to + create database expressions from raw SQL while maintaining security through proper + parameter escaping. Args: - db: The TypeDAL object. - sql_fragment: The raw SQL fragment. - *raw_args: Arguments to be interpolated into the SQL fragment. - output_type: The expected output type of the expression. - **raw_kwargs: Keyword arguments to be interpolated into the SQL fragment. + db: The TypeDAL database connection object. + sql_fragment: Raw SQL fragment with placeholders (%s for positional, %(name)s for named). + In Python 3.14+, this can also be a Template (t-string) with + interpolated values that will be automatically escaped. + *raw_args: Positional arguments to be escaped and interpolated into the SQL fragment. + Only use with string fragments, not Template objects. + output_type: Optional type hint for the expected output type of the expression. + This can help with query analysis and optimization. + **raw_kwargs: Keyword arguments to be escaped and interpolated into the SQL fragment. + Only use with string fragments, not Template objects. Returns: - A pydal Expression object. + Expression: A PyDAL Expression object wrapping the safely escaped SQL fragment. + + Examples: + Creating a complex WHERE clause: + >>> expr = sql_expression(db, + ... "age > %s AND status = %s", + ... 18, "active", + ... output_type="boolean") + >>> query = db(expr).select() + + Using keyword arguments: + >>> expr = sql_expression(db, + ... "EXTRACT(year FROM %(date_col)s) = %(year)s", + ... date_col="created_at", year=2023, + ... output_type="boolean") + + Template strings (Python 3.14+): + >>> min_age = 21 + >>> expr = sql_expression(db, t"age >= {min_age}", output_type="boolean") + + Security: + All parameters are escaped using sql_escape() before being wrapped in the Expression, + ensuring protection against SQL injection attacks. + + Note: + The returned Expression can be used anywhere PyDAL expects an expression, + such as in db().select(), .update(), or .delete() operations. """ safe_sql = sql_escape(db, sql_fragment, *raw_args, **raw_kwargs) @@ -413,3 +564,59 @@ def sql_expression( safe_sql, type=output_type, # optional type hint ) + + +def normalize_table_keys(row: Row, pattern: re.Pattern[str] = re.compile(r"^([a-zA-Z_]+)_(\d{5,})$")) -> Row: + """ + Normalize table keys in a PyDAL Row object by stripping numeric hash suffixes from table names, \ + only if the suffix is 5 or more digits. + + For example: + Row({'articles_12345': {...}}) -> Row({'articles': {...}}) + Row({'articles_123': {...}}) -> unchanged + + Returns: + Row: A new Row object with normalized keys. + """ + new_data: dict[str, t.Any] = {} + for key, value in row.items(): + if match := pattern.match(key): + base, _suffix = match.groups() + normalized_key = base + new_data[normalized_key] = value + else: + new_data[key] = value + return Row(new_data) + + +def default_representer(field: TypedField[T], value: T, table: t.Type[TypedTable]) -> str: + """ + Simply call field.represent on the value. + """ + if represent := getattr(field, "represent", None): + return str(represent(value, table)) + else: + return repr(value) + + +def throw(exc: BaseException) -> t.Never: + """ + Raise the given exception. + + This function provides a functional way to raise exceptions, allowing + exception raising to be used in expressions where a statement wouldn't work. + + Args: + exc: The exception to be raised. + + Returns: + Never returns normally as an exception is always raised. + + Raises: + BaseException: Always raises the provided exception. + + Examples: + >>> value = get_value() or throw(ValueError("No value available")) + >>> result = data.get('key') if data else throw(KeyError("Missing data")) + """ + raise exc diff --git a/src/typedal/mixins.py b/src/typedal/mixins.py index c98246e..ebf6000 100644 --- a/src/typedal/mixins.py +++ b/src/typedal/mixins.py @@ -5,26 +5,22 @@ """ import base64 +import datetime as dt import os -import typing +import typing as t import warnings -from datetime import datetime -from typing import Any, Optional from pydal import DAL from pydal.validators import IS_NOT_IN_DB, ValidationError from slugify import slugify -from .core import ( # noqa F401 - used by example in docstring - QueryBuilder, - T_MetaInstance, - TableMeta, - TypeDAL, - TypedTable, - _TypedTable, -) +from .core import TypeDAL from .fields import DatetimeField, StringField -from .types import OpRow, Set +from .tables import _TypedTable +from .types import OpRow, Set, T_MetaInstance + +if t.TYPE_CHECKING: + from .tables import TypedTable # noqa: F401 class Mixin(_TypedTable): @@ -38,9 +34,9 @@ class Mixin(_TypedTable): ('inconsistent method resolution' or 'metaclass conflicts') """ - __settings__: typing.ClassVar[dict[str, Any]] + __settings__: t.ClassVar[dict[str, t.Any]] - def __init_subclass__(cls, **kwargs: Any): + def __init_subclass__(cls, **kwargs: t.Any): """ Ensures __settings__ exists for other mixins. """ @@ -52,8 +48,8 @@ class TimestampsMixin(Mixin): A Mixin class for adding timestamp fields to a model. """ - created_at = DatetimeField(default=datetime.now, writable=False) - updated_at = DatetimeField(default=datetime.now, writable=False) + created_at = DatetimeField(default=dt.datetime.now, writable=False) + updated_at = DatetimeField(default=dt.datetime.now, writable=False) @classmethod def __on_define__(cls, db: TypeDAL) -> None: @@ -73,7 +69,7 @@ def set_updated_at(_: Set, row: OpRow) -> None: _: Set: Unused parameter. row (OpRow): The row to update. """ - row["updated_at"] = datetime.now() + row["updated_at"] = dt.datetime.now() cls._before_update.append(set_updated_at) @@ -89,7 +85,7 @@ def slug_random_suffix(length: int = 8) -> str: return base64.urlsafe_b64encode(os.urandom(length)).rstrip(b"=").decode().strip("=") -T = typing.TypeVar("T") +T = t.TypeVar("T") # noinspection PyPep8Naming @@ -112,7 +108,7 @@ def __init__( """ super().__init__(db, field, error_message) - def validate(self, original: T, record_id: Optional[int] = None) -> T: + def validate(self, original: T, record_id: t.Optional[int] = None) -> T: """ Performs checks to see if the slug already exists for a different row. """ @@ -154,7 +150,7 @@ class SlugMixin(Mixin): # pub: slug = StringField(unique=True, writable=False) # priv: - __settings__: typing.TypedDict( # type: ignore + __settings__: t.TypedDict( # type: ignore "SlugFieldSettings", { "slug_field": str, @@ -164,10 +160,10 @@ class SlugMixin(Mixin): def __init_subclass__( cls, - slug_field: typing.Optional[str] = None, + slug_field: t.Optional[str] = None, slug_suffix_length: int = 0, - slug_suffix: Optional[int] = None, - **kw: Any, + slug_suffix: t.Optional[int] = None, + **kw: t.Any, ) -> None: """ Bind 'slug field' option to be used later (on_define). @@ -235,7 +231,7 @@ def __on_define__(cls, db: TypeDAL) -> None: slug_field.requires = current_requires @classmethod - def from_slug(cls: typing.Type[T_MetaInstance], slug: str, join: bool = True) -> Optional[T_MetaInstance]: + def from_slug(cls: t.Type[T_MetaInstance], slug: str, join: bool = True) -> t.Optional[T_MetaInstance]: """ Find a row by its slug. """ @@ -246,7 +242,7 @@ def from_slug(cls: typing.Type[T_MetaInstance], slug: str, join: bool = True) -> return builder.first() @classmethod - def from_slug_or_fail(cls: typing.Type[T_MetaInstance], slug: str, join: bool = True) -> T_MetaInstance: + def from_slug_or_fail(cls: t.Type[T_MetaInstance], slug: str, join: bool = True) -> T_MetaInstance: """ Find a row by its slug, or raise an error if it doesn't exist. """ diff --git a/src/typedal/query_builder.py b/src/typedal/query_builder.py new file mode 100644 index 0000000..ea22730 --- /dev/null +++ b/src/typedal/query_builder.py @@ -0,0 +1,1059 @@ +""" +Contains base functionality related to the Query Builder. +""" + +from __future__ import annotations + +import datetime as dt +import math +import typing as t +from collections import defaultdict + +import pydal.objects + +from .constants import DEFAULT_JOIN_OPTION, JOIN_OPTIONS +from .core import TypeDAL +from .fields import TypedField, is_typed_field +from .helpers import DummyQuery, as_lambda, looks_like, normalize_table_keys, throw +from .tables import TypedTable +from .types import ( + CacheMetadata, + Condition, + Expression, + Field, + Metadata, + OnQuery, + OrderBy, + Query, + Rows, + SelectKwargs, + T, + T_MetaInstance, +) + + +class QueryBuilder(t.Generic[T_MetaInstance]): + """ + Abstration on top of pydal's query system. + """ + + model: t.Type[T_MetaInstance] + query: Query + select_args: list[t.Any] + select_kwargs: SelectKwargs + relationships: dict[str, Relationship[t.Any]] + metadata: Metadata + + def __init__( + self, + model: t.Type[T_MetaInstance], + add_query: t.Optional[Query] = None, + select_args: t.Optional[list[t.Any]] = None, + select_kwargs: t.Optional[SelectKwargs] = None, + relationships: dict[str, Relationship[t.Any]] = None, + metadata: Metadata = None, + ): + """ + Normally, you wouldn't manually initialize a QueryBuilder but start using a method on a TypedTable. + + Example: + MyTable.where(...) -> QueryBuilder[MyTable] + """ + self.model = model + table = model._ensure_table_defined() + default_query = t.cast(Query, table.id > 0) + self.query = add_query or default_query + self.select_args = select_args or [] + self.select_kwargs = select_kwargs or {} + self.relationships = relationships or {} + self.metadata = metadata or {} + + def __str__(self) -> str: + """ + Simple string representation for the query builder. + """ + return f"QueryBuilder for {self.model}" + + def __repr__(self) -> str: + """ + Advanced string representation for the query builder. + """ + return ( + f"" + ) + + def __bool__(self) -> bool: + """ + Querybuilder is truthy if it has t.Any conditions. + """ + table = self.model._ensure_table_defined() + default_query = t.cast(Query, table.id > 0) + return any( + [ + self.query != default_query, + self.select_args, + self.select_kwargs, + self.relationships, + self.metadata, + ], + ) + + def _extend( + self, + add_query: t.Optional[Query] = None, + overwrite_query: t.Optional[Query] = None, + select_args: t.Optional[list[t.Any]] = None, + select_kwargs: t.Optional[SelectKwargs] = None, + relationships: dict[str, Relationship[t.Any]] = None, + metadata: Metadata = None, + ) -> "QueryBuilder[T_MetaInstance]": + return QueryBuilder( + self.model, + (add_query & self.query) if add_query else overwrite_query or self.query, + (self.select_args + select_args) if select_args else self.select_args, + (self.select_kwargs | select_kwargs) if select_kwargs else self.select_kwargs, + (self.relationships | relationships) if relationships else self.relationships, + (self.metadata | (metadata or {})) if metadata else self.metadata, + ) + + def select(self, *fields: t.Any, **options: t.Unpack[SelectKwargs]) -> "QueryBuilder[T_MetaInstance]": + """ + Fields: database columns by name ('id'), by field reference (table.id) or other (e.g. table.ALL). + + Options: + paraphrased from the web2py pydal docs, + For more info, see http://www.web2py.com/books/default/chapter/29/06/the-database-abstraction-layer#orderby-groupby-limitby-distinct-having-orderby_on_limitby-join-left-cache + + orderby: field(s) to order by. Supported: + table.name - sort by name, ascending + ~table.name - sort by name, descending + - sort randomly + table.name|table.id - sort by two fields (first name, then id) + + groupby, having: together with orderby: + groupby can be a field (e.g. table.name) to group records by + having can be a query, only those `having` the condition are grouped + + limitby: tuple of min and max. When using the query builder, .paginate(limit, page) is recommended. + distinct: bool/field. Only select rows that differ + orderby_on_limitby (bool, default: True): by default, an implicit orderby is added when doing limitby. + join: othertable.on(query) - do an INNER JOIN. Using TypeDAL relationships with .join() is recommended! + left: othertable.on(query) - do a LEFT JOIN. Using TypeDAL relationships with .join() is recommended! + cache: cache the query result to speed up repeated queries; e.g. (cache=(cache.ram, 3600), cacheable=True) + """ + return self._extend(select_args=list(fields), select_kwargs=options) + + def orderby(self, *fields: OrderBy) -> "QueryBuilder[T_MetaInstance]": + """ + Order the query results by specified fields. + + Args: + fields: field(s) to order by. Supported: + table.name - sort by name, ascending + ~table.name - sort by name, descending + - sort randomly + table.name|table.id - sort by two fields (first name, then id) + + Returns: + QueryBuilder: A new QueryBuilder instance with the ordering applied. + """ + return self.select(orderby=fields) + + def where( + self, + *queries_or_lambdas: Query | t.Callable[[t.Type[T_MetaInstance]], Query] | dict[str, t.Any], + **filters: t.Any, + ) -> "QueryBuilder[T_MetaInstance]": + """ + Extend the builder's query. + + Can be used in multiple ways: + .where(Query) -> with a direct query such as `Table.id == 5` + .where(lambda table: table.id == 5) -> with a query via a lambda + .where(id=5) -> via keyword arguments + + When using multiple where's, they will be ANDed: + .where(lambda table: table.id == 5).where(lambda table: table.id == 6) == (table.id == 5) & (table.id=6) + When passing multiple queries to a single .where, they will be ORed: + .where(lambda table: table.id == 5, lambda table: table.id == 6) == (table.id == 5) | (table.id=6) + """ + new_query = self.query + table = self.model._ensure_table_defined() + + queries_or_lambdas = ( + *queries_or_lambdas, + filters, + ) + + subquery = t.cast(Query, DummyQuery()) + for query_part in queries_or_lambdas: + if isinstance(query_part, (Field, pydal.objects.Field)) or is_typed_field(query_part): + subquery |= t.cast(Query, query_part != None) + elif isinstance(query_part, (pydal.objects.Query, Expression, pydal.objects.Expression)): + subquery |= t.cast(Query, query_part) + elif callable(query_part): + if result := query_part(self.model): + subquery |= result + elif isinstance(query_part, dict): + subsubquery = DummyQuery() + for field, value in query_part.items(): + subsubquery &= table[field] == value + if subsubquery: + subquery |= subsubquery + else: + raise ValueError(f"Unexpected query type ({type(query_part)}).") + + if subquery: + new_query &= subquery + + return self._extend(overwrite_query=new_query) + + def _parse_relationships( + self, fields: t.Iterable[str | t.Type[TypedTable]], method: JOIN_OPTIONS = None, **update: t.Any + ) -> dict[str, Relationship[t.Any]]: + """ + Parse relationship fields into a dict of base relationships with nested relationships. + + Args: + fields: Iterable of relationship field names + (e.g., ['relationship', 'relationship.with_nested', 'relationship.no2']) + condition_and: Optional condition to pass to relationship clones + + Returns: + Dict mapping base relationship names to Relationship objects with nested relationships + Example: {'relationship': Relationship('relationship', + nested={'with_nested': Relationship(), + 'no2': Relationship()})} + """ + relationships: dict[str, Relationship[t.Any]] = {} + base_relationships = self.model.get_relationships() + db = self._get_db() + + for field in fields: + relation_name = str(field) + parts = relation_name.split(".") + base_name = parts[0] + + # Create base relationship if it doesn't exist + if base_name not in relationships: + relationships[base_name] = base_relationships[base_name].clone(join=method, **update) + + # If this is a nested relationship, traverse and add it + if len(parts) > 1: + current = relationships[base_name] + + for level in parts[1:]: + # Check if this nested relationship already exists + if level not in current.nested: + # Create new nested relationship + subrelationship = current.get_table(db).get_relationships()[level].clone(join=method) + current.nested[level] = subrelationship + + current = current.nested[level] + + return relationships + + def join( + self, + *fields: str | t.Type[TypedTable], + method: JOIN_OPTIONS = None, + on: OnQuery | list[Expression] | Expression = None, + condition: Condition = None, + condition_and: Condition = None, + ) -> "QueryBuilder[T_MetaInstance]": + """ + Include relationship fields in the result. + + `fields` can be names of Relationships on the current model. + If no fields are passed, all will be used. + + By default, the `method` defined in the relationship is used. + This can be overwritten with the `method` keyword argument (left or inner) + + `condition_and` can be used to add extra conditions to an inner join. + """ + # todo: allow limiting amount of related rows returned for join? + # todo: it would be nice if 'fields' could be an actual relationship + # (Article.tags = list[Tag]) and you could change the .condition and .on + # this could deprecate condition_and + relationships = self.model.get_relationships() + + if condition and on: + raise ValueError("condition and on can not be used together!") + elif condition: + if len(fields) != 1: + raise ValueError("join(field, condition=...) can only be used with exactly one field!") + + if isinstance(condition, pydal.objects.Query): + condition = as_lambda(condition) + + to_field = t.cast(t.Type[TypedTable], fields[0]) + relationships = { + str(to_field): Relationship(to_field, condition=condition, join=method, condition_and=condition_and) + } + elif on: + if len(fields) != 1: + raise ValueError("join(field, on=...) can only be used with exactly one field!") + + if isinstance(on, pydal.objects.Expression): + on = [on] + + if isinstance(on, list): + on = as_lambda(on) + + to_field = t.cast(t.Type[TypedTable], fields[0]) + relationships = {str(to_field): Relationship(to_field, on=on, join=method, condition_and=condition_and)} + + else: + if fields: + # join on every relationship + # simple: 'relationship' + # -> {'relationship': Relationship('relationship')} + # complex with one: relationship.with_nested + # -> {'relationship': Relationship('relationship', nested=[Relationship('with_nested')]) + # complex with two: relationship.with_nested, relationship.no2 + # -> {'relationship': Relationship('relationship', + # nested=[Relationship('with_nested'), Relationship('no2')]) + + relationships = self._parse_relationships(fields, method=method, condition_and=condition_and) + + if method: + relationships = { + str(k): r.clone(join=method, condition_and=condition_and) for k, r in relationships.items() + } + + return self._extend(relationships=relationships) + + def cache( + self, + *deps: t.Any, + expires_at: t.Optional[dt.datetime] = None, + ttl: t.Optional[int | dt.timedelta] = None, + ) -> "QueryBuilder[T_MetaInstance]": + """ + Enable caching for this query to load repeated calls from a dill row \ + instead of executing the sql and collecing matching rows again. + """ + existing = self.metadata.get("cache", {}) + + metadata: Metadata = {} + + cache_meta = t.cast( + CacheMetadata, + self.metadata.get("cache", {}) + | { + "enabled": True, + "depends_on": existing.get("depends_on", []) + [str(_) for _ in deps], + "expires_at": get_expire(expires_at=expires_at, ttl=ttl), + }, + ) + + metadata["cache"] = cache_meta + return self._extend(metadata=metadata) + + def _get_db(self) -> TypeDAL: + return self.model._db or throw(EnvironmentError("@define or db.define is not called on this class yet!")) + + def _select_arg_convert(self, arg: t.Any) -> t.Any: + # typedfield are not really used at runtime t.Anymore, but leave it in for safety: + if isinstance(arg, TypedField): # pragma: no cover + arg = arg._field + + return arg + + def delete(self) -> list[int]: + """ + Based on the current query, delete rows and return a list of deleted IDs. + """ + db = self._get_db() + removed_ids = [_.id for _ in db(self.query).select("id")] + if db(self.query).delete(): + # success! + return removed_ids + + return [] + + def _delete(self) -> str: + db = self._get_db() + return str(db(self.query)._delete()) + + def update(self, **fields: t.Any) -> list[int]: + """ + Based on the current query, update `fields` and return a list of updated IDs. + """ + # todo: limit? + db = self._get_db() + updated_ids = db(self.query).select("id").column("id") + if db(self.query).update(**fields): + # success! + return updated_ids + + return [] + + def _update(self, **fields: t.Any) -> str: + db = self._get_db() + return str(db(self.query)._update(**fields)) + + def _before_query(self, mut_metadata: Metadata, add_id: bool = True) -> tuple[Query, list[t.Any], SelectKwargs]: + select_args = [self._select_arg_convert(_) for _ in self.select_args] or [self.model.ALL] + select_kwargs = self.select_kwargs.copy() + query = self.query + model = self.model + mut_metadata["query"] = query + # require at least id of main table: + select_fields = ", ".join([str(_) for _ in select_args]) + tablename = str(model) + + if add_id and f"{tablename}.id" not in select_fields: + # fields of other selected, but required ID is missing. + select_args.append(model.id) + + if self.relationships: + query, select_args = self._handle_relationships_pre_select(query, select_args, select_kwargs, mut_metadata) + + return query, select_args, select_kwargs + + def to_sql(self, add_id: bool = False) -> str: + """ + Generate the SQL for the built query. + """ + db = self._get_db() + + query, select_args, select_kwargs = self._before_query({}, add_id=add_id) + + return str(db(query)._select(*select_args, **select_kwargs)) + + def _collect(self) -> str: + """ + Alias for to_sql, pydal-like syntax. + """ + return self.to_sql() + + def _collect_cached(self, metadata: Metadata) -> "TypedRows[T_MetaInstance] | None": + expires_at = metadata["cache"].get("expires_at") + metadata["cache"] |= { + # key is partly dependant on cache metadata but not these: + "key": None, + "status": None, + "cached_at": None, + "expires_at": None, + } + + _, key = create_and_hash_cache_key( + self.model, + metadata, + self.query, + self.select_args, + self.select_kwargs, + self.relationships.keys(), + ) + + # re-set after creating key: + metadata["cache"]["expires_at"] = expires_at + metadata["cache"]["key"] = key + + return load_from_cache(key, self._get_db()) + + def execute(self, add_id: bool = False) -> Rows: + """ + Raw version of .collect which only executes the SQL, without performing t.Any magic afterwards. + """ + db = self._get_db() + metadata = t.cast(Metadata, self.metadata.copy()) + + query, select_args, select_kwargs = self._before_query(metadata, add_id=add_id) + + return db(query).select(*select_args, **select_kwargs) + + def collect( + self, + verbose: bool = False, + _to: t.Type["TypedRows[t.Any]"] = None, + add_id: bool = True, + ) -> "TypedRows[T_MetaInstance]": + """ + Execute the built query and turn it into model instances, while handling relationships. + """ + if _to is None: + _to = TypedRows + + db = self._get_db() + metadata = t.cast(Metadata, self.metadata.copy()) + + if metadata.get("cache", {}).get("enabled") and (result := self._collect_cached(metadata)): + return result + + query, select_args, select_kwargs = self._before_query(metadata, add_id=add_id) + + metadata["sql"] = db(query)._select(*select_args, **select_kwargs) + + if verbose: # pragma: no cover + print(metadata["sql"]) + + rows: Rows = db(query).select(*select_args, **select_kwargs) + + metadata["final_query"] = str(query) + metadata["final_args"] = [str(_) for _ in select_args] + metadata["final_kwargs"] = select_kwargs + + if verbose: # pragma: no cover + print(rows) + + if not self.relationships: + # easy + typed_rows = _to.from_rows(rows, self.model, metadata=metadata) + + else: + # harder: try to match rows to the belonging objects + # assume structure of {'table': } per row. + # if that's not the case, return default behavior again + typed_rows = self._collect_with_relationships(rows, metadata=metadata, _to=_to) + + # only saves if requested in metadata: + return save_to_cache(typed_rows, rows) + + @t.overload + def column(self, field: TypedField[T], **options: t.Unpack[SelectKwargs]) -> list[T]: + """ + If a typedfield is passed, the output type can be safely determined. + """ + + @t.overload + def column(self, field: T, **options: t.Unpack[SelectKwargs]) -> list[T]: + """ + Otherwise, the output type is loosely determined (assumes `field: type` or t.Any). + """ + + def column(self, field: TypedField[T] | T, **options: t.Unpack[SelectKwargs]) -> list[T]: + """ + Get all values in a specific column. + + Shortcut for `.select(field).execute().column(field)`. + """ + return self.select(field, **options).execute().column(field) + + def _handle_relationships_pre_select( + self, + query: Query, + select_args: list[t.Any], + select_kwargs: SelectKwargs, + metadata: Metadata, + ) -> tuple[Query, list[t.Any]]: + """Handle relationship joins and field selection for database query.""" + # Collect all relationship keys including nested ones + metadata["relationships"] = self._collect_all_relationship_keys() + + # Build joins and apply limitby optimization if needed + inner_joins = self._build_inner_joins() + query = self._apply_limitby_optimization(query, select_kwargs, inner_joins, metadata) + + if inner_joins: + select_kwargs["join"] = inner_joins + + # Build left joins and handle field selection + left_joins: list[Expression] = [] + select_args = self._build_left_joins_and_fields(select_args, left_joins) + + select_kwargs["left"] = left_joins + return query, select_args + + def _collect_all_relationship_keys(self) -> set[str]: + """Collect all relationship keys including nested ones.""" + keys = set(self.relationships.keys()) + + for relation in self.relationships.values(): + keys.update(self._collect_nested_keys(relation)) + + return keys + + def _collect_nested_keys(self, relation: Relationship[t.Any], prefix: str = "") -> set[str]: + """Recursively collect nested relationship keys.""" + keys = set() + + for name, nested in relation.nested.items(): + nested_key = f"{prefix}.{name}" if prefix else name + keys.add(nested_key) + keys.update(self._collect_nested_keys(nested, nested_key)) + + return keys + + def _build_inner_joins(self) -> list[t.Any]: + """Build inner joins for relationships with conditions.""" + joins = [] + + for key, relation in self.relationships.items(): + joins.extend(self._build_inner_joins_recursive(relation, self.model, key)) + + return joins + + def _build_inner_joins_recursive( + self, relation: Relationship[t.Any], parent_table: t.Type[TypedTable], key: str, parent_key: str = "" + ) -> list[t.Any]: + """Recursively build inner joins for a relationship and its nested relationships.""" + db = self._get_db() + joins = [] + + # Handle current level + if relation.condition and relation.join == "inner": + other = relation.get_table(db) + other = other.with_alias(f"{key}_{hash(relation)}") + condition = relation.condition(parent_table, other) + + if callable(relation.condition_and): + condition &= relation.condition_and(parent_table, other) + + joins.append(other.on(condition)) + + # Process nested relationships + for nested_name, nested in relation.nested.items(): + # todo: add additional test, deduplicate + nested_key = f"{parent_key}.{nested_name}" if parent_key else f"{key}.{nested_name}" + joins.extend(self._build_inner_joins_recursive(nested, other, nested_name, nested_key)) + + return joins + + def _apply_limitby_optimization( + self, + query: Query, + select_kwargs: SelectKwargs, + joins: list[t.Any], + metadata: Metadata, + ) -> Query: + """Apply limitby optimization when relationships are present.""" + if not (limitby := select_kwargs.pop("limitby", ())): + return query + + db = self._get_db() + model = self.model + + kwargs: SelectKwargs = select_kwargs.copy() + kwargs["limitby"] = limitby + + if joins: + kwargs["join"] = joins + + ids = db(query)._select(model.id, **kwargs) + query = model.id.belongs(ids) + metadata["ids"] = ids + + return query + + def _build_left_joins_and_fields(self, select_args: list[t.Any], left_joins: list[Expression]) -> list[t.Any]: + """ + Build left joins and ensure required fields are selected. + """ + for key, relation in self.relationships.items(): + select_args = self._process_relationship_for_left_join(relation, key, select_args, left_joins, self.model) + + return select_args + + def _process_relationship_for_left_join( + self, + relation: Relationship[t.Any], + key: str, + select_args: list[t.Any], + left_joins: list[Expression], + parent_table: t.Type[TypedTable], + parent_key: str = "", + ) -> list[t.Any]: + """Process a single relationship for left join and field selection.""" + db = self._get_db() + other = relation.get_table(db) + method: JOIN_OPTIONS = relation.join or DEFAULT_JOIN_OPTION + + select_fields = ", ".join([str(_) for _ in select_args]) + pre_alias = str(other) + + # Ensure required fields are selected + select_args = self._ensure_relationship_fields(select_args, other, select_fields) + + # Build join condition + if relation.on: + # Custom .on condition - always left join + on = relation.on(parent_table, other) + if not isinstance(on, list): + on = [on] + + on = [_ for _ in on if isinstance(_, pydal.objects.Expression)] + left_joins.extend(on) + elif method == "left": + # Generate left join condition + other = other.with_alias(f"{key}_{hash(relation)}") + condition = t.cast(Query, relation.condition(parent_table, other)) + + if callable(relation.condition_and): + condition &= relation.condition_and(parent_table, other) + + left_joins.append(other.on(condition)) + else: + # Inner join (handled in _build_inner_joins) + other = other.with_alias(f"{key}_{hash(relation)}") + + # Handle aliasing in select_args + select_args = self._update_select_args_with_alias(select_args, pre_alias, other) + + # Process nested relationships + for nested_name, nested in relation.nested.items(): + # todo: add additional test, deduplicate + nested_key = f"{parent_key}.{nested_name}" if parent_key else f"{key}.{nested_name}" + select_args = self._process_relationship_for_left_join( + nested, nested_name, select_args, left_joins, other, nested_key + ) + + return select_args + + def _ensure_relationship_fields( + self, select_args: list[t.Any], other: t.Type[TypedTable], select_fields: str + ) -> list[t.Any]: + """Ensure required fields from relationship table are selected.""" + if f"{other}." not in select_fields: + # No fields of other selected, add .ALL + select_args.append(other.ALL) + elif f"{other}.id" not in select_fields: + # Fields of other selected, but required ID is missing + select_args.append(other.id) + + return select_args + + def _update_select_args_with_alias( + self, select_args: list[t.Any], pre_alias: str, other: t.Type[TypedTable] + ) -> list[t.Any]: + """Update select_args to use aliased table names.""" + post_alias = str(other).split(" AS ")[-1] + + if pre_alias != post_alias: + select_fields = ", ".join([str(_) for _ in select_args]) + select_fields = select_fields.replace(f"{pre_alias}.", f"{post_alias}.") + select_args = select_fields.split(", ") + + return select_args + + def _collect_with_relationships( + self, + rows: Rows, + metadata: Metadata, + _to: t.Type["TypedRows[T_MetaInstance]"], + ) -> "TypedRows[T_MetaInstance]": + """ + Transform the raw rows into Typed Table model instances with nested relationships. + """ + db = self._get_db() + main_table = self.model._ensure_table_defined() + + # id: Model + records: dict[t.Any, T_MetaInstance] = {} + + # id: [Row] + raw_per_id: dict[t.Any, list[t.Any]] = defaultdict(list) + + # Track what we've seen: main_id -> "column-relation_id" + seen_relations: dict[str, set[str]] = defaultdict(set) + + for row in rows: + main = row[main_table] + main_id = main.id + + raw_per_id[main_id].append(normalize_table_keys(row)) + + if main_id not in records: + records[main_id] = self.model(main) + records[main_id]._with = list(self.relationships.keys()) + + # Setup all relationship defaults (once) + for col, relationship in self.relationships.items(): + records[main_id][col] = [] if relationship.multiple else None + + # Process each top-level relationship + for column, relation in self.relationships.items(): + self._process_relationship_data( + row=row, + column=column, + relation=relation, + parent_record=records[main_id], + parent_id=main_id, + seen_relations=seen_relations, + db=db, + ) + + return _to(rows, self.model, records, metadata=metadata, raw=raw_per_id) + + def _process_relationship_data( + self, + row: t.Any, + column: str, + relation: Relationship[t.Any], + parent_record: t.Any, + parent_id: t.Any, + seen_relations: dict[str, set[str]], + db: t.Any, + path: str = "", + ) -> t.Any | None: + """ + Process relationship data from a row and attach it to the parent record. + + Returns the created instance (for nested processing). + + Args: + row: The database row containing relationship data + column: The relationship column name + relation: The Relationship object + parent_record: The parent model instance to attach data to + parent_id: ID of the parent for tracking + seen_relations: Dict tracking which relationships we've already processed + db: Database instance + path: Current relationship path (e.g., "users.bestie") + + Returns: + The created relationship instance, or None if skipped + """ + # Build the full path for tracking (e.g., "users", "users.bestie", "users.bestie.articles") + current_path = f"{path}.{column}" if path else column + + # Get the relationship column name (with hash for alias) + relationship_column = f"{column}_{hash(relation)}" + + # Get relation data from row + relation_data = row[relationship_column] if relationship_column in row else row.get(relation.get_table_name()) + + # Skip if no data or NULL id + if not relation_data or relation_data.id is None: + return None + + # Check if we've already seen this relationship instance + seen_key = f"{current_path}-{relation_data.id}" + if seen_key in seen_relations[parent_id]: + return None # Already processed + + seen_relations[parent_id].add(seen_key) + + # Create the relationship instance + relation_table = relation.get_table(db) + instance = relation_table(relation_data) if looks_like(relation_table, TypedTable) else relation_data + + # Process nested relationships on this instance + if relation.nested: + self._process_nested_relationships( + row=row, + relation=relation, + instance=instance, + parent_id=parent_id, + seen_relations=seen_relations, + db=db, + path=current_path, + ) + + # Attach to parent + if relation.multiple: + # current_value = parent_record.get(column) + # if not isinstance(current_value, list): + # setattr(parent_record, column, []) + parent_record[column].append(instance) + else: + parent_record[column] = instance + + return instance + + def _process_nested_relationships( + self, + row: t.Any, + relation: Relationship[t.Any], + instance: t.Any, + parent_id: t.Any, + seen_relations: dict[str, set[str]], + db: t.Any, + path: str, + ) -> None: + """ + Process all nested relationships for a given instance. + + Args: + row: The database row containing relationship data + relation: The parent Relationship object containing nested relationships + instance: The instance to attach nested data to + parent_id: ID of the root parent for tracking + seen_relations: Dict tracking which relationships we've already processed + db: Database instance + path: Current relationship path + """ + # Initialize nested relationship defaults on the instance + # Use __dict__ to avoid triggering __get__ descriptors + for nested_col, nested_relation in relation.nested.items(): + if nested_col not in instance.__dict__: + instance.__dict__[nested_col] = [] if nested_relation.multiple else None + + # Process each nested relationship + for nested_col, nested_relation in relation.nested.items(): + self._process_relationship_data( + row=row, + column=nested_col, + relation=nested_relation, + parent_record=instance, + parent_id=parent_id, + seen_relations=seen_relations, + db=db, + path=path, + ) + + def collect_or_fail(self, exception: t.Optional[Exception] = None) -> "TypedRows[T_MetaInstance]": + """ + Call .collect() and raise an error if nothing found. + + Basically unwraps t.Optional type. + """ + return self.collect() or throw(exception or ValueError("Nothing found!")) + + def __iter__(self) -> t.Generator[T_MetaInstance, None, None]: + """ + You can start iterating a Query Builder object before calling collect, for ease of use. + """ + yield from self.collect() + + def __count(self, db: TypeDAL, distinct: t.Optional[bool] = None) -> Query: + # internal, shared logic between .count and ._count + model = self.model + query = self.query + for key, relation in self.relationships.items(): + if (not relation.condition or relation.join != "inner") and not distinct: + continue + + other = relation.get_table(db) + if not distinct: + # todo: can this lead to other issues? + other = other.with_alias(f"{key}_{hash(relation)}") + query &= relation.condition(model, other) + + return query + + def count(self, distinct: t.Optional[bool] = None) -> int: + """ + Return the amount of rows matching the current query. + """ + db = self._get_db() + query = self.__count(db, distinct=distinct) + + return db(query).count(distinct) + + def _count(self, distinct: t.Optional[bool] = None) -> str: + """ + Return the SQL for .count(). + """ + db = self._get_db() + query = self.__count(db, distinct=distinct) + + return t.cast(str, db(query)._count(distinct)) + + def exists(self) -> bool: + """ + Determines if t.Any records exist matching the current query. + + Returns True if one or more records exist; otherwise, False. + + Returns: + bool: A boolean indicating whether t.Any records exist. + """ + return bool(self.count()) + + def __paginate( + self, + limit: int, + page: int = 1, + ) -> "QueryBuilder[T_MetaInstance]": + available = self.count() + + _from = limit * (page - 1) + _to = (limit * page) if limit else available + + metadata: Metadata = {} + + metadata["pagination"] = { + "limit": limit, + "current_page": page, + "max_page": math.ceil(available / limit) if limit else 1, + "rows": available, + "min_max": (_from, _to), + } + + return self._extend(select_kwargs={"limitby": (_from, _to)}, metadata=metadata) + + def paginate(self, limit: int, page: int = 1, verbose: bool = False) -> "PaginatedRows[T_MetaInstance]": + """ + Paginate transforms the more readable `page` and `limit` to pydals internal limit and offset. + + Note: when using relationships, this limit is only applied to the 'main' table and t.Any number of extra rows \ + can be loaded with relationship data! + """ + builder = self.__paginate(limit, page) + + rows = t.cast(PaginatedRows[T_MetaInstance], builder.collect(verbose=verbose, _to=PaginatedRows)) + + rows._query_builder = builder + return rows + + def _paginate( + self, + limit: int, + page: int = 1, + ) -> str: + builder = self.__paginate(limit, page) + return builder._collect() + + def chunk(self, chunk_size: int) -> t.Generator["TypedRows[T_MetaInstance]", t.Any, None]: + """ + Generator that yields rows from a paginated source in chunks. + + This function retrieves rows from a paginated data source in chunks of the + specified `chunk_size` and yields them as TypedRows. + + Example: + ``` + for chunk_of_rows in Table.where(SomeTable.id > 5).chunk(100): + for row in chunk_of_rows: + # Process each row within the chunk. + pass + ``` + """ + page = 1 + + while rows := self.__paginate(chunk_size, page).collect(): + yield rows + page += 1 + + def first(self, verbose: bool = False) -> T_MetaInstance | None: + """ + Get the first row matching the currently built query. + + Also adds paginate, since it would be a waste to select more rows than needed. + """ + if row := self.paginate(page=1, limit=1, verbose=verbose).first(): + return self.model.from_row(row) + else: + return None + + def _first(self) -> str: + return self._paginate(page=1, limit=1) + + def first_or_fail(self, exception: t.Optional[BaseException] = None, verbose: bool = False) -> T_MetaInstance: + """ + Call .first() and raise an error if nothing found. + + Basically unwraps t.Optional type. + """ + return self.first(verbose=verbose) or throw(exception or ValueError("Nothing found!")) + + +# note: these imports exist at the bottom of this file to prevent circular import issues: + +from .caching import ( # noqa: E402 + create_and_hash_cache_key, + get_expire, + load_from_cache, + save_to_cache, +) +from .relationships import Relationship # noqa: E402 +from .rows import PaginatedRows, TypedRows # noqa: E402 diff --git a/src/typedal/relationships.py b/src/typedal/relationships.py new file mode 100644 index 0000000..81d0151 --- /dev/null +++ b/src/typedal/relationships.py @@ -0,0 +1,264 @@ +""" +Contains base functionality related to Relationships. +""" + +import inspect +import typing as t +import warnings + +import pydal.objects + +from .constants import JOIN_OPTIONS +from .core import TypeDAL +from .fields import TypedField +from .helpers import extract_type_optional, looks_like, unwrap_type +from .types import Condition, OnQuery, T_Field + +To_Type = t.TypeVar("To_Type") + + +class Relationship(t.Generic[To_Type]): + """ + Define a relationship to another table. + """ + + _type: t.Type[To_Type] + table: t.Type["TypedTable"] | type | str + condition: Condition + condition_and: Condition + on: OnQuery + multiple: bool + join: JOIN_OPTIONS + nested: dict[str, t.Self] + + def __init__( + self, + _type: t.Type[To_Type], + condition: Condition = None, + join: JOIN_OPTIONS = None, + on: OnQuery = None, + condition_and: Condition = None, + nested: dict[str, t.Self] = None, + ): + """ + Should not be called directly, use relationship() instead! + """ + if condition and on: + warnings.warn(f"Relation | Both specified! {condition=} {on=} {_type=}") + raise ValueError("Please specify either a condition or an 'on' statement for this relationship!") + + self._type = _type + self.condition = condition + self.join = "left" if on else join # .on is always left join! + self.on = on + self.condition_and = condition_and + + if args := t.get_args(_type): + self.table = unwrap_type(args[0]) + self.multiple = True + else: + self.table = t.cast(type[TypedTable], _type) + self.multiple = False + + if isinstance(self.table, str): + self.table = TypeDAL.to_snake(self.table) + + self.nested = nested or {} + + def clone(self, **update: t.Any) -> "Relationship[To_Type]": + """ + Create a copy of the relationship, possibly updated. + """ + return self.__class__( + update.get("_type") or self._type, + update.get("condition") or self.condition, + update.get("join") or self.join, + update.get("on") or self.on, + update.get("condition_and") or self.condition_and, + (self.nested | extra) if (extra := update.get("nested")) else self.nested, # type: ignore + ) + + def __repr__(self) -> str: + """ + Representation of the relationship. + """ + if callback := self.condition or self.on: + src_code = inspect.getsource(callback).strip() + + if c_and := self.condition_and: + and_code = inspect.getsource(c_and).strip() + src_code += " AND " + and_code + else: + cls_name = self._type if isinstance(self._type, str) else self._type.__name__ + src_code = f"to {cls_name} (missing condition)" + + join = f":{self.join}" if self.join else "" + return f"" + + def get_table(self, db: "TypeDAL") -> t.Type["TypedTable"]: + """ + Get the table this relationship is bound to. + """ + table = self.table # can be a string because db wasn't available yet + + if isinstance(table, str): + if mapped := db._class_map.get(table): + # yay + return mapped + + # boo, fall back to untyped table but pretend it is typed: + return t.cast(t.Type["TypedTable"], db[table]) # eh close enough! + + return table + + def get_table_name(self) -> str: + """ + Get the name of the table this relationship is bound to. + """ + if isinstance(self.table, str): + return self.table + + if isinstance(self.table, pydal.objects.Table): + return str(self.table) + + # else: typed table + try: + table = self.table._ensure_table_defined() if issubclass(self.table, TypedTable) else self.table + except Exception: # pragma: no cover + table = self.table + + return str(table) + + def __get__(self, instance: t.Any, owner: t.Any) -> "t.Optional[list[t.Any]] | Relationship[To_Type]": + """ + Relationship is a descriptor class, which can be returned from a class but not an instance. + + For an instance, using .join() will replace the Relationship with the actual data. + If you forgot to join, a warning will be shown and empty data will be returned. + """ + if not instance: + # relationship queried on class, that's allowed + return self + + warnings.warn( + "Trying to get data from a relationship object! Did you forget to join it?", + category=RuntimeWarning, + ) + if self.multiple: + return [] + else: + return None + + +def relationship( + _type: t.Type[To_Type], + condition: Condition = None, + join: JOIN_OPTIONS = None, + on: OnQuery = None, +) -> To_Type: + """ + Define a relationship to another table, when its id is not stored in the current table. + + Example: + class User(TypedTable): + name: str + + posts = relationship(list["Post"], condition=lambda self, post: self.id == post.author, join='left') + + class Post(TypedTable): + title: str + author: User + + User.join("posts").first() # User instance with list[Post] in .posts + + Here, Post stores the User ID, but `relationship(list["Post"])` still allows you to get the user's posts. + In this case, the join strategy is set to LEFT so users without posts are also still selected. + + For complex queries with a pivot table, a `on` can be set insteaad of `condition`: + class User(TypedTable): + ... + + tags = relationship(list["Tag"], on=lambda self, tag: [ + Tagged.on(Tagged.entity == entity.gid), + Tag.on((Tagged.tag == tag.id)), + ]) + + If you'd try to capture this in a single 'condition', pydal would create a cross join which is much less efficient. + """ + return t.cast( + # note: The descriptor `Relationship[To_Type]` is more correct, but pycharm doesn't really get that. + # so for ease of use, just cast to the refered type for now! + # e.g. x = relationship(Author) -> x: Author + To_Type, + Relationship(_type, condition, join, on), + ) + + +def _generate_relationship_condition(_: t.Type["TypedTable"], key: str, field: T_Field) -> Condition: + origin = t.get_origin(field) + # else: generic + + if origin is list: + # field = typing.get_args(field)[0] # actual field + # return lambda _self, _other: cls[key].contains(field) + + return lambda _self, _other: _self[key].contains(_other.id) + else: + # normal reference + # return lambda _self, _other: cls[key] == field.id + return lambda _self, _other: _self[key] == _other.id + + +def to_relationship( + cls: t.Type["TypedTable"] | type[t.Any], + key: str, + field: T_Field, +) -> t.Optional[Relationship[t.Any]]: + """ + Used to automatically create relationship instance for reference fields. + + Example: + class MyTable(TypedTable): + reference: OtherTable + + `reference` contains the id of an Other Table row. + MyTable.relationships should have 'reference' as a relationship, so `MyTable.join('reference')` should work. + + This function will automatically perform this logic (called in db.define): + to_relationship(MyTable, 'reference', OtherTable) -> Relationship[OtherTable] + + Also works for list:reference (list[OtherTable]) and TypedField[OtherTable]. + """ + if looks_like(field, TypedField): + # typing.get_args works for list[str] but not for TypedField[role] :( + if args := t.get_args(field): + # TypedField[SomeType] -> SomeType + field = args[0] + elif hasattr(field, "_type"): + # TypedField(SomeType) -> SomeType + field = t.cast(T_Field, field._type) + else: # pragma: no cover + # weird + return None + + field, optional = extract_type_optional(field) + + try: + condition = _generate_relationship_condition(cls, key, field) + except Exception as e: # pragma: no cover + warnings.warn("Could not generate Relationship condition", source=e) + condition = None + + if not condition: # pragma: no cover + # something went wrong, not a valid relationship + warnings.warn(f"Invalid relationship for {cls.__name__}.{key}: {field}") + return None + + join = "left" if optional or t.get_origin(field) is list else "inner" + + return Relationship(t.cast(type[TypedTable], field), condition, t.cast(JOIN_OPTIONS, join)) + + +# note: these imports exist at the bottom of this file to prevent circular import issues: + +from .tables import TypedTable # noqa: E402 diff --git a/src/typedal/rows.py b/src/typedal/rows.py new file mode 100644 index 0000000..6590f91 --- /dev/null +++ b/src/typedal/rows.py @@ -0,0 +1,524 @@ +""" +Contains base functionality related to Rows (raw result of a database query). +""" + +from __future__ import annotations + +import csv +import json +import typing as t + +import pydal.objects + +from .core import TypeDAL +from .helpers import mktable +from .query_builder import QueryBuilder +from .serializers import as_json +from .tables import TypedTable +from .types import ( + AnyDict, + Field, + Metadata, + PaginateDict, + Pagination, + Query, + Row, + Rows, + T, + T_MetaInstance, +) + + +class TypedRows(t.Collection[T_MetaInstance], Rows): + """ + Slighly enhaned and typed functionality on top of pydal Rows (the result of a select). + """ + + records: dict[int, T_MetaInstance] + # _rows: Rows + model: t.Type[T_MetaInstance] + metadata: Metadata + + # pseudo-properties: actually stored in _rows + db: TypeDAL + colnames: list[str] + fields: list[Field] + colnames_fields: list[Field] + response: list[tuple[t.Any, ...]] + + def __init__( + self, + rows: Rows, + model: t.Type[T_MetaInstance], + records: dict[int, T_MetaInstance] = None, + metadata: Metadata = None, + raw: dict[int, list[Row]] = None, + ) -> None: + """ + Should not be called manually! + + Normally, the `records` from an existing `Rows` object are used + but these can be overwritten with a `records` dict. + `metadata` can be t.Any (un)structured data + `model` is a Typed Table class + """ + + def _get_id(row: Row) -> int: + """ + Try to find the id field in a row. + + If _extra exists, the row changes: + + """ + if idx := getattr(row, "id", None): + return t.cast(int, idx) + elif main := getattr(row, str(model), None): + return t.cast(int, main.id) + else: # pragma: no cover + raise NotImplementedError(f"`id` could not be found for {row}") + + records = records or {_get_id(row): model(row) for row in rows} + raw = raw or {} + + for idx, entity in records.items(): + entity._rows = tuple(raw.get(idx, [])) + + super().__init__(rows.db, records, rows.colnames, rows.compact, rows.response, rows.fields) + self.model = model + self.metadata = metadata or {} + self.colnames = rows.colnames + + def __len__(self) -> int: + """ + Return the count of rows. + """ + return len(self.records) + + def __iter__(self) -> t.Iterator[T_MetaInstance]: + """ + Loop through the rows. + """ + yield from self.records.values() + + def __contains__(self, ind: t.Any) -> bool: + """ + Check if an id exists in this result set. + """ + return ind in self.records + + def first(self) -> T_MetaInstance | None: + """ + Get the row with the lowest id. + """ + if not self.records: + return None + + return next(iter(self)) + + def last(self) -> T_MetaInstance | None: + """ + Get the row with the highest id. + """ + if not self.records: + return None + + max_id = max(self.records.keys()) + return self[max_id] + + def find( + self, + f: t.Callable[[T_MetaInstance], Query], + limitby: tuple[int, int] = None, + ) -> "TypedRows[T_MetaInstance]": + """ + Returns a new Rows object, a subset of the original object, filtered by the function `f`. + """ + if not self.records: + return self.__class__(self, self.model, {}) + + records = {} + if limitby: + _min, _max = limitby + else: + _min, _max = 0, len(self) + count = 0 + for i, row in self.records.items(): + if f(row): + if _min <= count: + records[i] = row + count += 1 + if count == _max: + break + + return self.__class__(self, self.model, records) + + def exclude(self, f: t.Callable[[T_MetaInstance], Query]) -> "TypedRows[T_MetaInstance]": + """ + Removes elements from the calling Rows object, filtered by the function `f`, \ + and returns a new Rows object containing the removed elements. + """ + if not self.records: + return self.__class__(self, self.model, {}) + removed = {} + to_remove = [] + for i in self.records: + row = self[i] + if f(row): + removed[i] = self.records[i] + to_remove.append(i) + + [self.records.pop(i) for i in to_remove] + + return self.__class__( + self, + self.model, + removed, + ) + + def sort(self, f: t.Callable[[T_MetaInstance], t.Any], reverse: bool = False) -> list[T_MetaInstance]: + """ + Returns a list of sorted elements (not sorted in place). + """ + return [r for (r, s) in sorted(zip(self.records.values(), self), key=lambda r: f(r[1]), reverse=reverse)] + + def __str__(self) -> str: + """ + Simple string representation. + """ + return f"" + + def __repr__(self) -> str: + """ + Print a table on repr(). + """ + data = self.as_dict() + try: + headers = list(next(iter(data.values())).keys()) + except StopIteration: + headers = [] + + return mktable(data, headers) + + def group_by_value( + self, + *fields: "str | Field | TypedField[T]", + one_result: bool = False, + **kwargs: t.Any, + ) -> dict[T, list[T_MetaInstance]]: + """ + Group the rows by a specific field (which will be the dict key). + """ + kwargs["one_result"] = one_result + result = super().group_by_value(*fields, **kwargs) + return t.cast(dict[T, list[T_MetaInstance]], result) + + def as_csv(self) -> str: + """ + Dump the data to csv. + """ + return t.cast(str, super().as_csv()) + + def as_dict( + self, + key: str | Field | None = None, + compact: bool = False, + storage_to_dict: bool = False, + datetime_to_str: bool = False, + custom_types: list[type] | None = None, + ) -> dict[int, AnyDict]: + """ + Get the data in a dict of dicts. + """ + if any([key, compact, storage_to_dict, datetime_to_str, custom_types]): + # functionality not guaranteed + if isinstance(key, Field): + key = key.name + + return t.cast( + dict[int, AnyDict], + super().as_dict( + key or "id", + compact, + storage_to_dict, + datetime_to_str, + custom_types, + ), + ) + + return {k: v.as_dict() for k, v in self.records.items()} + + def as_json( + self, default: t.Callable[[t.Any], t.Any] = None, indent: t.Optional[int] = None, **kwargs: t.Any + ) -> str: + """ + Turn the data into a dict and then dump to JSON. + """ + data = self.as_list() + + return as_json.encode(data, default=default, indent=indent, **kwargs) + + def json(self, default: t.Callable[[t.Any], t.Any] = None, indent: t.Optional[int] = None, **kwargs: t.Any) -> str: + """ + Turn the data into a dict and then dump to JSON. + """ + return self.as_json(default=default, indent=indent, **kwargs) + + def as_list( + self, + compact: bool = False, + storage_to_dict: bool = False, + datetime_to_str: bool = False, + custom_types: list[type] = None, + ) -> list[AnyDict]: + """ + Get the data in a list of dicts. + """ + if any([compact, storage_to_dict, datetime_to_str, custom_types]): + return t.cast(list[AnyDict], super().as_list(compact, storage_to_dict, datetime_to_str, custom_types)) + + return [_.as_dict() for _ in self.records.values()] + + def __getitem__(self, item: int) -> T_MetaInstance: + """ + You can get a specific row by ID from a typedrows by using rows[idx] notation. + + Since pydal's implementation differs (they expect a list instead of a dict with id keys), + using rows[0] will return the first row, regardless of its id. + """ + try: + return self.records[item] + except KeyError as e: + if item == 0 and (row := self.first()): + # special case: pydal internals think Rows.records is a list, not a dict + return row + + raise e + + def get(self, item: int) -> t.Optional[T_MetaInstance]: + """ + Get a row by ID, or receive None if it isn't in this result set. + """ + return self.records.get(item) + + def update(self, **new_values: t.Any) -> bool: + """ + Update the current rows in the database with new_values. + """ + # cast to make mypy understand .id is a TypedField and not an int! + table = t.cast(t.Type[TypedTable], self.model._ensure_table_defined()) + + ids = set(self.column("id")) + query = table.id.belongs(ids) + return bool(self.db(query).update(**new_values)) + + def delete(self) -> bool: + """ + Delete the currently selected rows from the database. + """ + # cast to make mypy understand .id is a TypedField and not an int! + table = t.cast(t.Type[TypedTable], self.model._ensure_table_defined()) + + ids = set(self.column("id")) + query = table.id.belongs(ids) + return bool(self.db(query).delete()) + + def join( + self, + field: "Field | TypedField[t.Any]", + name: str = None, + constraint: Query = None, + fields: list[str | Field] = None, + orderby: t.Optional[str | Field] = None, + ) -> T_MetaInstance: + """ + This can be used to JOIN with some relationships after the initial select. + + Using the querybuilder's .join() method is prefered! + """ + result = super().join(field, name, constraint, fields or [], orderby) + return t.cast(T_MetaInstance, result) + + def export_to_csv_file( + self, + ofile: t.TextIO, + null: t.Any = "", + delimiter: str = ",", + quotechar: str = '"', + quoting: int = csv.QUOTE_MINIMAL, + represent: bool = False, + colnames: list[str] = None, + write_colnames: bool = True, + *args: t.Any, + **kwargs: t.Any, + ) -> None: + """ + Shadow export_to_csv_file from Rows, but with typing. + + See http://web2py.com/books/default/chapter/29/06/the-database-abstraction-layer?search=export_to_csv_file#Exporting-and-importing-data + """ + super().export_to_csv_file( + ofile, + null, + *args, + delimiter=delimiter, + quotechar=quotechar, + quoting=quoting, + represent=represent, + colnames=colnames or self.colnames, + write_colnames=write_colnames, + **kwargs, + ) + + @classmethod + def from_rows( + cls, + rows: Rows, + model: t.Type[T_MetaInstance], + metadata: Metadata = None, + ) -> "TypedRows[T_MetaInstance]": + """ + Internal method to convert a Rows object to a TypedRows. + """ + return cls(rows, model, metadata=metadata) + + def __getstate__(self) -> AnyDict: + """ + Used by dill to dump to bytes (exclude db connection etc). + """ + return { + "metadata": json.dumps(self.metadata, default=str), + "records": self.records, + "model": str(self.model._table), + "colnames": self.colnames, + } + + def __setstate__(self, state: AnyDict) -> None: + """ + Used by dill when loading from a bytestring. + """ + state["metadata"] = json.loads(state["metadata"]) + self.__dict__.update(state) + # db etc. set after undill by caching.py + + def render( + self, + i: int | None = None, + fields: list[Field] | None = None, + ) -> t.Generator[T_MetaInstance, None, None]: + """ + Takes an index and returns a copy of the indexed row with values \ + transformed via the "represent" attributes of the associated fields. + + Args: + i: index. If not specified, a generator is returned for iteration + over all the rows. + fields: a list of fields to transform (if None, all fields with + "represent" attributes will be transformed) + """ + if i is None: + # difference: uses .keys() instead of index + return (self.render(i, fields=fields) for i in self.records) + + if not self.db.has_representer("rows_render"): # pragma: no cover + raise RuntimeError( + "Rows.render() needs a `rows_render` representer in DAL instance", + ) + + row = self.records[i] + return row.render(fields, compact=self.compact) + + +class PaginatedRows(TypedRows[T_MetaInstance]): + """ + Extension on top of rows that is used when calling .paginate() instead of .collect(). + """ + + _query_builder: QueryBuilder[T_MetaInstance] + + @property + def data(self) -> list[T_MetaInstance]: + """ + Get the underlying data. + """ + return list(self.records.values()) + + @property + def pagination(self) -> Pagination: + """ + Get all page info. + """ + pagination_data = self.metadata["pagination"] + + has_next_page = pagination_data["current_page"] < pagination_data["max_page"] + has_prev_page = pagination_data["current_page"] > 1 + return { + "total_items": pagination_data["rows"], + "current_page": pagination_data["current_page"], + "per_page": pagination_data["limit"], + "total_pages": pagination_data["max_page"], + "has_next_page": has_next_page, + "has_prev_page": has_prev_page, + "next_page": pagination_data["current_page"] + 1 if has_next_page else None, + "prev_page": pagination_data["current_page"] - 1 if has_prev_page else None, + } + + def next(self) -> t.Self: + """ + Get the next page. + """ + data = self.metadata["pagination"] + if data["current_page"] >= data["max_page"]: + raise StopIteration("Final Page") + + return self._query_builder.paginate(limit=data["limit"], page=data["current_page"] + 1) + + def previous(self) -> t.Self: + """ + Get the previous page. + """ + data = self.metadata["pagination"] + if data["current_page"] <= 1: + raise StopIteration("First Page") + + return self._query_builder.paginate(limit=data["limit"], page=data["current_page"] - 1) + + def as_dict(self, *_: t.Any, **__: t.Any) -> PaginateDict: # type: ignore + """ + Convert to a dictionary with pagination info and original data. + + All arguments are ignored! + """ + return {"data": super().as_dict(), "pagination": self.pagination} + + +class TypedSet(pydal.objects.Set): # type: ignore # pragma: no cover + """ + Used to make pydal Set more typed. + + This class is not actually used, only 'cast' by TypeDAL.__call__ + """ + + def count(self, distinct: t.Optional[bool] = None, cache: AnyDict = None) -> int: + """ + Count returns an int. + """ + result = super().count(distinct, cache) + return t.cast(int, result) + + def select(self, *fields: t.Any, **attributes: t.Any) -> TypedRows[T_MetaInstance]: + """ + Select returns a TypedRows of a user defined table. + + Example: + result: TypedRows[MyTable] = db(MyTable.id > 0).select() + + for row in result: + reveal_type(row) # MyTable + """ + rows = super().select(*fields, **attributes) + return t.cast(TypedRows[T_MetaInstance], rows) + + +# note: these imports exist at the bottom of this file to prevent circular import issues: + +from .fields import TypedField # noqa: E402 diff --git a/src/typedal/serializers/as_json.py b/src/typedal/serializers/as_json.py index e8ed715..2100534 100644 --- a/src/typedal/serializers/as_json.py +++ b/src/typedal/serializers/as_json.py @@ -3,8 +3,7 @@ """ import json -import typing -from typing import Any +import typing as t from configurablejson import ConfigurableJsonEncoder, JSONRule @@ -14,7 +13,7 @@ class SerializedJson(ConfigurableJsonEncoder): Custom encoder class with slightly improved defaults. """ - def _default(self, o: Any) -> Any: # pragma: no cover + def _default(self, o: t.Any) -> t.Any: # pragma: no cover if hasattr(o, "as_dict"): return o.as_dict() elif hasattr(o, "asdict"): @@ -41,25 +40,25 @@ def _default(self, o: Any) -> Any: # pragma: no cover return str(o) - @typing.overload - def rules(self, o: Any, with_default: typing.Literal[False]) -> JSONRule | None: + @t.overload + def rules(self, o: t.Any, with_default: t.Literal[False]) -> JSONRule | None: """ If you pass with_default=False, you could get a None result. """ - @typing.overload - def rules(self, o: Any, with_default: typing.Literal[True] = True) -> JSONRule: + @t.overload + def rules(self, o: t.Any, with_default: t.Literal[True] = True) -> JSONRule: """ If you don't pass with_default=False, you will always get a JSONRule result. """ - def rules(self, o: Any, with_default: bool = True) -> JSONRule | None: + def rules(self, o: t.Any, with_default: bool = True) -> JSONRule | None: """ Custom rules, such as set to list and as_dict/__json__ etc. lookups. """ _type = type(o) - _rules: dict[type[Any], JSONRule] = { + _rules: dict[type[t.Any], JSONRule] = { # convert set to list set: JSONRule(preprocess=lambda o: list(o)), } @@ -68,7 +67,7 @@ def rules(self, o: Any, with_default: bool = True) -> JSONRule | None: return _rules.get(_type, JSONRule(transform=self._default) if with_default else None) -def encode(something: Any, indent: typing.Optional[int] = None, **kw: Any) -> str: +def encode(something: t.Any, indent: t.Optional[int] = None, **kw: t.Any) -> str: """ Encode anything to JSON with some improved defaults. """ diff --git a/src/typedal/tables.py b/src/typedal/tables.py new file mode 100644 index 0000000..2ad6699 --- /dev/null +++ b/src/typedal/tables.py @@ -0,0 +1,1122 @@ +""" +Contains base functionality related to Tables. +""" + +from __future__ import annotations + +import copy +import csv +import functools +import json +import typing as t +import uuid + +import pydal.objects +from pydal._globals import DEFAULT + +from .constants import JOIN_OPTIONS +from .core import TypeDAL +from .helpers import classproperty, throw +from .serializers import as_json +from .types import ( + AnyDict, + Condition, + Expression, + Field, + OnQuery, + OpRow, + OrderBy, + P, + Query, + R, + Reference, + Row, + SelectKwargs, + Set, + T, + T_MetaInstance, + T_Query, + Table, +) + +if t.TYPE_CHECKING: + from .relationships import Relationship + from .rows import PaginatedRows, TypedRows + + +def reorder_fields( + table: pydal.objects.Table, + fields: t.Iterable[str | TypedField[t.Any] | Field], + keep_others: bool = True, +) -> None: + """ + Reorder fields of a pydal table. + + Args: + table: The pydal table object (e.g., db.mytable). + fields: List of field names (str) or Field objects in desired order. + keep_others (bool): + - True (default): keep other fields at the end, in their original order. + - False: remove other fields (only keep what's specified). + """ + # Normalize input to field names + desired = [f.name if isinstance(f, (TypedField, Field, pydal.objects.Field)) else str(f) for f in fields] + + new_order = [f for f in desired if f in table._fields] + + if keep_others: + # Start with desired fields, then append the rest + new_order.extend(f for f in table._fields if f not in desired) + + table._fields = new_order + + +class TableMeta(type): + """ + This metaclass contains functionality on table classes, that doesn't exist on its instances. + + Example: + class MyTable(TypedTable): + some_field: TypedField[int] + + MyTable.update_or_insert(...) # should work + + MyTable.some_field # -> Field, can be used to query etc. + + row = MyTable.first() # returns instance of MyTable + + # row.update_or_insert(...) # shouldn't work! + + row.some_field # -> int, with actual data + + """ + + # set up by db.define: + # _db: TypeDAL | None = None + # _table: Table | None = None + _db: TypeDAL | None = None + _table: Table | None = None + _relationships: dict[str, Relationship[t.Any]] | None = None + + ######################### + # TypeDAL custom logic: # + ######################### + + def __set_internals__(self, db: pydal.DAL, table: Table, relationships: dict[str, Relationship[t.Any]]) -> None: + """ + Store the related database and pydal table for later usage. + """ + self._db = db + self._table = table + self._relationships = relationships + + def __getattr__(self, col: str) -> t.Optional[Field]: + """ + Magic method used by TypedTableMeta to get a database field with dot notation on a class. + + Example: + SomeTypedTable.col -> db.table.col (via TypedTableMeta.__getattr__) + + """ + if self._table: + return getattr(self._table, col, None) + + return None + + def _ensure_table_defined(self) -> Table: + if not self._table: + raise EnvironmentError("@define or db.define is not called on this class yet!") + return self._table + + def __iter__(self) -> t.Generator[Field, None, None]: + """ + Loop through the columns of this model. + """ + table = self._ensure_table_defined() + yield from iter(table) + + def __getitem__(self, item: str) -> Field: + """ + Allow dict notation to get a column of this table (-> Field instance). + """ + table = self._ensure_table_defined() + return table[item] + + def __str__(self) -> str: + """ + Normally, just returns the underlying table name, but with a fallback if the model is unbound. + """ + if self._table: + return str(self._table) + else: + return f"" + + def from_row(self: t.Type[T_MetaInstance], row: pydal.objects.Row) -> T_MetaInstance: + """ + Create a model instance from a pydal row. + """ + return self(row) + + def all(self: t.Type[T_MetaInstance]) -> "TypedRows[T_MetaInstance]": + """ + Return all rows for this model. + """ + return self.collect() + + def get_relationships(self) -> dict[str, Relationship[t.Any]]: + """ + Return the registered relationships of the current model. + """ + return self._relationships or {} + + ########################## + # TypeDAL Modified Logic # + ########################## + + def insert(self: t.Type[T_MetaInstance], **fields: t.Any) -> T_MetaInstance: + """ + This is only called when db.define is not used as a decorator. + + cls.__table functions as 'self' + + Args: + **fields: t.Anything you want to insert in the database + + Returns: the ID of the new row. + + """ + table = self._ensure_table_defined() + + result = table.insert(**fields) + # it already is an int but mypy doesn't understand that + return self(result) + + def _insert(self, **fields: t.Any) -> str: + table = self._ensure_table_defined() + + return str(table._insert(**fields)) + + def bulk_insert(self: t.Type[T_MetaInstance], items: list[AnyDict]) -> "TypedRows[T_MetaInstance]": + """ + Insert multiple rows, returns a TypedRows set of new instances. + """ + table = self._ensure_table_defined() + result = table.bulk_insert(items) + return self.where(lambda row: row.id.belongs(result)).collect() + + def update_or_insert( + self: t.Type[T_MetaInstance], + query: T_Query | AnyDict = DEFAULT, + **values: t.Any, + ) -> T_MetaInstance: + """ + Update a row if query matches, else insert a new one. + + Returns the created or updated instance. + """ + table = self._ensure_table_defined() + + if query is DEFAULT: + record = table(**values) + elif isinstance(query, dict): + record = table(**query) + else: + record = table(query) + + if not record: + return self.insert(**values) + + record.update_record(**values) + return self(record) + + def validate_and_insert( + self: t.Type[T_MetaInstance], + **fields: t.Any, + ) -> tuple[t.Optional[T_MetaInstance], t.Optional[dict[str, str]]]: + """ + Validate input data and then insert a row. + + Returns a tuple of (the created instance, a dict of errors). + """ + table = self._ensure_table_defined() + result = table.validate_and_insert(**fields) + if row_id := result.get("id"): + return self(row_id), None + else: + return None, result.get("errors") + + def validate_and_update( + self: t.Type[T_MetaInstance], + query: Query, + **fields: t.Any, + ) -> tuple[t.Optional[T_MetaInstance], t.Optional[dict[str, str]]]: + """ + Validate input data and then update max 1 row. + + Returns a tuple of (the updated instance, a dict of errors). + """ + table = self._ensure_table_defined() + + result = table.validate_and_update(query, **fields) + + if errors := result.get("errors"): + return None, errors + elif row_id := result.get("id"): + return self(row_id), None + else: # pragma: no cover + # update on query without result (shouldnt happen) + return None, None + + def validate_and_update_or_insert( + self: t.Type[T_MetaInstance], + query: Query, + **fields: t.Any, + ) -> tuple[t.Optional[T_MetaInstance], t.Optional[dict[str, str]]]: + """ + Validate input data and then update_and_insert (on max 1 row). + + Returns a tuple of (the updated/created instance, a dict of errors). + """ + table = self._ensure_table_defined() + result = table.validate_and_update_or_insert(query, **fields) + + if errors := result.get("errors"): + return None, errors + elif row_id := result.get("id"): + return self(row_id), None + else: # pragma: no cover + # update on query without result (shouldnt happen) + return None, None + + def select(self: t.Type[T_MetaInstance], *a: t.Any, **kw: t.Any) -> "QueryBuilder[T_MetaInstance]": + """ + See QueryBuilder.select! + """ + return QueryBuilder(self).select(*a, **kw) + + def column(self: t.Type[T_MetaInstance], field: T | TypedField[T], **options: t.Unpack[SelectKwargs]) -> list[T]: + """ + Get all values in a specific column. + + Shortcut for `.select(field).execute().column(field)`. + """ + return QueryBuilder(self).select(field, **options).execute().column(field) + + def paginate(self: t.Type[T_MetaInstance], limit: int, page: int = 1) -> "PaginatedRows[T_MetaInstance]": + """ + See QueryBuilder.paginate! + """ + return QueryBuilder(self).paginate(limit=limit, page=page) + + def chunk(self: t.Type[T_MetaInstance], chunk_size: int) -> t.Generator["TypedRows[T_MetaInstance]", t.Any, None]: + """ + See QueryBuilder.chunk! + """ + return QueryBuilder(self).chunk(chunk_size) + + def where(self: t.Type[T_MetaInstance], *a: t.Any, **kw: t.Any) -> "QueryBuilder[T_MetaInstance]": + """ + See QueryBuilder.where! + """ + return QueryBuilder(self).where(*a, **kw) + + def orderby(self: t.Type[T_MetaInstance], *fields: OrderBy) -> "QueryBuilder[T_MetaInstance]": + """ + See QueryBuilder.orderby! + """ + return QueryBuilder(self).orderby(*fields) + + def cache(self: t.Type[T_MetaInstance], *deps: t.Any, **kwargs: t.Any) -> "QueryBuilder[T_MetaInstance]": + """ + See QueryBuilder.cache! + """ + return QueryBuilder(self).cache(*deps, **kwargs) + + def count(self: t.Type[T_MetaInstance]) -> int: + """ + See QueryBuilder.count! + """ + return QueryBuilder(self).count() + + def exists(self: t.Type[T_MetaInstance]) -> bool: + """ + See QueryBuilder.exists! + """ + return QueryBuilder(self).exists() + + def first(self: t.Type[T_MetaInstance]) -> T_MetaInstance | None: + """ + See QueryBuilder.first! + """ + return QueryBuilder(self).first() + + def first_or_fail(self: t.Type[T_MetaInstance]) -> T_MetaInstance: + """ + See QueryBuilder.first_or_fail! + """ + return QueryBuilder(self).first_or_fail() + + def join( + self: t.Type[T_MetaInstance], + *fields: str | t.Type["TypedTable"], + method: JOIN_OPTIONS = None, + on: OnQuery | list[Expression] | Expression = None, + condition: Condition = None, + condition_and: Condition = None, + ) -> "QueryBuilder[T_MetaInstance]": + """ + See QueryBuilder.join! + """ + return QueryBuilder(self).join(*fields, on=on, condition=condition, method=method, condition_and=condition_and) + + def collect(self: t.Type[T_MetaInstance], verbose: bool = False) -> "TypedRows[T_MetaInstance]": + """ + See QueryBuilder.collect! + """ + return QueryBuilder(self).collect(verbose=verbose) + + @property + def ALL(cls) -> pydal.objects.SQLALL: + """ + Select all fields for this table. + """ + table = cls._ensure_table_defined() + + return table.ALL + + ########################## + # TypeDAL Shadowed Logic # + ########################## + fields: list[str] + + # other table methods: + + def truncate(self, mode: str = "") -> None: + """ + Remove all data and reset index. + """ + table = self._ensure_table_defined() + table.truncate(mode) + + def drop(self, mode: str = "") -> None: + """ + Remove the underlying table. + """ + table = self._ensure_table_defined() + table.drop(mode) + + def create_index(self, name: str, *fields: str | Field, **kwargs: t.Any) -> bool: + """ + Add an index on some columns of this table. + """ + table = self._ensure_table_defined() + result = table.create_index(name, *fields, **kwargs) + return t.cast(bool, result) + + def drop_index(self, name: str, if_exists: bool = False) -> bool: + """ + Remove an index from this table. + """ + table = self._ensure_table_defined() + result = table.drop_index(name, if_exists) + return t.cast(bool, result) + + def import_from_csv_file( + self, + csvfile: t.TextIO, + id_map: dict[str, str] = None, + null: t.Any = "", + unique: str = "uuid", + id_offset: dict[str, int] = None, # id_offset used only when id_map is None + transform: t.Callable[[dict[t.Any, t.Any]], dict[t.Any, t.Any]] = None, + validate: bool = False, + encoding: str = "utf-8", + delimiter: str = ",", + quotechar: str = '"', + quoting: int = csv.QUOTE_MINIMAL, + restore: bool = False, + **kwargs: t.Any, + ) -> None: + """ + Load a csv file into the database. + """ + table = self._ensure_table_defined() + table.import_from_csv_file( + csvfile, + id_map=id_map, + null=null, + unique=unique, + id_offset=id_offset, + transform=transform, + validate=validate, + encoding=encoding, + delimiter=delimiter, + quotechar=quotechar, + quoting=quoting, + restore=restore, + **kwargs, + ) + + def on(self, query: bool | Query) -> Expression: + """ + Shadow Table.on. + + Used for joins. + + See Also: + http://web2py.com/books/default/chapter/29/06/the-database-abstraction-layer?search=export_to_csv_file#One-to-mt.Any-relation + """ + table = self._ensure_table_defined() + return t.cast(Expression, table.on(query)) + + def with_alias(self: t.Type[T_MetaInstance], alias: str) -> t.Type[T_MetaInstance]: + """ + Shadow Table.with_alias. + + Useful for joins when joining the same table multiple times. + + See Also: + http://web2py.com/books/default/chapter/29/06/the-database-abstraction-layer#One-to-mt.Any-relation + """ + table = self._ensure_table_defined() + return t.cast(t.Type[T_MetaInstance], table.with_alias(alias)) + + def unique_alias(self: t.Type[T_MetaInstance]) -> t.Type[T_MetaInstance]: + """ + Generates a unique alias for this table. + + Useful for joins when joining the same table multiple times + and you don't want to keep track of aliases yourself. + """ + key = f"{self.__name__.lower()}_{hash(uuid.uuid4())}" + return self.with_alias(key) + + # hooks: + def _hook_once( + cls: t.Type[T_MetaInstance], + hooks: list[t.Callable[P, R]], + fn: t.Callable[P, R], + ) -> t.Type[T_MetaInstance]: + @functools.wraps(fn) + def wraps(*a: P.args, **kw: P.kwargs) -> R: + try: + return fn(*a, **kw) + finally: + hooks.remove(wraps) + + hooks.append(wraps) + return cls + + def before_insert( + cls: t.Type[T_MetaInstance], + fn: t.Callable[[T_MetaInstance], t.Optional[bool]] | t.Callable[[OpRow], t.Optional[bool]], + ) -> t.Type[T_MetaInstance]: + """ + Add a before insert hook. + """ + if fn not in cls._before_insert: + cls._before_insert.append(fn) + return cls + + def before_insert_once( + cls: t.Type[T_MetaInstance], + fn: t.Callable[[T_MetaInstance], t.Optional[bool]] | t.Callable[[OpRow], t.Optional[bool]], + ) -> t.Type[T_MetaInstance]: + """ + Add a before insert hook that only fires once and then removes itself. + """ + return cls._hook_once(cls._before_insert, fn) # type: ignore + + def after_insert( + cls: t.Type[T_MetaInstance], + fn: ( + t.Callable[[T_MetaInstance, Reference], t.Optional[bool]] | t.Callable[[OpRow, Reference], t.Optional[bool]] + ), + ) -> t.Type[T_MetaInstance]: + """ + Add an after insert hook. + """ + if fn not in cls._after_insert: + cls._after_insert.append(fn) + return cls + + def after_insert_once( + cls: t.Type[T_MetaInstance], + fn: ( + t.Callable[[T_MetaInstance, Reference], t.Optional[bool]] | t.Callable[[OpRow, Reference], t.Optional[bool]] + ), + ) -> t.Type[T_MetaInstance]: + """ + Add an after insert hook that only fires once and then removes itself. + """ + return cls._hook_once(cls._after_insert, fn) # type: ignore + + def before_update( + cls: t.Type[T_MetaInstance], + fn: t.Callable[[Set, T_MetaInstance], t.Optional[bool]] | t.Callable[[Set, OpRow], t.Optional[bool]], + ) -> t.Type[T_MetaInstance]: + """ + Add a before update hook. + """ + if fn not in cls._before_update: + cls._before_update.append(fn) + return cls + + def before_update_once( + cls, + fn: t.Callable[[Set, T_MetaInstance], t.Optional[bool]] | t.Callable[[Set, OpRow], t.Optional[bool]], + ) -> t.Type[T_MetaInstance]: + """ + Add a before update hook that only fires once and then removes itself. + """ + return cls._hook_once(cls._before_update, fn) # type: ignore + + def after_update( + cls: t.Type[T_MetaInstance], + fn: t.Callable[[Set, T_MetaInstance], t.Optional[bool]] | t.Callable[[Set, OpRow], t.Optional[bool]], + ) -> t.Type[T_MetaInstance]: + """ + Add an after update hook. + """ + if fn not in cls._after_update: + cls._after_update.append(fn) + return cls + + def after_update_once( + cls: t.Type[T_MetaInstance], + fn: t.Callable[[Set, T_MetaInstance], t.Optional[bool]] | t.Callable[[Set, OpRow], t.Optional[bool]], + ) -> t.Type[T_MetaInstance]: + """ + Add an after update hook that only fires once and then removes itself. + """ + return cls._hook_once(cls._after_update, fn) # type: ignore + + def before_delete(cls: t.Type[T_MetaInstance], fn: t.Callable[[Set], t.Optional[bool]]) -> t.Type[T_MetaInstance]: + """ + Add a before delete hook. + """ + if fn not in cls._before_delete: + cls._before_delete.append(fn) + return cls + + def before_delete_once( + cls: t.Type[T_MetaInstance], + fn: t.Callable[[Set], t.Optional[bool]], + ) -> t.Type[T_MetaInstance]: + """ + Add a before delete hook that only fires once and then removes itself. + """ + return cls._hook_once(cls._before_delete, fn) + + def after_delete(cls: t.Type[T_MetaInstance], fn: t.Callable[[Set], t.Optional[bool]]) -> t.Type[T_MetaInstance]: + """ + Add an after delete hook. + """ + if fn not in cls._after_delete: + cls._after_delete.append(fn) + return cls + + def after_delete_once( + cls: t.Type[T_MetaInstance], + fn: t.Callable[[Set], t.Optional[bool]], + ) -> t.Type[T_MetaInstance]: + """ + Add an after delete hook that only fires once and then removes itself. + """ + return cls._hook_once(cls._after_delete, fn) + + def reorder_fields(cls, *fields: str | Field | TypedField[t.Any], keep_others: bool = True) -> None: + """ + Reorder fields of a typedal table. + + Args: + fields: List of field names (str) or Field objects in desired order. + keep_others (bool): + - True (default): keep other fields at the end, in their original order. + - False: remove other fields (only keep what's specified). + """ + return reorder_fields(cls._table, fields, keep_others=keep_others) + + +class _TypedTable: + """ + This class is a final shared parent between TypedTable and Mixins. + + This needs to exist because otherwise the __on_define__ of Mixins are not executed. + Notably, this class exists at a level ABOVE the `metaclass=TableMeta`, + because otherwise typing gets confused when Mixins are used and multiple types could satisfy + generic 'T subclass of TypedTable' + -> Setting 'TypedTable' as the parent for Mixin does not work at runtime (and works semi at type check time) + """ + + id: "TypedField[int]" + + _before_insert: list[t.Callable[[t.Self], t.Optional[bool]] | t.Callable[[OpRow], t.Optional[bool]]] + _after_insert: list[ + t.Callable[[t.Self, Reference], t.Optional[bool]] | t.Callable[[OpRow, Reference], t.Optional[bool]] + ] + _before_update: list[t.Callable[[Set, t.Self], t.Optional[bool]] | t.Callable[[Set, OpRow], t.Optional[bool]]] + _after_update: list[t.Callable[[Set, t.Self], t.Optional[bool]] | t.Callable[[Set, OpRow], t.Optional[bool]]] + _before_delete: list[t.Callable[[Set], t.Optional[bool]]] + _after_delete: list[t.Callable[[Set], t.Optional[bool]]] + + @classmethod + def __on_define__(cls, db: TypeDAL) -> None: + """ + Method that can be implemented by tables to do an action after db.define is completed. + + This can be useful if you need to add something like requires=IS_NOT_IN_DB(db, "table.field"), + where you need a reference to the current database, which may not exist yet when defining the model. + """ + + @classproperty + def _hooks(cls) -> dict[str, list[t.Callable[..., t.Optional[bool]]]]: + return { + "before_insert": cls._before_insert, + "after_insert": cls._after_insert, + "before_update": cls._before_update, + "after_update": cls._after_update, + "before_delete": cls._before_delete, + "after_delete": cls._after_delete, + } + + +class TypedTable(_TypedTable, metaclass=TableMeta): + """ + Enhanded modeling system on top of pydal's Table that adds typing and additional functionality. + """ + + # set up by 'new': + _row: Row | None = None + _rows: tuple[Row, ...] = () + + _with: list[str] + + def _setup_instance_methods(self) -> None: + self.as_dict = self._as_dict # type: ignore + self.__json__ = self.as_json = self._as_json # type: ignore + # self.as_yaml = self._as_yaml # type: ignore + self.as_xml = self._as_xml # type: ignore + + self.update = self._update # type: ignore + + self.delete_record = self._delete_record # type: ignore + self.update_record = self._update_record # type: ignore + + def __new__( + cls, + row_or_id: t.Union[Row, Query, pydal.objects.Set, int, str, None, "TypedTable"] = None, + **filters: t.Any, + ) -> t.Self: + """ + Create a Typed Rows model instance from an existing row, ID or query. + + Examples: + MyTable(1) + MyTable(id=1) + MyTable(MyTable.id == 1) + """ + table = cls._ensure_table_defined() + inst = super().__new__(cls) + + if isinstance(row_or_id, TypedTable): + # existing typed table instance! + return t.cast(t.Self, row_or_id) + + elif isinstance(row_or_id, pydal.objects.Row): + row = row_or_id + elif row_or_id is not None: + row = table(row_or_id, **filters) + elif filters: + row = table(**filters) + else: + # dummy object + return inst + + if not row: + return None # type: ignore + + inst._row = row + + if hasattr(row, "id"): + inst.__dict__.update(row) + else: + # deal with _extra (and possibly others?) + # Row <{actual: {}, _extra: ...}> + inst.__dict__.update(row[str(cls)]) + + inst._setup_instance_methods() + return inst + + def __iter__(self) -> t.Generator[t.Any, None, None]: + """ + Allows looping through the columns. + """ + row = self._ensure_matching_row() + yield from iter(row) + + def __getitem__(self, item: str) -> t.Any: + """ + Allows dictionary notation to get columns. + """ + if item in self.__dict__: + return self.__dict__.get(item) + + # fallback to lookup in row + if self._row: + return self._row[item] + + # nothing found! + raise KeyError(item) + + def __getattr__(self, item: str) -> t.Any: + """ + Allows dot notation to get columns. + """ + if value := self.get(item): + return value + + raise AttributeError(item) + + def keys(self) -> list[str]: + """ + Return the combination of row + relationship keys. + + Used by dict(row). + """ + return list(self._row.keys() if self._row else ()) + getattr(self, "_with", []) + + def get(self, item: str, default: t.Any = None) -> t.Any: + """ + Try to get a column from this instance, else return default. + """ + try: + return self.__getitem__(item) + except KeyError: + return default + + def __setitem__(self, key: str, value: t.Any) -> None: + """ + Data can both be updated via dot and dict notation. + """ + return setattr(self, key, value) + + def __int__(self) -> int: + """ + Calling int on a model instance will return its id. + """ + return getattr(self, "id", 0) + + def __bool__(self) -> bool: + """ + If the instance has an underlying row with data, it is truthy. + """ + return bool(getattr(self, "_row", False)) + + def _ensure_matching_row(self) -> Row: + row = getattr(self, "_row", None) + return t.cast(Row, row) or throw( + EnvironmentError("Trying to access non-existant row. Maybe it was deleted or not yet initialized?") + ) + + def __repr__(self) -> str: + """ + String representation of the model instance. + """ + model_name = self.__class__.__name__ + model_data = {} + + if self._row: + model_data = self._row.as_json() + + details = model_name + details += f"({model_data})" + + if relationships := getattr(self, "_with", []): + details += f" + {relationships}" + + return f"<{details}>" + + # serialization + # underscore variants work for class instances (set up by _setup_instance_methods) + + @classmethod + def as_dict(cls, flat: bool = False, sanitize: bool = True) -> AnyDict: + """ + Dump the object to a plain dict. + + Can be used as both a class or instance method: + - dumps the table info if it's a class + - dumps the row info if it's an instance (see _as_dict) + """ + table = cls._ensure_table_defined() + result = table.as_dict(flat, sanitize) + return t.cast(AnyDict, result) + + @classmethod + def as_json(cls, sanitize: bool = True, indent: t.Optional[int] = None, **kwargs: t.Any) -> str: + """ + Dump the object to json. + + Can be used as both a class or instance method: + - dumps the table info if it's a class + - dumps the row info if it's an instance (see _as_json) + """ + data = cls.as_dict(sanitize=sanitize) + return as_json.encode(data, indent=indent, **kwargs) + + @classmethod + def as_xml(cls, sanitize: bool = True) -> str: # pragma: no cover + """ + Dump the object to xml. + + Can be used as both a class or instance method: + - dumps the table info if it's a class + - dumps the row info if it's an instance (see _as_xml) + """ + table = cls._ensure_table_defined() + return t.cast(str, table.as_xml(sanitize)) + + @classmethod + def as_yaml(cls, sanitize: bool = True) -> str: + """ + Dump the object to yaml. + + Can be used as both a class or instance method: + - dumps the table info if it's a class + - dumps the row info if it's an instance (see _as_yaml) + """ + table = cls._ensure_table_defined() + return t.cast(str, table.as_yaml(sanitize)) + + def _as_dict( + self, + datetime_to_str: bool = False, + custom_types: t.Iterable[type] | type | None = None, + ) -> AnyDict: + row = self._ensure_matching_row() + + result = row.as_dict(datetime_to_str=datetime_to_str, custom_types=custom_types) + + def asdict_method(obj: t.Any) -> t.Any: # pragma: no cover + if hasattr(obj, "_as_dict"): # typedal + return obj._as_dict() + elif hasattr(obj, "as_dict"): # pydal + return obj.as_dict() + else: # something else?? + return obj.__dict__ + + if _with := getattr(self, "_with", None): + for relationship in _with: + data = self.get(relationship) + + if isinstance(data, list): + data = [asdict_method(_) for _ in data] + elif data: + data = asdict_method(data) + + result[relationship] = data + + return t.cast(AnyDict, result) + + def _as_json( + self, + default: t.Callable[[t.Any], t.Any] = None, + indent: t.Optional[int] = None, + **kwargs: t.Any, + ) -> str: + data = self._as_dict() + return as_json.encode(data, default=default, indent=indent, **kwargs) + + def _as_xml(self, sanitize: bool = True) -> str: # pragma: no cover + row = self._ensure_matching_row() + return t.cast(str, row.as_xml(sanitize)) + + # def _as_yaml(self, sanitize: bool = True) -> str: + # row = self._ensure_matching_row() + # return t.cast(str, row.as_yaml(sanitize)) + + def __setattr__(self, key: str, value: t.Any) -> None: + """ + When setting a property on a Typed Table model instance, also update the underlying row. + """ + if self._row and key in self._row.__dict__ and not callable(value): + # enables `row.key = value; row.update_record()` + self._row[key] = value + + super().__setattr__(key, value) + + @classmethod + def update(cls: t.Type[T_MetaInstance], query: Query, **fields: t.Any) -> T_MetaInstance | None: + """ + Update one record. + + Example: + MyTable.update(MyTable.id == 1, name="NewName") -> MyTable + """ + # todo: update multiple? + if record := cls(query): + return record.update_record(**fields) + else: + return None + + def _update(self: T_MetaInstance, **fields: t.Any) -> T_MetaInstance: + row = self._ensure_matching_row() + row.update(**fields) + self.__dict__.update(**fields) + return self + + def _update_record(self: T_MetaInstance, **fields: t.Any) -> T_MetaInstance: + row = self._ensure_matching_row() + new_row = row.update_record(**fields) + self.update(**new_row) + return self + + def update_record(self: T_MetaInstance, **fields: t.Any) -> T_MetaInstance: # pragma: no cover + """ + Here as a placeholder for _update_record. + + Will be replaced on instance creation! + """ + return self._update_record(**fields) + + def _delete_record(self) -> int: + """ + Actual logic in `pydal.helpers.classes.RecordDeleter`. + """ + row = self._ensure_matching_row() + result = row.delete_record() + self.__dict__ = {} # empty self, since row is no more. + self._row = None # just to be sure + self._setup_instance_methods() + # ^ instance methods might've been deleted by emptying dict, + # but we still want .as_dict to show an error, not the table's as_dict. + return t.cast(int, result) + + def delete_record(self) -> int: # pragma: no cover + """ + Here as a placeholder for _delete_record. + + Will be replaced on instance creation! + """ + return self._delete_record() + + # __del__ is also called on the end of a scope so don't remove records on every del!! + + # pickling: + + def __getstate__(self) -> AnyDict: + """ + State to save when pickling. + + Prevents db connection from being pickled. + Similar to as_dict but without changing the data of the relationships (dill does that recursively) + """ + row = self._ensure_matching_row() + result: AnyDict = row.as_dict() + + if _with := getattr(self, "_with", None): + result["_with"] = _with + for relationship in _with: + data = self.get(relationship) + + result[relationship] = data + + result["_row"] = self._row.as_json() if self._row else "" + return result + + def __setstate__(self, state: AnyDict) -> None: + """ + Used by dill when loading from a bytestring. + """ + # as_dict also includes table info, so dump as json to only get the actual row data + # then create a new (more empty) row object: + state["_row"] = Row(json.loads(state["_row"])) + self.__dict__ |= state + + @classmethod + def _sql(cls) -> str: + """ + Generate SQL Schema for this table via pydal2sql (if 'migrations' extra is installed). + """ + try: + import pydal2sql + except ImportError as e: # pragma: no cover + raise RuntimeError("Can not generate SQL without the 'migration' extra or `pydal2sql` installed!") from e + + return pydal2sql.generate_sql(cls) + + def render(self, fields: list[Field] = None, compact: bool = False) -> t.Self: + """ + Renders a copy of the object with potentially modified values. + + Args: + fields: A list of fields to render. Defaults to all representable fields in the table. + compact: Whether to return only the value of the first field if there is only one field. + + Returns: + A copy of the object with potentially modified values. + """ + row = copy.deepcopy(self) + keys = list(row) + if not fields: + fields = [self._table[f] for f in self._table._fields] + fields = [f for f in fields if isinstance(f, Field) and f.represent] + + for field in fields: + if field._table == self._table: + row[field.name] = self._db.represent( + "rows_render", + field, + row[field.name], + row, + ) + # else: relationship, different logic: + + for relation_name in getattr(row, "_with", []): + if relation := self._relationships.get(relation_name): + relation_table = relation.table + if isinstance(relation_table, str): + relation_table = self._db[relation_table] + + relation_row = row[relation_name] + + if isinstance(relation_row, list): + # list of rows + combined = [] + + for related_og in relation_row: + related = copy.deepcopy(related_og) + for fieldname in related: + field = relation_table[fieldname] + related[field.name] = self._db.represent( + "rows_render", + field, + related[field.name], + related, + ) + combined.append(related) + + row[relation_name] = combined + else: + # 1 row + for fieldname in relation_row: + field = relation_table[fieldname] + row[relation_name][fieldname] = self._db.represent( + "rows_render", + field, + relation_row[field.name], + relation_row, + ) + + if compact and len(keys) == 1 and keys[0] != "_extra": # pragma: no cover + return t.cast(t.Self, row[keys[0]]) + return row + + +# backwards compat: +TypedRow = TypedTable + +# note: at the bottom to prevent circular import issues: +from .fields import TypedField # noqa: E402 +from .query_builder import QueryBuilder # noqa: E402 diff --git a/src/typedal/types.py b/src/typedal/types.py index 9957efb..b9810b9 100644 --- a/src/typedal/types.py +++ b/src/typedal/types.py @@ -2,10 +2,18 @@ Stuff to make mypy happy. """ -import typing -from datetime import datetime -from typing import Any, Callable, Optional, TypedDict +# --------------------------------------------------------------------------- +# Imports +# --------------------------------------------------------------------------- +# Standard library +import datetime as dt +import types +import typing as t + +import pydal.objects + +# Third-party from pydal.adapters.base import BaseAdapter from pydal.helpers.classes import OpRow as _OpRow from pydal.helpers.classes import Reference as _Reference @@ -13,126 +21,157 @@ from pydal.objects import Expression as _Expression from pydal.objects import Field as _Field from pydal.objects import Query as _Query +from pydal.objects import Row as _Row from pydal.objects import Rows as _Rows from pydal.objects import Set as _Set from pydal.objects import Table as _Table from pydal.validators import Validator as _Validator -from typing_extensions import NotRequired -if typing.TYPE_CHECKING: - from .core import TypedField +try: + from string.templatelib import Template +except ImportError: + Template: t.TypeAlias = str # type: ignore -AnyDict: typing.TypeAlias = dict[str, Any] +# Internal references +if t.TYPE_CHECKING: + from .fields import TypedField + from .tables import TypedTable +# --------------------------------------------------------------------------- +# Aliases +# --------------------------------------------------------------------------- -class Query(_Query): # type: ignore - """ - Pydal Query object. +AnyCallable: t.TypeAlias = t.Callable[..., t.Any] +AnyDict: t.TypeAlias = dict[str, t.Any] - Makes mypy happy. - """ +# --------------------------------------------------------------------------- +# Protocols +# --------------------------------------------------------------------------- -class Expression(_Expression): # type: ignore - """ - Pydal Expression object. - Make mypy happy. - """ +class TableProtocol(t.Protocol): # pragma: no cover + """Protocol to make mypy happy for Tables.""" + id: "TypedField[int]" + + def __getitem__(self, item: str) -> "Field": + """ + Tables have table[field] syntax. + """ -class Set(_Set): # type: ignore - """ - Pydal Set object. - Make mypy happy. +class CacheFn(t.Protocol): """ + The cache model (e.g. cache.ram) accepts these parameters (all filled by default). + """ + + def __call__( + self: BaseAdapter, + sql: str = "", + fields: t.Iterable[str] = (), + attributes: t.Iterable[str] = (), + colnames: t.Iterable[str] = (), + ) -> "Rows": + """Signature for calling this object.""" + + +class FileSystemLike(t.Protocol): # pragma: no cover + """Protocol for any class that has an 'open' function (e.g. OSFS).""" + + def open(self, file: str, mode: str = "r") -> t.IO[t.Any]: + """We assume every object with an open function this shape, is basically a file.""" + + +# --------------------------------------------------------------------------- +# pydal Wrappers (to help mypy understand these classes) +# --------------------------------------------------------------------------- + + +class Query(_Query): # type: ignore + """Pydal Query object. Makes mypy happy.""" + + +class Expression(_Expression): # type: ignore + """Pydal Expression object. Make mypy happy.""" + + +class Set(_Set): # type: ignore + """Pydal Set object. Make mypy happy.""" -if typing.TYPE_CHECKING: +if t.TYPE_CHECKING: class OpRow: """ Pydal OpRow object for typing (otherwise mypy thinks it's Any). - - Make mypy happy. """ - def __getitem__(self, item: str) -> typing.Any: - """ - Dict [] get notation. - """ + def __getitem__(self, item: str) -> t.Any: + """row.item syntax.""" - def __setitem__(self, key: str, value: typing.Any) -> None: - """ - Dict [] set notation. - """ + def __setitem__(self, key: str, value: t.Any) -> None: + """row.item = key syntax.""" - # ... and more methods + # more methods could be added else: class OpRow(_OpRow): # type: ignore - """ - Pydal OpRow object at runtime just uses pydal's version. - - Make mypy happy. - """ + """Runtime OpRow, using pydal's version.""" class Reference(_Reference): # type: ignore - """ - Pydal Reference object. - - Make mypy happy. - """ + """Pydal Reference object. Make mypy happy.""" class Field(_Field): # type: ignore - """ - Pydal Field object. - - Make mypy happy. - """ + """Pydal Field object. Make mypy happy.""" class Rows(_Rows): # type: ignore - """ - Pydal Rows object. - - Make mypy happy. - """ + """Pydal Rows object. Make mypy happy.""" - def column(self, column: typing.Any = None) -> list[typing.Any]: + def column(self, column: t.Any = None) -> list[t.Any]: """ Get a list of all values in a specific column. Example: - rows.column('name') -> ['Name 1', 'Name 2', ...] + rows.column('name') -> ['Name 1', 'Name 2', ...] """ return [r[str(column) if column else self.colnames[0]] for r in self] +class Row(_Row): + """Pydal Row object. Make mypy happy.""" + + class Validator(_Validator): # type: ignore - """ - Pydal Validator object. + """Pydal Validator object. Make mypy happy.""" + + +class Table(_Table, TableProtocol): # type: ignore + """Table with protocol support. Make mypy happy.""" - Make mypy happy. - """ + +# --------------------------------------------------------------------------- +# Utility Types +# --------------------------------------------------------------------------- class _Types: - """ - Internal type storage for stuff that mypy otherwise won't understand. - """ + """Internal type storage for stuff mypy otherwise won't understand.""" NONETYPE = type(None) -class Pagination(TypedDict): - """ - Pagination key of a paginate dict has these items. - """ +# --------------------------------------------------------------------------- +# TypedDicts +# --------------------------------------------------------------------------- + + +class Pagination(t.TypedDict): + """Pagination key of a paginate dict has these items.""" total_items: int current_page: int @@ -140,36 +179,30 @@ class Pagination(TypedDict): total_pages: int has_next_page: bool has_prev_page: bool - next_page: Optional[int] - prev_page: Optional[int] + next_page: t.Optional[int] + prev_page: t.Optional[int] -class PaginateDict(TypedDict): - """ - Result of PaginatedRows.as_dict(). - """ +class PaginateDict(t.TypedDict): + """Result of PaginatedRows.as_dict().""" data: dict[int, AnyDict] pagination: Pagination -class CacheMetadata(TypedDict): - """ - Used by query builder metadata in the 'cache' key. - """ +class CacheMetadata(t.TypedDict): + """Used by query builder metadata in the 'cache' key.""" enabled: bool - depends_on: list[Any] - key: NotRequired[str | None] - status: NotRequired[str | None] - expires_at: NotRequired[datetime | None] - cached_at: NotRequired[datetime | None] + depends_on: list[t.Any] + key: t.NotRequired[str | None] + status: t.NotRequired[str | None] + expires_at: t.NotRequired[dt.datetime | None] + cached_at: t.NotRequired[dt.datetime | None] -class PaginationMetadata(TypedDict): - """ - Used by query builder metadata in the 'pagination' key. - """ +class PaginationMetadata(t.TypedDict): + """Used by query builder metadata in the 'pagination' key.""" limit: int current_page: int @@ -178,101 +211,34 @@ class PaginationMetadata(TypedDict): min_max: tuple[int, int] -class TableProtocol(typing.Protocol): # pragma: no cover - """ - Make mypy happy. - """ - - id: "TypedField[int]" - - def __getitem__(self, item: str) -> Field: - """ - Tell mypy a Table supports dictionary notation for columns. - """ - - -class Table(_Table, TableProtocol): # type: ignore - """ - Make mypy happy. - """ - - -class CacheFn(typing.Protocol): - """ - The cache model (e.g. cache.ram) accepts these parameters (all filled by dfeault). - """ - - def __call__( - self: BaseAdapter, - sql: str = "", - fields: typing.Iterable[str] = (), - attributes: typing.Iterable[str] = (), - colnames: typing.Iterable[str] = (), - ) -> Rows: - """ - Only used for type-hinting. - """ - - -# CacheFn = typing.Callable[[], Rows] -CacheModel = typing.Callable[[str, CacheFn, int], Rows] -CacheTuple = tuple[CacheModel, int] - -OrderBy: typing.TypeAlias = Expression | str - +class SelectKwargs(t.TypedDict, total=False): + """Possible keyword arguments for .select().""" -class SelectKwargs(TypedDict, total=False): - """ - Possible keyword arguments for .select(). - """ - - join: Optional[list[Expression]] - left: Optional[list[Expression]] - orderby: OrderBy | typing.Iterable[OrderBy] | None - limitby: Optional[tuple[int, int]] + join: t.Optional[list[Expression]] + left: t.Optional[list[Expression]] + orderby: "OrderBy | t.Iterable[OrderBy] | None" + limitby: t.Optional[tuple[int, int]] distinct: bool | Field | Expression orderby_on_limitby: bool cacheable: bool - cache: CacheTuple - - -class Metadata(TypedDict): - """ - Loosely structured metadata used by Query Builder. - """ - - cache: NotRequired[CacheMetadata] - pagination: NotRequired[PaginationMetadata] + cache: "CacheTuple" - query: NotRequired[Query | str | None] - ids: NotRequired[str] - final_query: NotRequired[Query | str | None] - final_args: NotRequired[list[Any]] - final_kwargs: NotRequired[SelectKwargs] - relationships: NotRequired[set[str]] +class Metadata(t.TypedDict): + """Loosely structured metadata used by Query Builder.""" - sql: NotRequired[str] + cache: t.NotRequired[CacheMetadata] + pagination: t.NotRequired[PaginationMetadata] + query: t.NotRequired[Query | str | None] + ids: t.NotRequired[str] + final_query: t.NotRequired[Query | str | None] + final_args: t.NotRequired[list[t.Any]] + final_kwargs: t.NotRequired[SelectKwargs] + relationships: t.NotRequired[set[str]] + sql: t.NotRequired[str] -class FileSystemLike(typing.Protocol): # pragma: no cover - """ - Protocol for any class that has an 'open' function. - - An example of this is OSFS from PyFilesystem2. - """ - - def open(self, file: str, mode: str = "r") -> typing.IO[typing.Any]: - """ - Opens a file for reading, writing or other modes. - """ - ... - - -AnyCallable: typing.TypeAlias = Callable[..., Any] - - -class FieldSettings(TypedDict, total=False): +class FieldSettings(t.TypedDict, total=False): """ The supported keyword arguments for `pydal.Field()`. @@ -281,9 +247,9 @@ class FieldSettings(TypedDict, total=False): type: str | type | SQLCustomType length: int - default: Any + default: t.Any required: bool - requires: list[AnyCallable | Any | Validator] | Validator | AnyCallable + requires: list[AnyCallable | t.Any | Validator] | Validator | AnyCallable ondelete: str onupdate: str notnull: bool @@ -297,8 +263,8 @@ class FieldSettings(TypedDict, total=False): searchable: bool listable: bool regex: str - options: list[Any] | AnyCallable - update: Any + options: list[t.Any] | AnyCallable + update: t.Any authorize: AnyCallable autodelete: bool represent: AnyCallable @@ -312,6 +278,46 @@ class FieldSettings(TypedDict, total=False): custom_delete: AnyCallable filter_in: AnyCallable filter_out: AnyCallable - custom_qualifier: Any - map_none: Any + custom_qualifier: t.Any + map_none: t.Any rname: str + + +# --------------------------------------------------------------------------- +# Generics & Query Helpers +# --------------------------------------------------------------------------- + +T = t.TypeVar("T", bound=t.Any) +P = t.ParamSpec("P") +R = t.TypeVar("R") + +T_MetaInstance = t.TypeVar("T_MetaInstance", bound="TypedTable") +T_Query = t.Union[ + "Table", + Query, + bool, + None, + "TypedTable", + t.Type["TypedTable"], + Expression, +] + +T_subclass = t.TypeVar("T_subclass", "TypedTable", Table) +T_Field: t.TypeAlias = t.Union["TypedField[t.Any]", "Table", t.Type["TypedTable"]] + +# use typing.cast(type, ...) to make mypy happy with unions +T_Value = t.TypeVar("T_Value") # actual type of the Field (via Generic) + +# table-ish parameter: +P_Table = t.Union[t.Type["TypedTable"], pydal.objects.Table] + +Condition: t.TypeAlias = t.Optional[t.Callable[[P_Table, P_Table], Query | bool]] + +OnQuery: t.TypeAlias = t.Optional[t.Callable[[P_Table, P_Table], list[Expression]]] + +CacheModel = t.Callable[[str, CacheFn, int], Rows] +CacheTuple = tuple[CacheModel, int] + +OrderBy: t.TypeAlias = str | Expression + +T_annotation = t.Type[t.Any] | types.UnionType diff --git a/src/typedal/web2py_py4web_shared.py b/src/typedal/web2py_py4web_shared.py index 1805c32..91c89bb 100644 --- a/src/typedal/web2py_py4web_shared.py +++ b/src/typedal/web2py_py4web_shared.py @@ -6,7 +6,7 @@ from pydal.validators import CRYPT, IS_EMAIL, IS_NOT_EMPTY, IS_NOT_IN_DB, IS_STRONG -from .core import TypeDAL, TypedField, TypedTable +from . import TypeDAL, TypedField, TypedTable from .fields import PasswordField diff --git a/tests/py314_tests.py b/tests/py314_tests.py new file mode 100644 index 0000000..0b625cc --- /dev/null +++ b/tests/py314_tests.py @@ -0,0 +1,146 @@ +import sqlite3 + +import pytest + +from src.typedal import TypeDAL +from src.typedal.helpers import process_tstring, sql_escape, sql_escape_template, sql_expression + + +def test_process_tstring_basic(database: TypeDAL): + """Test the basic f-string functionality example from process_tstring docstring.""" + + def fstring_operation(interpolation): + return str(interpolation.value) + + value = "test" + template = t"{value = }" + result = process_tstring(template, fstring_operation) + assert result == "value = test" + + +def test_sql_escape_template_security(database: TypeDAL): + """Test the SQL injection prevention example from sql_escape_template docstring.""" + user_input = "'; DROP TABLE users; --" + query = t"SELECT * FROM users WHERE name = {user_input}" + safe_query = sql_escape_template(database, query) + + # The exact escaping format depends on the database adapter, but it should be escaped + assert "DROP TABLE users" in safe_query # The dangerous part should still be there + assert safe_query != f"SELECT * FROM users WHERE name = {user_input}" # But it should be escaped + # For most SQL adapters, strings are wrapped in quotes and escaped + assert "'" in safe_query or '"' in safe_query # Should have some form of quoting + + +def test_sql_escape_positional_example(database: TypeDAL): + """Test the positional arguments example from sql_escape docstring.""" + user_id = 123 + safe_sql = sql_escape(database, "SELECT * FROM users WHERE id = %s", user_id) + assert safe_sql == "SELECT * FROM users WHERE id = 123" + + +def test_sql_escape_keyword_example(database: TypeDAL): + """Test the keyword arguments example from sql_escape docstring.""" + username = "john_doe" + safe_sql = sql_escape(database, "SELECT * FROM users WHERE name = %(name)s", name=username) + assert safe_sql == "SELECT * FROM users WHERE name = 'john_doe'" + + +def test_sql_escape_template_example(database: TypeDAL): + """Test the Template string example from sql_escape docstring.""" + user_id = 456 + safe_sql = sql_escape(database, t"SELECT * FROM users WHERE id = {user_id}") + assert safe_sql == "SELECT * FROM users WHERE id = 456" + + +def test_sql_expression_complex_where(database: TypeDAL): + """Test the complex WHERE clause example from sql_expression docstring.""" + expr = sql_expression(database, "age > %s AND status = %s", 18, "active", output_type="boolean") + + expected = "age > 18 AND status = 'active'" + assert str(expr) == expected + assert expr.type == "boolean" + + +def test_sql_expression_keyword_extract(database: TypeDAL): + """Test the keyword arguments EXTRACT example from sql_expression docstring.""" + expr = sql_expression( + database, "EXTRACT(year FROM %(date_col)s) = %(year)s", date_col="created_at", year=2023, output_type="boolean" + ) + + expected = "EXTRACT(year FROM 'created_at') = 2023" + assert str(expr) == expected + assert expr.type == "boolean" + + +def test_sql_expression_template_age(database: TypeDAL): + """Test the Template string age example from sql_expression docstring.""" + min_age = 21 + expr = sql_expression(database, t"age >= {min_age}", output_type="boolean") + + expected = "age >= 21" + assert str(expr) == expected + assert expr.type == "boolean" + + +def test_date_expression_similar_to_other_test(database: TypeDAL): + start_date = "2025-01-01" + expr1 = database.sql_expression(t"date('now') > {start_date}") + assert str(expr1) == "date('now') > '2025-01-01'" + + +def test_executesql_without_tstring(database: TypeDAL): + bobby_tables = "Robert'); DROP TABLE Students;--" + + database.executesql(f""" + CREATE TABLE hackable ( + name VARCHAR(100) + ) + """) + + with pytest.raises(sqlite3.OperationalError): + database.executesql(f"INSERT INTO hackable(name) VALUES ({bobby_tables})") + + with pytest.raises(sqlite3.OperationalError): + database.executesql(f"SELECT * FROM hackable where name = {bobby_tables}") + + +def test_executesql_with_tstring(database: TypeDAL): + bobby_tables = "Robert'); DROP TABLE Students;--" + + database.executesql(t""" + CREATE TABLE unhackable ( + name VARCHAR(100) + ) + """) + + database.executesql(t"INSERT INTO unhackable(name) VALUES ({bobby_tables})") + + rows = database.executesql(t"SELECT * FROM unhackable where name = {bobby_tables}") + + assert len(rows) == 1 + assert rows[0][0] == bobby_tables + + # alternative using magic: + name = bobby_tables + rows = database.executesql(t"SELECT * FROM unhackable where {name = }") + + assert len(rows) == 1 + assert rows[0][0] == bobby_tables + + +def test_sql_expression_314(database: TypeDAL): + """Main test function that calls all example tests to verify docstring examples.""" + # Call all the docstring example tests + test_process_tstring_basic(database) + test_sql_escape_template_security(database) + test_sql_escape_positional_example(database) + test_sql_escape_keyword_example(database) + test_sql_escape_template_example(database) + test_sql_expression_complex_where(database) + test_sql_expression_keyword_extract(database) + test_sql_expression_template_age(database) + # + the one similar to the non-tstring test: + test_date_expression_similar_to_other_test(database) + # executesql with string: + test_executesql_without_tstring(database) + test_executesql_with_tstring(database) diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 8297d78..779bee8 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -1,5 +1,5 @@ +import sys import typing -from datetime import datetime, timedelta import pydal import pytest @@ -8,6 +8,7 @@ from src.typedal import TypeDAL, TypedTable, sql_expression from src.typedal.caching import get_expire from src.typedal.helpers import ( + SYSTEM_SUPPORTS_TEMPLATES, DummyQuery, all_annotations, as_lambda, @@ -26,6 +27,7 @@ ) from src.typedal.types import Field +import datetime as dt def test_is_union(): assert is_union(int | str) @@ -140,11 +142,11 @@ def test_as_lambda(): def test_get_expire(): - now = datetime(year=2023, hour=12, minute=1, second=1, month=1, day=1) + now = dt.datetime(year=2023, hour=12, minute=1, second=1, month=1, day=1) assert get_expire() is None - assert get_expire(ttl=2, now=now) == datetime(year=2023, hour=12, minute=1, second=3, month=1, day=1) - assert get_expire(ttl=timedelta(seconds=2), now=now) == datetime( + assert get_expire(ttl=2, now=now) == dt.datetime(year=2023, hour=12, minute=1, second=3, month=1, day=1) + assert get_expire(ttl=dt.timedelta(seconds=2), now=now) == dt.datetime( year=2023, hour=12, minute=1, second=3, month=1, day=1 ) @@ -209,6 +211,26 @@ def test_get_functions(): assert isinstance(field, Field) +def test_forward_reference_annotation_314(): + if sys.version_info.minor < 14: + return + + class WithForwardRef: + fwd: Future + + class Future: ... + + assert all_annotations(WithForwardRef) + + print(all_annotations(WithForwardRef)) + + class WithFakeForwardRef: + fwd: Fake + + with pytest.raises(NameError): + all_annotations(WithFakeForwardRef) + + def test_sql_expression(): # note: only %s works since .adapt does something like # -> "'%s'" % obj.replace("'", "''") @@ -244,3 +266,9 @@ class TestSqlExpression(TypedTable): # test quoting fields and tables: assert str(database.sql_expression("LOWER(%s)", TestSqlExpression.value)) == 'LOWER("test_sql_expression"."value")' assert str(database.sql_expression("LOWER(%s.value)", TestSqlExpression)) == 'LOWER("test_sql_expression".value)' + +@pytest.mark.skipif(not SYSTEM_SUPPORTS_TEMPLATES, reason="t-strings contain breaking syntax!") +def test_sql_expression_314(): + from .py314_tests import test_sql_expression_314 + + test_sql_expression_314(database) diff --git a/tests/test_main.py b/tests/test_main.py index 9bb0f31..a18e201 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,14 +1,15 @@ import re +import typing +import sys from copy import copy from sqlite3 import IntegrityError +from typing import ForwardRef -import pydal import pytest -from src.typedal import * +from src.typedal import TypedRows from src.typedal.__about__ import __version__ from src.typedal.fields import * -from typedal.types import Expression def test_about(): @@ -580,6 +581,45 @@ class SomeTableToRetry(TypedTable): assert db.try_define(SomeTableToRetry, verbose=True) +def test_forward_reference_class_314(): + if sys.version_info.minor < 14: + return + + class WithForwardRef(TypedTable): + fwd: Future + + class WithFakeRef(TypedTable): + fwd: Fake + + class Future(TypedTable): ... + + # note: this still has to be defined first because otherwise pydal can't create a database relation!: + assert db.define(Future) + + assert db.define(WithForwardRef) + + with pytest.raises(NameError): + assert db.define(WithFakeRef) + + +def test_forward_reference_class_explicit(): + class ExplicitWithForwardRef(TypedTable): + fwd: ForwardRef("ExplicitFuture") + + class WithFakeRef(TypedTable): + fwd: ForwardRef("Fake") + + class ExplicitFuture(TypedTable): ... + + # note: this still has to be defined first because otherwise pydal can't create a database relation!: + assert db.define(ExplicitFuture) + + assert db.define(ExplicitWithForwardRef) + + with pytest.raises(NameError): + assert db.define(WithFakeRef) + + def test_reorder_fields(): @db.define() class Base(TypedTable): diff --git a/tests/test_mixins.py b/tests/test_mixins.py index e2fc546..facc8fd 100644 --- a/tests/test_mixins.py +++ b/tests/test_mixins.py @@ -1,12 +1,9 @@ import time -import uuid -from datetime import datetime from typing import Optional import pytest from src.typedal import TypeDAL, TypedTable -from src.typedal.fields import StringField, TypedField, UUIDField from src.typedal.mixins import Mixin, SlugMixin, TimestampsMixin diff --git a/tests/test_mypy.py b/tests/test_mypy.py index 063e597..4b60108 100644 --- a/tests/test_mypy.py +++ b/tests/test_mypy.py @@ -1,3 +1,4 @@ +import sys import typing import pydal.objects @@ -43,16 +44,16 @@ def mypy_test_typedal_define() -> None: reveal_type(MyTable.normal) # R: builtins.str reveal_type(MyTable().normal) # R: builtins.str - reveal_type(MyTable.fancy) # R: typedal.core.TypedField[builtins.str] + reveal_type(MyTable.fancy) # R: typedal.fields.TypedField[builtins.str] reveal_type(MyTable().fancy) # R: builtins.str - reveal_type(MyTable.options) # R: typedal.core.TypedField[builtins.str] + reveal_type(MyTable.options) # R: typedal.fields.TypedField[builtins.str] reveal_type(MyTable().options) # R: builtins.str reveal_type(MyTable.fancy.lower()) # R: typedal.types.Expression reveal_type(MyTable().fancy.lower()) # R: builtins.str aliased_cls = MyTable.with_alias("---") - (reveal_type(aliased_cls),) # R: type[tests.test_mypy.MyTable] + reveal_type(aliased_cls) # R: type[tests.test_mypy.MyTable] aliased_instance = aliased_cls() reveal_type(aliased_instance) # R: tests.test_mypy.MyTable @@ -80,10 +81,10 @@ def somefunc_err(row: str, _: Reference) -> None: ... @pytest.mark.mypy_testing -def test_update() -> None: +def test_update_modern_union() -> None: query: pydal.objects.Query = MyTable.id == 3 new = MyTable.update(query) - reveal_type(new) # R: Union[tests.test_mypy.MyTable, None] + reveal_type(new) # R: tests.test_mypy.MyTable | None inst = MyTable(3) # could also actually be None! reveal_type(inst) # R: tests.test_mypy.MyTable @@ -97,7 +98,7 @@ def test_update() -> None: @pytest.mark.mypy_testing -def mypy_test_typedset() -> None: +def mypy_test_typedset_modern_union() -> None: counted1 = db(MyTable).count() counted2 = db(db.old_style).count() counted3 = db(old_style).count() @@ -112,13 +113,13 @@ def mypy_test_typedset() -> None: select2: TypedRows[MyTable] = db(MyTable).select() select3 = MyTable.select().collect() - reveal_type(select1) # R: typedal.core.TypedRows[Any] - reveal_type(select2) # R: typedal.core.TypedRows[tests.test_mypy.MyTable] - reveal_type(select3) # R: typedal.core.TypedRows[tests.test_mypy.MyTable] + reveal_type(select1) # R: typedal.rows.TypedRows[Any] + reveal_type(select2) # R: typedal.rows.TypedRows[tests.test_mypy.MyTable] + reveal_type(select3) # R: typedal.rows.TypedRows[tests.test_mypy.MyTable] - reveal_type(select1.first()) # R: Union[Any, None] - reveal_type(select2.first()) # R: Union[tests.test_mypy.MyTable, None] - reveal_type(select3.first()) # R: Union[tests.test_mypy.MyTable, None] + reveal_type(select1.first()) # R: Any | None + reveal_type(select2.first()) # R: tests.test_mypy.MyTable | None + reveal_type(select3.first()) # R: tests.test_mypy.MyTable | None for row in select2: reveal_type(row) # R: tests.test_mypy.MyTable diff --git a/tests/test_orm.py b/tests/test_orm.py index 4ed8372..6fa6a9e 100644 --- a/tests/test_orm.py +++ b/tests/test_orm.py @@ -1,6 +1,5 @@ import typing import uuid -from collections import ChainMap from typing_extensions import reveal_type @@ -9,36 +8,11 @@ T_MetaInstance = typing.TypeVar("T_MetaInstance") -def _all_annotations(cls: type) -> ChainMap[str, type]: - """ - Returns a dictionary-like ChainMap that includes annotations for all \ - attributes defined in cls or inherited from superclasses. - """ - return ChainMap(*(c.__annotations__ for c in getattr(cls, "__mro__", []) if "__annotations__" in c.__dict__)) - - -def all_annotations(cls: type, _except: typing.Optional[typing.Iterable[str]] = None) -> dict[str, type]: - """ - Wrapper around `_all_annotations` that filters away any keys in _except. - - It also flattens the ChainMap to a regular dict. - """ - if _except is None: - _except = set() - - _all = _all_annotations(cls) - return {k: v for k, v in _all.items() if k not in _except} - - T_Table = typing.TypeVar("T_Table", bound=TypedTable) TypeTable = typing.Type[T_Table] - T_Value = typing.TypeVar("T_Value") # actual type of the Field (via Generic) -# T_Table = typing.TypeVar("T_Table") # typevar used by __get__ - - ### db = TypeDAL("sqlite:memory:") diff --git a/tests/test_query_builder.py b/tests/test_query_builder.py index 2fef068..9dbbd40 100644 --- a/tests/test_query_builder.py +++ b/tests/test_query_builder.py @@ -521,8 +521,7 @@ def test_collect_with_extra_fields(): assert builder.execute() - class HTTP(BaseException): - ... + class HTTP(BaseException): ... row = builder.first_or_fail(HTTP(404)) diff --git a/tests/test_relationships.py b/tests/test_relationships.py index 84fedc7..d1a4cf7 100644 --- a/tests/test_relationships.py +++ b/tests/test_relationships.py @@ -87,7 +87,8 @@ class Tagged(TypedTable): # pivot table @db.define() -class Empty(TypedTable): ... +class Empty(TypedTable): + ... def _setup_data(): @@ -108,7 +109,7 @@ def _setup_data(): {"name": "Reader 1", "roles": [reader], "main_role": reader, "extra_roles": []}, {"name": "Writer 1", "roles": [reader, writer], "main_role": writer, "extra_roles": []}, {"name": "Editor 1", "roles": [reader, writer, editor], "main_role": editor, "extra_roles": []}, - ] + ], ) # no relationships: @@ -122,7 +123,7 @@ def _setup_data(): [ {"title": "Article 1", "author": writer, "final_editor": editor}, {"title": "Article 2", "author": editor, "secondary_author": editor}, - ] + ], ) # tags @@ -134,7 +135,7 @@ def _setup_data(): {"name": "breaking-news"}, {"name": "trending"}, {"name": "off-topic"}, - ] + ], ) # tagged @@ -150,7 +151,7 @@ def _setup_data(): {"entity": writer.gid, "tag": tag_trending}, # tags {"entity": tag_offtopic.gid, "tag": tag_draft}, - ] + ], ) BestFriend.insert(friend=reader, name="Reader's Bestie") @@ -296,7 +297,8 @@ def test_typedal_way(): author1 = User.where(id=4).join().first() assert ( - len(author1.as_dict()["articles"]) == len(author1.__dict__["articles"]) == len(dict(author1)["articles"]) == 2 + len(author1.as_dict()["articles"]) == len(author1.__dict__["articles"]) == len( + dict(author1)["articles"]) == 2 ) @@ -384,7 +386,7 @@ def test_join_with_different_condition(): assert role_with_users.users[0].name == "Reader 1" role_with_users = Role.join( - "users", method="inner", condition_and=lambda role, user: ~user.name.like("Reader%") + "users", method="inner", condition_and=lambda role, user: ~user.name.like("Reader%"), ).first() assert role_with_users.users @@ -392,7 +394,7 @@ def test_join_with_different_condition(): # left: role_with_users = Role.join( - "users", method="left", condition_and=lambda role, user: ~user.name.like("Reader%") + "users", method="left", condition_and=lambda role, user: ~user.name.like("Reader%"), ).first() assert role_with_users.users @@ -429,12 +431,12 @@ def test_caching(): cached_user_only2 = User.join().cache(User.id).collect_or_fail() assert ( - len(uncached2) - == len(uncached) - == len(cached2) - == len(cached) - == len(cached_user_only2) - == len(cached_user_only) + len(uncached2) + == len(uncached) + == len(cached2) + == len(cached) + == len(cached_user_only2) + == len(cached_user_only) ) assert uncached.as_json() == uncached2.as_json() == cached.as_json() == cached2.as_json() @@ -442,9 +444,9 @@ def test_caching(): assert cached.first().gid == cached2.first().gid assert ( - [_.name for _ in uncached2.first().roles] - == [_.name for _ in cached.first().roles] - == [_.name for _ in cached2.first().roles] + [_.name for _ in uncached2.first().roles] + == [_.name for _ in cached.first().roles] + == [_.name for _ in cached2.first().roles] ) assert not uncached2.metadata.get("cache", {}).get("enabled") @@ -552,7 +554,7 @@ def test_caching_dependencies(): [ {"name": "een"}, {"name": "twee"}, - ] + ], ) CacheTwoRelationships.insert(first=first_one, second=second_one) @@ -582,7 +584,6 @@ def test_caching_dependencies(): def test_illegal(): with pytest.raises(ValueError), pytest.warns(UserWarning): - class HasRelationship: something = relationship("...", condition=lambda: 1, on=lambda: 2) @@ -642,3 +643,43 @@ def test_accessing_raw_data(): assert {row.user.id for row in user._rows} == {4} assert {row.articles.id for row in user._rows} == {1, 2} + + +def test_nested_relationships(): + _setup_data() + + # old: + users = Role.where(name="reader").join("users").first().users + # 2 queries + old_besties = { + user.name: user.bestie.name if user.bestie else "-" + for user in User.where(User.id.belongs(u.id for u in users)).join("bestie").orderby(User.name) + } + + # new: + + new_besties = { + user.name: user.bestie.name if user.bestie else "-" + for user in Role.where(name="reader").join("users.bestie", "users.articles").first().users # 1 query + } + + # check: + + assert old_besties == new_besties == {"Editor 1": "-", "Reader 1": "Reader's Bestie", "Writer 1": "-"} + + + # more complex: + role = Role.where(name="reader").join("users.bestie", "users.articles.final_editor", "users.articles.secondary_author") + + nested_article = role.first().users[2].articles[0] + + assert nested_article.title == "Article 2" + + assert nested_article.secondary_author + assert not nested_article.final_editor + + # complex, inner: + role_inner = Role.where(name="reader").join("users.bestie", "users.articles.final_editor", "users.articles.secondary_author", method="inner") + + # no final_editor -> inner join should fail: + assert not role_inner.first() \ No newline at end of file