diff --git a/docs/queries.md b/docs/queries.md index e1ca8f3..54aefbc 100644 --- a/docs/queries.md +++ b/docs/queries.md @@ -4,16 +4,18 @@ You can query history models just like any other sqlalchemy declarative model. ```python >>> from sqlalchemy_history import version_class +>>> import sqlalchemy as sa >>> ArticleVersion = version_class(Article) ->>> session.query(ArticleVersion).filter_by(name=u'some name').all() +>>> session.scalars(sa.select(ArticleVersion).filter_by(name=u'some name')).all() ``` ## How many transactions have been executed? ```python >>> from sqlalchemy_history import transaction_class +>>> import sqlalchemy as sa >>> Transaction = transaction_class(Article) ->>> Transaction.query.count() +>>> session.scalar(sa.select(sa.func.count()).select_from(Transaction)) ``` ## Querying for entities of a class at a given revision @@ -21,7 +23,7 @@ You can query history models just like any other sqlalchemy declarative model. In the following example we find all articles which were affected by transaction 33. ```python ->>> session.query(ArticleVersion).filter_by(transaction_id=33) +>>> session.scalars(sa.select(ArticleVersion).filter_by(transaction_id=33)).all() ``` ## Querying for transactions, at which entities of a given class changed @@ -30,13 +32,14 @@ In this example we find all transactions which affected any instance of 'Article ```python >>> TransactionChanges = Article.__versioned__['transaction_changes'] ->>> entries = ( -... session.query(Transaction) -... .innerjoin(Transaction.changes) -... .filter( +>>> statement = ( +... sa.select(Transaction) +... .join(Transaction.changes) +... .where( ... TransactionChanges.entity_name.in_(['Article']) ... ) ... ) +... entries = session.scalars(statement).all() ``` ## Querying for versions of entity that modified given property @@ -46,5 +49,5 @@ PropertyModTrackerPlugin. ```python >>> ArticleVersion = version_class(Article) ->>> session.query(ArticleHistory).filter(ArticleVersion.name_mod).all() +>>> session.scalars(sa.select(ArticleHistory).filter(ArticleVersion.name_mod)).all() ``` diff --git a/docs/revert.md b/docs/revert.md index 4aac48d..8c2b27c 100644 --- a/docs/revert.md +++ b/docs/revert.md @@ -28,7 +28,7 @@ One of the major benefits of SQLAlchemy-History is its ability to revert changes >>> session.commit() >>> version.revert() >>> session.commit() # article lives again! ->>> session.query(Article).first() +>>> session.scalars(sa.select(Article).limit(1)).first() ``` ## Revert relationships @@ -64,7 +64,7 @@ Now lets say some user first adds an article with couple of tags: Then lets say another user deletes one of the tags: ```python ->>> tag = session.query(Tag).filter_by(name=u'Interesting') +>>> tag = session.scalar(sa.select(Tag).where(Tag.name == "Interesting")) >>> session.delete(tag) >>> session.commit() ``` @@ -72,7 +72,7 @@ Then lets say another user deletes one of the tags: Now the first user wants to set the article back to its original state. It can be achieved as follows (notice how we use the relations parameter): ```python ->>> article = session.query(Article).get(1) +>>> article = session.get(Article, 1) >>> article.versions[0].revert(relations=['tags']) >>> session.commit() ``` diff --git a/docs/transactions.md b/docs/transactions.md index 509ce09..c677b01 100644 --- a/docs/transactions.md +++ b/docs/transactions.md @@ -9,7 +9,7 @@ Transaction can be queried just like any other sqlalchemy declarative model. ```python >>> from sqlalchemy_history import transaction_class >>> Transaction = transaction_class(Article) ->>> session.query(Transaction).all() # find all transactions +>>> session.scalars(sa.select(Transaction)).all() # find all transactions ``` ## UnitOfWork diff --git a/sqlalchemy_history/fetcher.py b/sqlalchemy_history/fetcher.py index b6c18e0..4abf4aa 100644 --- a/sqlalchemy_history/fetcher.py +++ b/sqlalchemy_history/fetcher.py @@ -35,7 +35,8 @@ def previous(self, obj): history. If current version is the first version this method returns None. """ - return self.previous_query(obj).first() + session = sa.orm.object_session(obj) + return session.scalars(self.previous_query(obj).limit(1)).first() def index(self, obj): """ @@ -50,7 +51,8 @@ def next(self, obj): history. If current version is the last version this method returns None. """ - return self.next_query(obj).first() + session = sa.orm.object_session(obj) + return session.scalars(self.next_query(obj).limit(1)).first() def _transaction_id_subquery(self, obj, next_or_prev="next", alias=None): if next_or_prev == "next": @@ -91,12 +93,10 @@ def _transaction_id_subquery(self, obj, next_or_prev="next", alias=None): return query.scalar_subquery() def _next_prev_query(self, obj, next_or_prev="next"): - session = sa.orm.object_session(obj) - subquery = self._transaction_id_subquery(obj, next_or_prev=next_or_prev) subquery = subquery.scalar_subquery() - return session.query(obj.__class__).filter( + return sa.select(obj.__class__).filter( sa.and_(getattr(obj.__class__, tx_column_name(obj)) == subquery, *parent_criteria(obj)) ) @@ -145,9 +145,7 @@ def next_query(self, obj): Returns the query that fetches the next version relative to this version in the version history. """ - session = sa.orm.object_session(obj) - - return session.query(obj.__class__).filter( + return sa.select(obj.__class__).filter( sa.and_( getattr(obj.__class__, tx_column_name(obj)) == getattr(obj, end_tx_column_name(obj)), *parent_criteria(obj), @@ -159,9 +157,7 @@ def previous_query(self, obj): Returns the query that fetches the previous version relative to this version in the version history. """ - session = sa.orm.object_session(obj) - - return session.query(obj.__class__).filter( + return sa.select(obj.__class__).filter( sa.and_( getattr(obj.__class__, end_tx_column_name(obj)) == getattr(obj, tx_column_name(obj)), *parent_criteria(obj), diff --git a/sqlalchemy_history/plugins/activity.py b/sqlalchemy_history/plugins/activity.py index 3ff9805..885583e 100644 --- a/sqlalchemy_history/plugins/activity.py +++ b/sqlalchemy_history/plugins/activity.py @@ -145,12 +145,9 @@ ```python >>> import sqlalchemy as sa ->>> activities = session.query(Activity).filter( -... sa.or_( -... Activity.object == article, -... Activity.target == article -... ) -... ) +>>> activities = session.scalars( +... sa.select(Activity).filter(sa.or_(Activity.object == article, Activity.target == article)) +... ).all() ``` #### Also Read @@ -223,11 +220,13 @@ def _calculate_tx_id(self, obj): model = obj.__class__ version_cls = version_class(model) primary_key = inspect(model).primary_key[0].name - return ( - session.query(sa.func.max(version_cls.transaction_id)) - .filter(getattr(version_cls, primary_key) == getattr(obj, primary_key)) - .scalar() - ) + return session.execute( + ( + sa.select(sa.func.max(version_cls.transaction_id)).where( + getattr(version_cls, primary_key) == getattr(obj, primary_key) + ) + ) + ).scalar_one_or_none() def calculate_object_tx_id(self): self.object_tx_id = self._calculate_tx_id(self.object) diff --git a/sqlalchemy_history/plugins/transaction_meta.py b/sqlalchemy_history/plugins/transaction_meta.py index 883ade4..4137b93 100644 --- a/sqlalchemy_history/plugins/transaction_meta.py +++ b/sqlalchemy_history/plugins/transaction_meta.py @@ -36,7 +36,7 @@ # find all transactions with 'article' tags query = ( - session.query(Transaction) + sa.select(Transaction) .join(Transaction.meta_relation) .filter( db.and_( diff --git a/sqlalchemy_history/relationship_builder.py b/sqlalchemy_history/relationship_builder.py index 677f0b0..1cf34de 100644 --- a/sqlalchemy_history/relationship_builder.py +++ b/sqlalchemy_history/relationship_builder.py @@ -9,10 +9,33 @@ from sqlalchemy_history.operation import Operation from sqlalchemy_history.table_builder import TableBuilder from sqlalchemy_history.utils import adapt_columns, version_class, option +import warnings +import typing as t +from sqlalchemy.orm import Session +from sqlalchemy.orm import RelationshipProperty +from sqlalchemy.sql.selectable import ExecutableReturnsRows + + +_T = t.TypeVar("_T") + + +class _WriteOnlyCollectionAdapter(t.Generic[_T]): + """ + Minimal adapter that exposes a write-only-collection-like `select()` API + backed by a preconstructed SQLAlchemy `Select` clause. + """ + + def __init__(self, statement: sa.Select[_T]): + self._statement = statement + + def select(self) -> sa.Select[_T]: + return self._statement class RelationshipBuilder(object): - def __init__(self, versioning_manager, model, property_): + property: RelationshipProperty + + def __init__(self, versioning_manager, model, property_: RelationshipProperty): self.manager = versioning_manager self.property = property_ self.model = model @@ -55,21 +78,39 @@ def many_to_one_subquery(self, obj): subquery = subquery.scalar_subquery() return getattr(self.remote_cls, tx_column) == subquery - def query(self, obj): - session = sa.orm.object_session(obj) - return session.query(self.remote_cls).filter(self.criteria(obj)) + def select(self, obj): + return sa.select(self.remote_cls).filter(self.criteria(obj)) - def process_query(self, query): + def process_query(self, query: ExecutableReturnsRows, session: Session): """Process given SQLAlchemy Query object depending on the associated RelationshipProperty object. - :param query: SQLAlchemy Query object + This method handles both legacy Query objects (for backward compatibility with + lazy='dynamic' relationships) and modern SQLAlchemy 2.0 select statements, executing + them appropriately based on the relationship's properties. + :param query: SQLAlchemy select clause + + Notes + ----- + The lazy='dynamic' strategy is deemed legacy in SQLAlchemy and maintained here only + for backward compatibility. Users should migrate to lazy='write_only' for similar + functionality in SQLAlchemy 2.0+. + See: https://docs.sqlalchemy.org/en/20/changelog/migration_20.html#dynamic-relationship-loaders-superseded-by-write-only """ if self.property.lazy == "dynamic": - return query + warnings.warn( + f'The lazy="dynamic" strategy is now legacy and is superseded by lazy="write_only" in SQLAlchemy 2.0. ' + f"Please consider migrating to the write_only strategy for relationship {self.property.key!r}.", + DeprecationWarning, + stacklevel=2, + ) + # Build legacy Query object for backward compatibility + return session.query(self.remote_cls).from_statement(query) + elif self.property.lazy == "write_only": + return _WriteOnlyCollectionAdapter(query) if self.property.uselist is False: - return query.first() - return query.all() + return session.scalars(query.limit(1)).first() + return session.scalars(query).all() def criteria(self, obj): direction = self.property.direction @@ -222,8 +263,8 @@ def reflected_relationship(self): @property def relationship(obj): - query = self.query(obj) - return self.process_query(query) + session = sa.orm.object_session(obj) + return self.process_query(self.select(obj), session) return relationship diff --git a/sqlalchemy_history/transaction.py b/sqlalchemy_history/transaction.py index 0088f4d..860f2e3 100644 --- a/sqlalchemy_history/transaction.py +++ b/sqlalchemy_history/transaction.py @@ -51,9 +51,10 @@ def changed_entities(self): tx_column = manager.option(class_, "transaction_column_name") - entities[version_class] = ( - session.query(version_class).filter(getattr(version_class, tx_column) == self.id) + entities[version_class] = session.scalars( + sa.select(version_class).filter(getattr(version_class, tx_column) == self.id) ).all() + return entities diff --git a/sqlalchemy_history/unit_of_work.py b/sqlalchemy_history/unit_of_work.py index b3d3889..9f09fbe 100644 --- a/sqlalchemy_history/unit_of_work.py +++ b/sqlalchemy_history/unit_of_work.py @@ -229,19 +229,21 @@ def update_version_validity(self, parent, version_obj): parent, version_obj, alias=sa.orm.aliased(class_.__table__) ) subquery = subquery.scalar_subquery() - query = session.query(class_.__table__).filter( - sa.and_( - getattr(class_, tx_column_name(version_obj)) == subquery, - *[ - getattr(version_obj, pk) == getattr(class_.__table__.c, pk) - for pk in get_primary_keys(class_) - if pk != tx_column_name(class_) - ], + + session.execute( + sa.update(class_.__table__) + .where( + sa.and_( + getattr(class_, tx_column_name(version_obj)) == subquery, + *[ + getattr(version_obj, pk) == getattr(class_.__table__.c, pk) + for pk in get_primary_keys(class_) + if pk != tx_column_name(class_) + ], + ) ) - ) - query.update( - {end_tx_column_name(version_obj): self.current_transaction.id}, - synchronize_session=False, + .values(**{end_tx_column_name(version_obj): self.current_transaction.id}) + .execution_options(synchronize_session=False) ) def create_association_versions(self, session): diff --git a/sqlalchemy_history/utils.py b/sqlalchemy_history/utils.py index 30e5b36..3c0966d 100644 --- a/sqlalchemy_history/utils.py +++ b/sqlalchemy_history/utils.py @@ -228,9 +228,9 @@ def vacuum(session, model, yield_per=1000): version_cls = version_class(model) versions = defaultdict(list) - query = (session.query(version_cls).order_by(option(version_cls, "transaction_column_name"))).yield_per( - yield_per - ) + query = session.scalars( + sa.select(version_cls).order_by(option(version_cls, "transaction_column_name")) + ).yield_per(yield_per) primary_key_col = sa.inspection.inspect(model).primary_key[0].name diff --git a/tests/builders/test_table_builder.py b/tests/builders/test_table_builder.py index 59c7079..4f9205f 100644 --- a/tests/builders/test_table_builder.py +++ b/tests/builders/test_table_builder.py @@ -16,7 +16,7 @@ def test_assigns_foreign_keys_for_versions(self): self.session.add(article) self.session.commit() cls = version_class(self.Tag) - version = self.session.query(cls).first() + version = self.session.scalars(sa.select(cls)).first() assert version.name == "some tag" assert version.id == 1 assert version.article_id == 1 diff --git a/tests/inheritance/test_concrete_inheritance.py b/tests/inheritance/test_concrete_inheritance.py index ee326c8..afda068 100644 --- a/tests/inheritance/test_concrete_inheritance.py +++ b/tests/inheritance/test_concrete_inheritance.py @@ -80,8 +80,8 @@ def test_transaction_changed_entities(self): self.session.add(article) self.session.commit() Transaction = versioning_manager.transaction_cls - transaction = ( - self.session.query(Transaction).order_by(sa.sql.expression.desc(Transaction.issued_at)) + transaction = self.session.scalars( + sa.select(Transaction).order_by(sa.sql.expression.desc(Transaction.issued_at)) ).first() assert transaction.entity_names == ["Article"] assert transaction.changed_entities diff --git a/tests/inheritance/test_join_table_inheritance.py b/tests/inheritance/test_join_table_inheritance.py index 6747593..c590abc 100644 --- a/tests/inheritance/test_join_table_inheritance.py +++ b/tests/inheritance/test_join_table_inheritance.py @@ -75,7 +75,7 @@ def test_with_polymorphic(self): self.session.add(article) self.session.commit() - version_obj = self.session.query(self.TextItemVersion).first() + version_obj = self.session.scalars(sa.select(self.TextItemVersion)).first() assert isinstance(version_obj, self.ArticleVersion) def test_consecutive_insert_and_delete(self): diff --git a/tests/inheritance/test_single_table_inheritance.py b/tests/inheritance/test_single_table_inheritance.py index bde3fe1..0862d0e 100644 --- a/tests/inheritance/test_single_table_inheritance.py +++ b/tests/inheritance/test_single_table_inheritance.py @@ -78,8 +78,8 @@ def test_transaction_changed_entities(self): self.session.add(article) self.session.commit() Transaction = versioning_manager.transaction_cls - transaction = ( - self.session.query(Transaction).order_by(sa.sql.expression.desc(Transaction.issued_at)) + transaction = self.session.scalars( + sa.select(Transaction).order_by(sa.sql.expression.desc(Transaction.issued_at)) ).first() assert transaction.entity_names == ["Article"] assert transaction.changed_entities diff --git a/tests/plugins/test_activity.py b/tests/plugins/test_activity.py index 12fdc6d..adacfe8 100644 --- a/tests/plugins/test_activity.py +++ b/tests/plugins/test_activity.py @@ -62,7 +62,7 @@ def test_create_activity_with_pk(self): self.session.commit() self.create_activity(not_id_model) self.session.commit() - activity = self.session.query(versioning_manager.activity_cls).first() + activity = self.session.scalars(sa.select(versioning_manager.activity_cls)).first() assert activity assert activity.transaction_id assert activity.object == not_id_model @@ -82,7 +82,7 @@ def test_create_activity(self): self.session.flush() self.create_activity(article) self.session.commit() - activity = self.session.query(versioning_manager.activity_cls).first() + activity = self.session.scalars(sa.select(versioning_manager.activity_cls)).first() assert activity assert activity.transaction_id assert activity.object == article @@ -99,11 +99,9 @@ def test_delete_activity(self): ) self.session.add(activity) self.session.commit() - versions = ( - self.session.query(self.ArticleVersion) - .order_by(sa.desc(self.ArticleVersion.transaction_id)) - .all() - ) + versions = self.session.scalars( + sa.select(self.ArticleVersion).order_by(sa.desc(self.ArticleVersion.transaction_id)) + ).all() assert activity assert activity.transaction_id assert activity.object is None @@ -126,8 +124,8 @@ def test_activity_queries(self): ) self.session.add(activity) self.session.commit() - activities = self.session.query(Activity).filter( - sa.or_(Activity.object == article, Activity.target == article) + activities = self.session.scalars( + sa.select(Activity).filter(sa.or_(Activity.object == article, Activity.target == article)) ) assert activities.count() == 2 @@ -198,7 +196,9 @@ def test_activity_target(self): ) self.session.add(activity) self.session.commit() - activity = self.session.query(versioning_manager.activity_cls).filter_by(id=activity.id).one() + activity = self.session.scalars( + sa.select(versioning_manager.activity_cls).filter_by(id=activity.id) + ).one() assert activity assert activity.transaction_id assert activity.object == tag diff --git a/tests/plugins/test_null_delete.py b/tests/plugins/test_null_delete.py index b17c46a..2484fbb 100644 --- a/tests/plugins/test_null_delete.py +++ b/tests/plugins/test_null_delete.py @@ -1,3 +1,4 @@ +import sqlalchemy as sa from sqlalchemy_history.plugins import NullDeletePlugin from tests import TestCase @@ -15,7 +16,7 @@ def _delete(self): def test_stores_operation_type(self): self._delete() - versions = self.session.query(self.ArticleVersion).all() + versions = self.session.scalars(sa.select(self.ArticleVersion)).all() assert versions[1].operation_type == 2 @@ -24,7 +25,7 @@ class TestDeleteWithoutStoreDataAtDelete(DeleteTestCase): def test_creates_versions_on_delete(self): self._delete() - versions = self.session.query(self.ArticleVersion).all() + versions = self.session.scalars(sa.select(self.ArticleVersion)).all() assert len(versions) == 2 assert versions[1].name is None assert versions[1].content is None diff --git a/tests/plugins/test_property_mod_tracker.py b/tests/plugins/test_property_mod_tracker.py index fbf4ceb..be84df4 100644 --- a/tests/plugins/test_property_mod_tracker.py +++ b/tests/plugins/test_property_mod_tracker.py @@ -55,7 +55,9 @@ def test_mod_properties_with_delete(self): self.session.delete(user) self.session.commit() UserVersion = version_class(self.User) - version = (self.session.query(UserVersion).order_by(sa.desc(UserVersion.transaction_id))).first() + version = self.session.scalars( + sa.select(UserVersion).order_by(sa.desc(UserVersion.transaction_id)) + ).first() assert version.age_mod assert version.name_mod diff --git a/tests/plugins/test_transaction_changes.py b/tests/plugins/test_transaction_changes.py index 77225a8..03cc537 100644 --- a/tests/plugins/test_transaction_changes.py +++ b/tests/plugins/test_transaction_changes.py @@ -1,3 +1,4 @@ +import sqlalchemy as sa from sqlalchemy_history import version_class from sqlalchemy_history.plugins import TransactionChangesPlugin from tests import TestCase @@ -64,4 +65,4 @@ def test_saves_only_modified_entity_names(self): article.name = "Some article" self.session.commit() - assert self.session.query(TransactionChanges).count() == 1 + assert self.session.scalar(sa.select(sa.func.count()).select_from(TransactionChanges)) == 1 diff --git a/tests/relationships/test_dynamic_relationships.py b/tests/relationships/test_dynamic_relationships.py index 0db65c7..3073022 100644 --- a/tests/relationships/test_dynamic_relationships.py +++ b/tests/relationships/test_dynamic_relationships.py @@ -1,6 +1,7 @@ from copy import copy from tests import TestCase import sqlalchemy as sa +import pytest class TestDynamicOneToManyRelationships(TestCase): @@ -37,7 +38,8 @@ def test_reflects_dynamic_relationships_as_dynamic(self): self.session.add(article) self.session.commit() - assert article.versions[0].tags + with pytest.deprecated_call(match='The lazy="dynamic" strategy is now legacy'): + assert article.versions[0].tags class TestDynamicManyToManyRelationships(TestCase): @@ -85,4 +87,5 @@ def test_version_relations(self): article.content = "Some content" self.session.add(article) self.session.commit() - assert article.versions[0].tags + with pytest.deprecated_call(match='The lazy="dynamic" strategy is now legacy'): + assert article.versions[0].tags diff --git a/tests/relationships/test_write_only_relationships.py b/tests/relationships/test_write_only_relationships.py new file mode 100644 index 0000000..ecddeb7 --- /dev/null +++ b/tests/relationships/test_write_only_relationships.py @@ -0,0 +1,150 @@ +from copy import copy +from tests import TestCase +import sqlalchemy as sa + + +class TestWriteOnlyOneToManyRelationships(TestCase): + def create_models(self): + class Article(self.Model): + __tablename__ = "article" + __versioned__ = copy(self.options) + + id = sa.Column( + sa.Integer, sa.Sequence(f"{__tablename__}_seq", start=1), autoincrement=True, primary_key=True + ) + name = sa.Column(sa.Unicode(255), nullable=False) + content = sa.Column(sa.UnicodeText) + description = sa.Column(sa.UnicodeText) + + class Tag(self.Model): + __tablename__ = "tag" + __versioned__ = copy(self.options) + + id = sa.Column( + sa.Integer, sa.Sequence(f"{__tablename__}_seq", start=1), autoincrement=True, primary_key=True + ) + name = sa.Column(sa.Unicode(255)) + article_id = sa.Column(sa.Integer, sa.ForeignKey(Article.id)) + article = sa.orm.relationship(Article, backref=sa.orm.backref("tags", lazy="write_only")) + + self.Article = Article + self.Tag = Tag + + def test_reflects_write_only_relationships_as_write_only(self): + article = self.Article() + article.name = "Some article" + article.content = "Some content" + self.session.add(article) + self.session.commit() + + # Verify the relationship is write_only and has select() method + version = article.versions[0] + assert hasattr(version.tags, "select") + + # Verify select() returns a Select statement + select_stmt = version.tags.select() + assert isinstance(select_stmt, sa.sql.Select) + + # Execute the select to verify it's lazy and can be queried + result = self.session.execute(select_stmt).scalars().all() + assert isinstance(result, list) + + def test_write_only_relationship_with_tags(self): + article = self.Article() + article.name = "Article with tags" + article.content = "Content here" + self.session.add(article) + + tag1 = self.Tag(name="Python", article=article) + tag2 = self.Tag(name="SQLAlchemy", article=article) + self.session.add_all([tag1, tag2]) + self.session.commit() + + # Verify the version relationship is lazy + version = article.versions[0] + select_stmt = version.tags.select() + tags = self.session.execute(select_stmt).scalars().all() + + # Tags should be retrievable via select() + assert len(tags) == 2 + assert any(tag.name == "Python" for tag in tags) + assert any(tag.name == "SQLAlchemy" for tag in tags) + + +class TestWriteOnlyManyToManyRelationships(TestCase): + def create_models(self): + class Article(self.Model): + __tablename__ = "article" + __versioned__ = {"base_classes": (self.Model,)} + + id = sa.Column( + sa.Integer, sa.Sequence(f"{__tablename__}_seq", start=1), autoincrement=True, primary_key=True + ) + name = sa.Column(sa.Unicode(255)) + + article_tag = sa.Table( + "article_tag", + self.Model.metadata, + sa.Column( + "article_id", + sa.Integer, + sa.ForeignKey("article.id", ondelete="CASCADE"), + primary_key=True, + ), + sa.Column("tag_id", sa.Integer, sa.ForeignKey("tag.id", ondelete="CASCADE"), primary_key=True), + ) + + class Tag(self.Model): + __tablename__ = "tag" + __versioned__ = {"base_classes": (self.Model,)} + + id = sa.Column( + sa.Integer, sa.Sequence(f"{__tablename__}_seq", start=1), autoincrement=True, primary_key=True + ) + name = sa.Column(sa.Unicode(255)) + + Tag.articles = sa.orm.relationship( + Article, secondary=article_tag, backref=sa.orm.backref("tags", lazy="write_only") + ) + + self.Article = Article + self.Tag = Tag + + def test_version_relations(self): + article = self.Article() + article.name = "Some article" + self.session.add(article) + self.session.commit() + + # Verify the relationship is write_only and has select() method + version = article.versions[0] + assert hasattr(version.tags, "select") + + # Verify select() returns a Select statement + select_stmt = version.tags.select() + assert isinstance(select_stmt, sa.sql.Select) + + def test_write_only_many_to_many_with_data(self): + article = self.Article() + article.name = "Article about Python" + self.session.add(article) + + tag1 = self.Tag(name="Python") + tag2 = self.Tag(name="Programming") + self.session.add_all([tag1, tag2]) + self.session.flush() + + # Add tags to article using add() method + article.tags.add(tag1) + article.tags.add(tag2) + self.session.commit() + + # Verify the version relationship is write_only + version = article.versions[0] + select_stmt = version.tags.select() + tags = self.session.execute(select_stmt).scalars().all() + + # Tags should be retrievable via select() + assert len(tags) == 2 + assert any(tag.name == "Python" for tag in tags) + assert any(tag.name == "Programming" for tag in tags) diff --git a/tests/reported_bugs/test_bug_141_after_flush_postexec_op_type_issue.py b/tests/reported_bugs/test_bug_141_after_flush_postexec_op_type_issue.py index 5cbedc3..8df6eb7 100644 --- a/tests/reported_bugs/test_bug_141_after_flush_postexec_op_type_issue.py +++ b/tests/reported_bugs/test_bug_141_after_flush_postexec_op_type_issue.py @@ -30,14 +30,14 @@ def after_flush_postexec(session, flush_context): self.session.add(author) self.session.commit() - versioned_objs = self.session.query(version_class(self.Author)).all() + versioned_objs = self.session.scalars(sa.select(version_class(self.Author))).all() assert len(versioned_objs) == 1 assert versioned_objs[0].operation_type == 0 assert versioned_objs[0].name == "yoyoyoyoyo" author.name = "sdfeoinfe" self.session.add(author) self.session.commit() - versioned_objs = self.session.query(version_class(self.Author)).all() + versioned_objs = self.session.scalars(sa.select(version_class(self.Author))).all() assert len(versioned_objs) == 2 assert versioned_objs[0].operation_type == 0 assert versioned_objs[1].operation_type == 1 diff --git a/tests/reported_bugs/test_bug_27_datetime_insertion_issue.py b/tests/reported_bugs/test_bug_27_datetime_insertion_issue.py index ec87277..9f2cc10 100644 --- a/tests/reported_bugs/test_bug_27_datetime_insertion_issue.py +++ b/tests/reported_bugs/test_bug_27_datetime_insertion_issue.py @@ -56,6 +56,6 @@ def test_inserting_entries(self): self.session.add(author) self.session.commit() - obj = self.session.query(self.article_author_table).all() + obj = self.session.execute(sa.select(self.article_author_table)).all() assert len(obj) == 1 assert isinstance(obj[0][-1], datetime.datetime) # last col is a datetime! diff --git a/tests/schema/test_update_end_transaction_id.py b/tests/schema/test_update_end_transaction_id.py index 825d720..b822a03 100644 --- a/tests/schema/test_update_end_transaction_id.py +++ b/tests/schema/test_update_end_transaction_id.py @@ -126,11 +126,9 @@ def test_assoc_update_end_transaction_id(self): article.labels = [label] self.session.commit() - rows = ( - self.session.query(article_label_table_version) - .order_by(article_label_table_version.c.transaction_id) - .all() - ) + rows = self.session.execute( + sa.select(article_label_table_version).order_by(article_label_table_version.c.transaction_id) + ).all() if self.versioning_strategy == "validity": assert rows[0].label_id == label.id assert rows[0].transaction_id == 1 diff --git a/tests/test_accessors.py b/tests/test_accessors.py index 8a461f3..b255f24 100644 --- a/tests/test_accessors.py +++ b/tests/test_accessors.py @@ -40,8 +40,8 @@ def test_previous_for_deleted_parent(self): self.session.commit() self.session.delete(article) self.session.commit() - versions = ( - self.session.query(self.ArticleVersion).order_by( + versions = self.session.scalars( + sa.select(self.ArticleVersion).order_by( getattr(self.ArticleVersion, self.options["transaction_column_name"]) ) ).all() @@ -57,8 +57,8 @@ def test_previous_chaining(self): self.session.commit() self.session.delete(article) self.session.commit() - version = ( - self.session.query(self.ArticleVersion).order_by( + version = self.session.scalars( + sa.select(self.ArticleVersion).order_by( getattr(self.ArticleVersion, self.options["transaction_column_name"]) ) ).all()[-1] @@ -169,8 +169,8 @@ def test_index_for_deleted_parent(self): self.session.delete(article) self.session.commit() - versions = ( - self.session.query(self.ArticleVersion).order_by( + versions = self.session.scalars( + sa.select(self.ArticleVersion).order_by( getattr(self.ArticleVersion, self.options["transaction_column_name"]) ) ).all() diff --git a/tests/test_changeset.py b/tests/test_changeset.py index 1d5bcb2..29bee27 100644 --- a/tests/test_changeset.py +++ b/tests/test_changeset.py @@ -52,7 +52,7 @@ def test_changeset_for_history_that_does_not_have_first_insert(self): ) ) - assert self.session.query(self.ArticleVersion).first().changeset == { + assert self.session.scalars(sa.select(self.ArticleVersion)).first().changeset == { "content": [None, "some content"], "id": [None, 1], "name": [None, "something"], diff --git a/tests/test_column_aliases.py b/tests/test_column_aliases.py index 79ba53f..bccd243 100644 --- a/tests/test_column_aliases.py +++ b/tests/test_column_aliases.py @@ -55,8 +55,8 @@ def test_previous_for_deleted_parent(self): self.session.commit() TextItemVersion = version_class(self.TextItem) - versions = ( - self.session.query(TextItemVersion).order_by( + versions = self.session.scalars( + sa.select(TextItemVersion).order_by( getattr(TextItemVersion, self.options["transaction_column_name"]) ) ).all() diff --git a/tests/test_delete.py b/tests/test_delete.py index cef7fe7..f1eef32 100644 --- a/tests/test_delete.py +++ b/tests/test_delete.py @@ -15,12 +15,12 @@ def _delete(self): def test_stores_operation_type(self): self._delete() - versions = self.session.query(self.ArticleVersion).all() + versions = self.session.scalars(sa.select(self.ArticleVersion)).all() assert versions[1].operation_type == 2 def test_creates_versions_on_delete(self): self._delete() - versions = self.session.query(self.ArticleVersion).all() + versions = self.session.scalars(sa.select(self.ArticleVersion)).all() assert len(versions) == 2 assert versions[1].name == "Some article" assert versions[1].content == "Some content" diff --git a/tests/test_i18n.py b/tests/test_i18n.py index 598def8..bd45fb0 100644 --- a/tests/test_i18n.py +++ b/tests/test_i18n.py @@ -59,7 +59,7 @@ def test_changed_entities(self): self.session.commit() tx_log = versioning_manager.transaction_cls - tx = self.session.query(tx_log).order_by(sa.desc(tx_log.id)).first() + tx = self.session.scalars(sa.select(tx_log).order_by(sa.desc(tx_log.id))).first() assert "ArticleTranslation" in tx.entity_names def test_history_with_many_translations(self): @@ -73,7 +73,7 @@ def test_history_with_many_translations(self): self.session.commit() Transaction = versioning_manager.transaction_cls - transaction = self.session.query(Transaction).one() + transaction = self.session.scalars(sa.select(Transaction)).one() assert len(transaction.changes) == 2 assert "ArticleTranslation" in {chng.entity_name for chng in transaction.changes} diff --git a/tests/test_insert.py b/tests/test_insert.py index fbdc377..abfe2cf 100644 --- a/tests/test_insert.py +++ b/tests/test_insert.py @@ -83,4 +83,7 @@ def test_does_not_create_transaction(self): self.session.add(item) self.session.commit() - assert self.session.query(versioning_manager.transaction_cls).count() == 0 + assert ( + self.session.scalar(sa.select(sa.func.count()).select_from(versioning_manager.transaction_cls)) + == 0 + ) diff --git a/tests/test_revert.py b/tests/test_revert.py index 14a0197..405f380 100644 --- a/tests/test_revert.py +++ b/tests/test_revert.py @@ -57,7 +57,7 @@ def test_revert_deletion(self): self.session.commit() version.revert() self.session.commit() - assert self.session.query(self.Article).count() == 1 + assert self.session.scalar(sa.select(sa.func.count()).select_from(self.Article)) == 1 article = self.session.get(self.Article, old_article_id) assert version.next.next diff --git a/tests/test_transaction.py b/tests/test_transaction.py index 30c8de7..96fd1aa 100644 --- a/tests/test_transaction.py +++ b/tests/test_transaction.py @@ -28,10 +28,13 @@ def test_only_saves_transaction_if_actual_modifications(self): self.session.commit() self.article.name = "Some article" self.session.commit() - assert self.session.query(versioning_manager.transaction_cls).count() == 1 + assert ( + self.session.scalar(sa.select(sa.func.count()).select_from(versioning_manager.transaction_cls)) + == 1 + ) def test_repr(self): - transaction = self.session.query(versioning_manager.transaction_cls).first() + transaction = self.session.scalars(sa.select(versioning_manager.transaction_cls)).first() assert "" % (transaction.id, transaction.issued_at) == repr( transaction )