Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion kloppy/domain/models/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from .pitch import Point

if TYPE_CHECKING:
from .tracking import Frame
from .tracking import Frame, TrackingDataset

QualifierValueType = TypeVar("QualifierValueType")
EnumQualifierType = TypeVar("EnumQualifierType", bound=Enum)
Expand Down Expand Up @@ -1496,6 +1496,22 @@ def aggregate(self, type_: str, **aggregator_kwargs) -> list[Any]:

return aggregator.aggregate(self)

def to_tracking_data(self) -> "TrackingDataset":
from .tracking import TrackingDataset

freeze_frames = self.filter(
lambda event: event.freeze_frame is not None
)
if len(freeze_frames.records) == 0:
raise ValueError(
"EventDataset has 0 freeze frame records making it impossible to convert to a TrackingDataset"
)

return TrackingDataset.from_dataset(
freeze_frames,
lambda event: event.freeze_frame,
)


__all__ = [
"EnumQualifier",
Expand Down
2 changes: 1 addition & 1 deletion kloppy/infra/serializers/event/statsbomb/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def get_player_from_freeze_frame(player_data, team, i):
timestamp=event.timestamp,
ball_state=event.ball_state,
ball_owning_team=event.ball_owning_team,
other_data={"visible_area": visible_area},
other_data={"visible_area": visible_area, "event_id": event.event_id},
)

return frame
61 changes: 59 additions & 2 deletions kloppy/tests/test_statsbomb.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
PassType,
UnderPressureQualifier,
)
from kloppy.exceptions import DeserializationError
from kloppy.exceptions import DeserializationError, KloppyParameterError
from kloppy.infra.serializers.event.statsbomb.helpers import parse_str_ts
import kloppy.infra.serializers.event.statsbomb.specification as SB

Expand Down Expand Up @@ -616,7 +616,6 @@ def test_correct_normalized_deserialization(self):
coordinates,
) in pass_event.freeze_frame.players_coordinates.items():
coordinates_per_team[player.team.name].append(coordinates)
print(coordinates_per_team)
assert coordinates_per_team == {
"Belgium": [
Point(x=0.30230680550305883, y=0.5224074534269804),
Expand Down Expand Up @@ -1264,3 +1263,61 @@ def test_player_position(self, base_dir):
PositionType.LeftMidfield,
)
]


class TestStatsBombAsTrackingDataset:
"""Tests related to deserializing 34/Tactical Shift events"""

def test_convert_to_tracking(self, dataset: EventDataset):
sb_tracking = dataset.to_tracking_data()
assert len(sb_tracking) == 3346

with pytest.raises(
AttributeError,
match=r"'NoneType' object has no attribute 'player_id'",
):
sb_tracking.to_df(layout="wide")

with pytest.raises(
KloppyParameterError,
match=r"Row-wise format is only supported for tracking datasets, got DatasetType.EVENT",
):
dataset.to_df(layout="long")

df = sb_tracking.to_df(layout="long")
assert list(df.columns) == [
"period_id",
"timestamp",
"frame_id",
"ball_state",
"ball_owning_team_id",
"visible_area",
"event_id",
"team_id",
"player_id",
"x",
"y",
"z",
"d",
"s",
]

assert (
len(
df[df["frame_id"] == 37].drop_duplicates(
subset=["period_id", "frame_id", "event_id"]
)
)
== 2
)

assert (
len(
df[df["frame_id"] == 37].drop_duplicates(
subset=["period_id", "frame_id", "player_id"]
)
)
== 40
)

assert len(df) == 54540
Loading