Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions sqlalchemy_history/reverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,11 @@ def revert_relationship(self, prop):
value = self.revert_child(child_obj, prop)
if value:
values.append(value)

for value in getattr(self.version_parent, prop.key, []):
collection = getattr(self.version_parent, prop.key, [])
for value in list(collection):
if value not in values:
self.session.delete(value)
collection.remove(value)
else:
self.revert_child(getattr(self.obj, prop.key), prop)

Expand Down
12 changes: 8 additions & 4 deletions sqlalchemy_history/unit_of_work.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import sqlalchemy as sa
from sqlalchemy_utils import get_primary_keys, identity

from sqlalchemy_history.operation import Operations
from sqlalchemy_history.operation import Operation, Operations
from sqlalchemy_history.schema import update_end_tx_column
from sqlalchemy_history.utils import (
end_tx_column_name,
Expand Down Expand Up @@ -301,9 +301,13 @@ def assign_attributes(self, parent_obj, version_obj):
:param version_obj: Version object to assign the attribute values to

"""
state = sa.inspect(parent_obj)
for prop in versioned_column_properties(parent_obj):
try:
value = getattr(parent_obj, prop.key)
except sa.orm.exc.ObjectDeletedError:
if version_obj.operation_type == Operation.DELETE and prop.key in state.unloaded:
value = None
else:
try:
value = getattr(parent_obj, prop.key)
except sa.orm.exc.ObjectDeletedError:
value = None
setattr(version_obj, prop.key, value)
65 changes: 65 additions & 0 deletions tests/revert/test_polymorphic_relationship.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import sqlalchemy as sa
from tests import TestCase


class TestRevertPolymorphicRelationship(TestCase):
def create_models(self):
class Car(self.Model):
__tablename__ = "car"
__versioned__ = {}

id = sa.Column(
sa.Integer, sa.Sequence(f"{__tablename__}_seq", start=1), autoincrement=True, primary_key=True
)
parts = sa.orm.relationship(
"Part", back_populates="car", cascade="all, delete-orphan", lazy="selectin"
)

class Part(self.Model):
__tablename__ = "part"
__versioned__ = {}

id = sa.Column(
sa.Integer, sa.Sequence(f"{__tablename__}_seq", start=1), autoincrement=True, primary_key=True
)
car_id = sa.Column(sa.Integer, sa.ForeignKey(Car.id))
car = sa.orm.relationship(Car, back_populates="parts")

type = sa.Column(sa.String(50))

__mapper_args__ = {
"polymorphic_identity": "part",
"polymorphic_on": type,
}

class Tire(Part):
__tablename__ = "tire"
__versioned__ = {}

id = sa.Column(sa.Integer, sa.ForeignKey(Part.id), primary_key=True)
radius = sa.Column(sa.Integer)
width = sa.Column(sa.Integer)

__mapper_args__ = {
"polymorphic_identity": "tire",
}

self.Car = Car
self.Part = Part
self.Tire = Tire

def test_revert_polymorphic_relationship(self):
car = self.Car()
self.session.add(car)
self.session.commit()

tire = self.Tire(radius=15, width=200)
car.parts.append(tire)
self.session.commit()

initial_version = car.versions.all()[0]
reverted_car = initial_version.revert(relations=["parts"])

assert len(reverted_car.parts) == 0

self.session.flush()
Loading