diff --git a/sqlalchemy_history/reverter.py b/sqlalchemy_history/reverter.py index dca6dd1..d529caf 100644 --- a/sqlalchemy_history/reverter.py +++ b/sqlalchemy_history/reverter.py @@ -67,10 +67,11 @@ def revert_relationship(self, prop): value = self.revert_child(child_obj, prop) if value: values.append(value) - - for value in getattr(self.version_parent, prop.key, []): + collection = getattr(self.version_parent, prop.key, []) + for value in list(collection): if value not in values: self.session.delete(value) + collection.remove(value) else: self.revert_child(getattr(self.obj, prop.key), prop) diff --git a/sqlalchemy_history/unit_of_work.py b/sqlalchemy_history/unit_of_work.py index 9f09fbe..806e479 100644 --- a/sqlalchemy_history/unit_of_work.py +++ b/sqlalchemy_history/unit_of_work.py @@ -5,7 +5,7 @@ import sqlalchemy as sa from sqlalchemy_utils import get_primary_keys, identity -from sqlalchemy_history.operation import Operations +from sqlalchemy_history.operation import Operation, Operations from sqlalchemy_history.schema import update_end_tx_column from sqlalchemy_history.utils import ( end_tx_column_name, @@ -301,9 +301,13 @@ def assign_attributes(self, parent_obj, version_obj): :param version_obj: Version object to assign the attribute values to """ + state = sa.inspect(parent_obj) for prop in versioned_column_properties(parent_obj): - try: - value = getattr(parent_obj, prop.key) - except sa.orm.exc.ObjectDeletedError: + if version_obj.operation_type == Operation.DELETE and prop.key in state.unloaded: value = None + else: + try: + value = getattr(parent_obj, prop.key) + except sa.orm.exc.ObjectDeletedError: + value = None setattr(version_obj, prop.key, value) diff --git a/tests/revert/test_polymorphic_relationship.py b/tests/revert/test_polymorphic_relationship.py new file mode 100644 index 0000000..785cf95 --- /dev/null +++ b/tests/revert/test_polymorphic_relationship.py @@ -0,0 +1,65 @@ +import sqlalchemy as sa +from tests import TestCase + + +class TestRevertPolymorphicRelationship(TestCase): + def create_models(self): + class Car(self.Model): + __tablename__ = "car" + __versioned__ = {} + + id = sa.Column( + sa.Integer, sa.Sequence(f"{__tablename__}_seq", start=1), autoincrement=True, primary_key=True + ) + parts = sa.orm.relationship( + "Part", back_populates="car", cascade="all, delete-orphan", lazy="selectin" + ) + + class Part(self.Model): + __tablename__ = "part" + __versioned__ = {} + + id = sa.Column( + sa.Integer, sa.Sequence(f"{__tablename__}_seq", start=1), autoincrement=True, primary_key=True + ) + car_id = sa.Column(sa.Integer, sa.ForeignKey(Car.id)) + car = sa.orm.relationship(Car, back_populates="parts") + + type = sa.Column(sa.String(50)) + + __mapper_args__ = { + "polymorphic_identity": "part", + "polymorphic_on": type, + } + + class Tire(Part): + __tablename__ = "tire" + __versioned__ = {} + + id = sa.Column(sa.Integer, sa.ForeignKey(Part.id), primary_key=True) + radius = sa.Column(sa.Integer) + width = sa.Column(sa.Integer) + + __mapper_args__ = { + "polymorphic_identity": "tire", + } + + self.Car = Car + self.Part = Part + self.Tire = Tire + + def test_revert_polymorphic_relationship(self): + car = self.Car() + self.session.add(car) + self.session.commit() + + tire = self.Tire(radius=15, width=200) + car.parts.append(tire) + self.session.commit() + + initial_version = car.versions.all()[0] + reverted_car = initial_version.revert(relations=["parts"]) + + assert len(reverted_car.parts) == 0 + + self.session.flush()