From 215cf24ec5eff4e4d9ab8b38e637c8e6c9692895 Mon Sep 17 00:00:00 2001 From: Pieter Robberechts Date: Tue, 23 Dec 2025 10:25:50 +0100 Subject: [PATCH] feat(eventdataset): export freeze frames to tracking dataset Adds a `EventDataset.to_tracking_data` method to export the freeze frames in a EventDataset to a TrackingDataset. --------- Co-authored-by: UnravelSports [JB] --- kloppy/domain/models/event.py | 18 +++++- .../serializers/event/statsbomb/helpers.py | 2 +- kloppy/tests/test_statsbomb.py | 61 ++++++++++++++++++- 3 files changed, 77 insertions(+), 4 deletions(-) diff --git a/kloppy/domain/models/event.py b/kloppy/domain/models/event.py index 5c597c425..caf494cc2 100644 --- a/kloppy/domain/models/event.py +++ b/kloppy/domain/models/event.py @@ -34,7 +34,7 @@ from .pitch import Point if TYPE_CHECKING: - from .tracking import Frame + from .tracking import Frame, TrackingDataset QualifierValueType = TypeVar("QualifierValueType") EnumQualifierType = TypeVar("EnumQualifierType", bound=Enum) @@ -1496,6 +1496,22 @@ def aggregate(self, type_: str, **aggregator_kwargs) -> list[Any]: return aggregator.aggregate(self) + def to_tracking_data(self) -> "TrackingDataset": + from .tracking import TrackingDataset + + freeze_frames = self.filter( + lambda event: event.freeze_frame is not None + ) + if len(freeze_frames.records) == 0: + raise ValueError( + "EventDataset has 0 freeze frame records making it impossible to convert to a TrackingDataset" + ) + + return TrackingDataset.from_dataset( + freeze_frames, + lambda event: event.freeze_frame, + ) + __all__ = [ "EnumQualifier", diff --git a/kloppy/infra/serializers/event/statsbomb/helpers.py b/kloppy/infra/serializers/event/statsbomb/helpers.py index 757e33c7d..477e3ebe9 100644 --- a/kloppy/infra/serializers/event/statsbomb/helpers.py +++ b/kloppy/infra/serializers/event/statsbomb/helpers.py @@ -156,7 +156,7 @@ def get_player_from_freeze_frame(player_data, team, i): timestamp=event.timestamp, ball_state=event.ball_state, ball_owning_team=event.ball_owning_team, - other_data={"visible_area": visible_area}, + other_data={"visible_area": visible_area, "event_id": event.event_id}, ) return frame diff --git a/kloppy/tests/test_statsbomb.py b/kloppy/tests/test_statsbomb.py index e7e6e8646..efbb2614d 100644 --- a/kloppy/tests/test_statsbomb.py +++ b/kloppy/tests/test_statsbomb.py @@ -47,7 +47,7 @@ PassType, UnderPressureQualifier, ) -from kloppy.exceptions import DeserializationError +from kloppy.exceptions import DeserializationError, KloppyParameterError from kloppy.infra.serializers.event.statsbomb.helpers import parse_str_ts import kloppy.infra.serializers.event.statsbomb.specification as SB @@ -616,7 +616,6 @@ def test_correct_normalized_deserialization(self): coordinates, ) in pass_event.freeze_frame.players_coordinates.items(): coordinates_per_team[player.team.name].append(coordinates) - print(coordinates_per_team) assert coordinates_per_team == { "Belgium": [ Point(x=0.30230680550305883, y=0.5224074534269804), @@ -1264,3 +1263,61 @@ def test_player_position(self, base_dir): PositionType.LeftMidfield, ) ] + + +class TestStatsBombAsTrackingDataset: + """Tests related to deserializing 34/Tactical Shift events""" + + def test_convert_to_tracking(self, dataset: EventDataset): + sb_tracking = dataset.to_tracking_data() + assert len(sb_tracking) == 3346 + + with pytest.raises( + AttributeError, + match=r"'NoneType' object has no attribute 'player_id'", + ): + sb_tracking.to_df(layout="wide") + + with pytest.raises( + KloppyParameterError, + match=r"Row-wise format is only supported for tracking datasets, got DatasetType.EVENT", + ): + dataset.to_df(layout="long") + + df = sb_tracking.to_df(layout="long") + assert list(df.columns) == [ + "period_id", + "timestamp", + "frame_id", + "ball_state", + "ball_owning_team_id", + "visible_area", + "event_id", + "team_id", + "player_id", + "x", + "y", + "z", + "d", + "s", + ] + + assert ( + len( + df[df["frame_id"] == 37].drop_duplicates( + subset=["period_id", "frame_id", "event_id"] + ) + ) + == 2 + ) + + assert ( + len( + df[df["frame_id"] == 37].drop_duplicates( + subset=["period_id", "frame_id", "player_id"] + ) + ) + == 40 + ) + + assert len(df) == 54540