diff --git a/docs/user-guide/exporting-data/dataframes.md b/docs/user-guide/exporting-data/dataframes.md index 5820d4675..b89c676fe 100644 --- a/docs/user-guide/exporting-data/dataframes.md +++ b/docs/user-guide/exporting-data/dataframes.md @@ -86,24 +86,63 @@ print(f""" ### Tracking data -For a [`TrackingDataset`][kloppy.domain.TrackingDataset], the output columns include: +The [`TrackingDataset`][kloppy.domain.TrackingDataset] supports two different layouts: **Wide** (default) and **Long**. -| Column | Description | -|-------------------------------------------------|---------------------------------------| -| frame_id | Frame number | -| period_id | Match period | -| timestamp | Frame timestamp | -| ball_x, ball_y, ball_z, ball_speed | Ball position and speed | -| _x, _y, _d, _s | Player coordinates, distance (since previous frame), and speed | -| ball_state | Current state of the ball | -| ball_owning_team | Which team owns the ball | +#### Wide layout + +In the wide layout, each row represents a single frame. Player data is flattened into columns, with a specific column for each player's x/y coordinates, speed, etc. + +**Common Columns:** + +| Column | Description | +| :--- | :--- | +| `frame_id` | Frame number | +| `period_id` | Match period | +| `timestamp` | Frame timestamp | +| `ball_state` | Current state of the ball | +| `ball_owning_team` | Which team owns the ball | +| `ball_x`, `ball_y`, `ball_z` | Ball coordinates | +| `ball_speed` | Ball speed | +| `_x`, `_y` | Player coordinates | +| `_d` | Player distance covered (since previous frame) | +| `_s` | Player speed | + +**Example:** + +```python exec="true" html="true" session="export-df" +# Default is layout="wide" +print(f""" +
+{tracking_dataset.to_df(layout="wide").head(n=3).to_html(index=False, border="0")} +
+""") +``` + +#### Long layout + +In the long layout, each frame is "melted" into multiple rows: one for the ball and one for each player present in that frame. + +**Common Columns:** + +| Column | Description | +| :--- | :--- | +| `frame_id` | Frame number | +| `period_id` | Match period | +| `timestamp` | Frame timestamp | +| `team_id` | Team identifier (or "ball") | +| `player_id` | Player identifier (or "ball") | +| `x`, `y`, `z` | Entity coordinates | +| `d` | Entity distance covered (since previous frame) | +| `s` | Entity speed | +| `ball_state` | Current state of the ball (repeated for all rows in frame) | +| `ball_owning_team` | Which team owns the ball (repeated for all rows in frame) | **Example:** ```python exec="true" html="true" session="export-df" print(f"""
-{tracking_dataset.to_df().head(n=3).to_html(index=False, border="0")} +{tracking_dataset.to_df(layout="long").head(n=6).to_html(index=False, border="0")}
""") ``` diff --git a/kloppy/domain/models/common.py b/kloppy/domain/models/common.py index 64470d0f8..153aa0507 100644 --- a/kloppy/domain/models/common.py +++ b/kloppy/domain/models/common.py @@ -1,9 +1,9 @@ from abc import ABC, abstractmethod -from collections import defaultdict from collections.abc import Iterable from dataclasses import dataclass, field, replace from datetime import datetime, timedelta from enum import Enum, Flag +from itertools import chain import sys from typing import ( TYPE_CHECKING, @@ -679,12 +679,16 @@ def to_mplsoccer(self): dim = BaseDims( left=self.pitch_dimensions.x_dim.min, right=self.pitch_dimensions.x_dim.max, - bottom=self.pitch_dimensions.y_dim.min - if not invert_y - else self.pitch_dimensions.y_dim.max, - top=self.pitch_dimensions.y_dim.max - if not invert_y - else self.pitch_dimensions.y_dim.min, + bottom=( + self.pitch_dimensions.y_dim.min + if not invert_y + else self.pitch_dimensions.y_dim.max + ), + top=( + self.pitch_dimensions.y_dim.max + if not invert_y + else self.pitch_dimensions.y_dim.min + ), width=self.pitch_dimensions.x_dim.max - self.pitch_dimensions.x_dim.min, length=self.pitch_dimensions.y_dim.max @@ -733,14 +737,16 @@ def to_mplsoccer(self): - self.pitch_dimensions.x_dim.min ), pad_multiplier=1, - aspect_equal=False - if self.pitch_dimensions.unit == Unit.NORMED - else True, + aspect_equal=( + False if self.pitch_dimensions.unit == Unit.NORMED else True + ), pitch_width=pitch_width, pitch_length=pitch_length, - aspect=pitch_width / pitch_length - if self.pitch_dimensions.unit == Unit.NORMED - else 1.0, + aspect=( + pitch_width / pitch_length + if self.pitch_dimensions.unit == Unit.NORMED + else 1.0 + ), ) return dim @@ -1823,6 +1829,7 @@ def to_records( self, *columns: Unpack[tuple["Column"]], as_list: Literal[True] = True, + layout: Optional[str] = None, **named_columns: "NamedColumns", ) -> list[dict[str, Any]]: ... @@ -1830,6 +1837,7 @@ def to_records( def to_records( self, *columns: Unpack[tuple["Column"]], + layout: Optional[str] = None, as_list: Literal[False] = False, **named_columns: "NamedColumns", ) -> Iterable[dict[str, Any]]: ... @@ -1837,15 +1845,16 @@ def to_records( def to_records( self, *columns: Unpack[tuple["Column"]], + layout: Optional[str] = None, as_list: bool = True, **named_columns: "NamedColumns", ) -> Union[list[dict[str, Any]], Iterable[dict[str, Any]]]: from ..services.transformers.data_record import get_transformer_cls - transformer = get_transformer_cls(self.dataset_type)( + transformer = get_transformer_cls(self.dataset_type, layout=layout)( *columns, **named_columns ) - iterator = map(transformer, self.records) + iterator = chain.from_iterable(map(transformer, self.records)) if as_list: return list(iterator) else: @@ -1855,23 +1864,36 @@ def to_dict( self, *columns: Unpack[tuple["Column"]], orient: Literal["list"] = "list", + layout: Optional[str] = None, **named_columns: "NamedColumns", ) -> dict[str, list[Any]]: if orient == "list": from ..services.transformers.data_record import get_transformer_cls - transformer = get_transformer_cls(self.dataset_type)( + transformer = get_transformer_cls(self.dataset_type, layout=layout)( *columns, **named_columns ) - c = len(self.records) - items = defaultdict(lambda: [None] * c) - for i, record in enumerate(self.records): - item = transformer(record) - for k, v in item.items(): - items[k][i] = v + result = {} + for record_idx, record in enumerate(self.records): + items = transformer(record) + + for item_idx, item in enumerate(items): + seen_keys = set() + for k, v in item.items(): + # If this is a new key, backfill previous records + if k not in result: + result[k] = [None] * (record_idx + item_idx) - return items + result[k].append(v) + seen_keys.add(k) + + # Pad keys that were not seen in this record + for k in result: + if k not in seen_keys: + result[k].append(None) + + return dict(result) else: raise KloppyParameterError( f"Orient {orient} is not supported. Only orient='list' is supported" @@ -1881,8 +1903,43 @@ def to_df( self, *columns: Unpack[tuple["Column"]], engine: Optional[Literal["polars", "pandas", "pandas[pyarrow]"]] = None, + layout: Optional[str] = None, **named_columns: "NamedColumns", ): + """Converts the dataset's records into a DataFrame. + + This method extracts data from the internal records and formats them into + a tabular structure using the specified dataframe engine (Pandas or Polars). + + Args: + *columns: Column names to include in the output. + - If not provided, a default set of columns is returned. + - Supports wildcards (e.g., "*coordinates*"). + - Supports callables for custom extraction logic. + engine: The dataframe engine to use. + - 'pandas': Returns a standard pandas DataFrame. + - 'pandas[pyarrow]': Returns a pandas DataFrame backed by PyArrow. + - 'polars': Returns a Polars DataFrame. + - None: Defaults to the `dataframe.enging` configuration value. + layout: The layout structure of the output. + - For Event data: Default is a flat list of events. + - For Tracking data: + - 'wide' (default): One row per frame, with players as columns. + - 'long': One row per entity (player/ball) per frame ("tidy" data). + **named_columns: Additional columns to create, where the key is the + column name and the value is a literal or a callable applied to + each record. + + Examples: + Basic conversion to Pandas: + >>> df = dataset.to_df() + + Using Polars and selecting specific columns: + >>> df = dataset.to_df("period_id", "timestamp", "player_id", "coordinates_*", engine="polars") + + Tracking data in long format: + >>> df = tracking_dataset.to_df(layout="long") + """ from kloppy.config import get_config if not engine: @@ -1913,7 +1970,9 @@ def to_df( ) table = pa.Table.from_pydict( - self.to_dict(*columns, orient="list", **named_columns) + self.to_dict( + *columns, orient="list", layout=layout, **named_columns + ) ) return table.to_pandas(types_mapper=types_mapper) @@ -1927,7 +1986,9 @@ def to_df( ) return DataFrame.from_dict( - self.to_dict(*columns, orient="list", **named_columns) + self.to_dict( + *columns, orient="list", layout=layout, **named_columns + ) ) elif engine == "polars": try: @@ -1939,7 +2000,9 @@ def to_df( ) return from_dict( - self.to_dict(*columns, orient="list", **named_columns) + self.to_dict( + *columns, orient="list", layout=layout, **named_columns + ) ) else: raise KloppyParameterError(f"Engine {engine} is not valid") diff --git a/kloppy/domain/models/event.py b/kloppy/domain/models/event.py index 2079e0e3c..1509fa8eb 100644 --- a/kloppy/domain/models/event.py +++ b/kloppy/domain/models/event.py @@ -40,6 +40,7 @@ from ..services.transformers.data_record import NamedColumns from .tracking import Frame + QualifierValueType = TypeVar("QualifierValueType") EnumQualifierType = TypeVar("EnumQualifierType", bound=Enum) diff --git a/kloppy/domain/services/__init__.py b/kloppy/domain/services/__init__.py index d38a26cb2..9cd5b6608 100644 --- a/kloppy/domain/services/__init__.py +++ b/kloppy/domain/services/__init__.py @@ -3,7 +3,7 @@ from kloppy.domain import AttackingDirection, Frame, Ground, Period from .event_factory import EventFactory, create_event -from .transformers import DatasetTransformer, DatasetTransformerBuilder +from .transformers.dataset import DatasetTransformer, DatasetTransformerBuilder # NOT YET: from .enrichers import TrackingPossessionEnricher diff --git a/kloppy/domain/services/transformers/__init__.py b/kloppy/domain/services/transformers/__init__.py index 18ea0c75c..4b8b66a67 100644 --- a/kloppy/domain/services/transformers/__init__.py +++ b/kloppy/domain/services/transformers/__init__.py @@ -1,3 +1,3 @@ -from .dataset import DatasetTransformer, DatasetTransformerBuilder +from . import attribute, data_record, dataset -__all__ = ["DatasetTransformer", "DatasetTransformerBuilder"] +__all__ = ["dataset", "data_record", "attribute"] diff --git a/kloppy/domain/services/transformers/attribute.py b/kloppy/domain/services/transformers/attribute.py index 798295b30..d27d7fee2 100644 --- a/kloppy/domain/services/transformers/attribute.py +++ b/kloppy/domain/services/transformers/attribute.py @@ -1,28 +1,51 @@ +"""Event Attribute Transformation. + +This module provides tools to extract, calculate, and encode features from +individual `Event` objects. These transformers are designed to enrich event data +with derived metrics (like distance to goal) or categorical encodings (like +one-hot encoded body parts) for downstream analysis or machine learning tasks. + +Examples: + **1. Calculating Distances and Angles** + Compute spatial metrics for an event relative to the goal. + + >>> from kloppy.domain.models.event import ShotEvent + >>> # event is a ShotEvent derived from a dataset with ACTION_EXECUTING_TEAM orientation + >>> + >>> dist_transformer = DistanceToGoalTransformer() + >>> angle_transformer = AngleToGoalTransformer() + >>> + >>> features = {} + >>> features.update(dist_transformer(event)) + >>> features.update(angle_transformer(event)) + >>> # features: {'distance_to_goal': 16.5, 'angle_to_goal': 25.4} + + **2. Encoding Qualifiers (Body Parts)** + Convert categorical body part qualifiers into one-hot encoded columns. + + >>> from kloppy.domain import BodyPartQualifier + >>> # event has a qualifier BodyPartQualifier(value=BodyPart.HEAD) + >>> + >>> transformer = BodyPartTransformer() + >>> encoded = transformer(event) + >>> # encoded: {'is_body_part_head': True, 'is_body_part_foot_right': False, ...} +""" + from abc import ABC, abstractmethod import math import sys -from typing import Any, Optional, Union +from typing import Any, Union from kloppy.domain import ( BodyPartQualifier, - Code, Event, - Frame, Orientation, Point, - QualifierMixin, - ResultMixin, ) from kloppy.domain.models.event import ( - CardEvent, - CarryEvent, EnumQualifier, - EventType, - PassEvent, - ShotEvent, ) from kloppy.exceptions import ( - KloppyParameterError, OrientationError, UnknownEncoderError, ) @@ -159,234 +182,4 @@ def __call__(self, event: Event) -> dict[str, Any]: return _Transformer -class DefaultEventTransformer(EventAttributeTransformer): - def __init__( - self, - *include: str, - exclude: Optional[list[str]] = None, - ): - if include and exclude: - raise KloppyParameterError("Cannot specify both include as exclude") - - self.exclude = exclude or [] - self.include = include or [] - - def __call__(self, event: Event) -> dict[str, Any]: - row = dict( - event_id=event.event_id, - event_type=( - event.event_type.value - if event.event_type != EventType.GENERIC - else f"GENERIC:{event.event_name}" - ), - period_id=event.period.id, - timestamp=event.timestamp, - end_timestamp=None, - ball_state=event.ball_state.value if event.ball_state else None, - ball_owning_team=( - event.ball_owning_team.team_id - if event.ball_owning_team - else None - ), - team_id=event.team.team_id if event.team else None, - player_id=event.player.player_id if event.player else None, - coordinates_x=event.coordinates.x if event.coordinates else None, - coordinates_y=event.coordinates.y if event.coordinates else None, - ) - if isinstance(event, PassEvent): - row.update( - { - "end_timestamp": event.receive_timestamp, - "end_coordinates_x": ( - event.receiver_coordinates.x - if event.receiver_coordinates - else None - ), - "end_coordinates_y": ( - event.receiver_coordinates.y - if event.receiver_coordinates - else None - ), - "receiver_player_id": ( - event.receiver_player.player_id - if event.receiver_player - else None - ), - } - ) - elif isinstance(event, CarryEvent): - row.update( - { - "end_timestamp": event.end_timestamp, - "end_coordinates_x": ( - event.end_coordinates.x - if event.end_coordinates - else None - ), - "end_coordinates_y": ( - event.end_coordinates.y - if event.end_coordinates - else None - ), - } - ) - elif isinstance(event, ShotEvent): - row.update( - { - "end_coordinates_x": ( - event.result_coordinates.x - if event.result_coordinates - else None - ), - "end_coordinates_y": ( - event.result_coordinates.y - if event.result_coordinates - else None - ), - } - ) - elif isinstance(event, CardEvent): - row.update( - { - "card_type": ( - event.card_type.value if event.card_type else None - ) - } - ) - - if isinstance(event, QualifierMixin) and event.qualifiers: - for qualifier in event.qualifiers: - row.update(qualifier.to_dict()) - - if isinstance(event, ResultMixin) and event.result is not None: - row.update( - { - "result": event.result.value, - "success": event.result.is_success, - } - ) - else: - row.update( - { - "result": None, - "success": None, - } - ) - - if self.include: - return {k: row[k] for k in self.include} - elif self.exclude: - return {k: v for k, v in row.items() if k not in self.exclude} - else: - return row - - -class DefaultFrameTransformer: - def __init__( - self, - *include: str, - exclude: Optional[list[str]] = None, - ): - if include and exclude: - raise KloppyParameterError("Cannot specify both include as exclude") - - self.exclude = exclude or [] - self.include = include or [] - - def __call__(self, frame: Frame) -> dict[str, Any]: - row = dict( - period_id=frame.period.id if frame.period else None, - timestamp=frame.timestamp, - frame_id=frame.frame_id, - ball_state=frame.ball_state.value if frame.ball_state else None, - ball_owning_team_id=( - frame.ball_owning_team.team_id - if frame.ball_owning_team - else None - ), - ball_x=( - frame.ball_coordinates.x if frame.ball_coordinates else None - ), - ball_y=( - frame.ball_coordinates.y if frame.ball_coordinates else None - ), - ball_z=( - getattr(frame.ball_coordinates, "z", None) - if frame.ball_coordinates - else None - ), - ball_speed=frame.ball_speed, - ) - for player, player_data in frame.players_data.items(): - row.update( - { - f"{player.player_id}_x": ( - player_data.coordinates.x - if player_data.coordinates - else None - ), - f"{player.player_id}_y": ( - player_data.coordinates.y - if player_data.coordinates - else None - ), - f"{player.player_id}_d": player_data.distance, - f"{player.player_id}_s": player_data.speed, - } - ) - - if player_data.other_data: - for name, value in player_data.other_data.items(): - row.update( - { - f"{player.player_id}_{name}": value, - } - ) - - if frame.other_data: - for name, value in frame.other_data.items(): - row.update( - { - name: value, - } - ) - - if self.include: - return {k: row[k] for k in self.include} - elif self.exclude: - return {k: v for k, v in row.items() if k not in self.exclude} - else: - return row - - -class DefaultCodeTransformer: - def __init__( - self, - *include: str, - exclude: Optional[list[str]] = None, - ): - if include and exclude: - raise KloppyParameterError("Cannot specify both include as exclude") - - self.exclude = exclude or [] - self.include = include or [] - - def __call__(self, code: Code) -> dict[str, Any]: - row = dict( - code_id=code.code_id, - period_id=code.period.id if code.period else None, - timestamp=code.timestamp, - end_timestamp=code.end_timestamp, - code=code.code, - ) - row.update(code.labels) - - if self.include: - return {k: row[k] for k in self.include} - elif self.exclude: - return {k: v for k, v in row.items() if k not in self.exclude} - else: - return row - - BodyPartTransformer = create_transformer_from_qualifier(BodyPartQualifier) diff --git a/kloppy/domain/services/transformers/data_record.py b/kloppy/domain/services/transformers/data_record.py index a62687197..61ef94346 100644 --- a/kloppy/domain/services/transformers/data_record.py +++ b/kloppy/domain/services/transformers/data_record.py @@ -1,106 +1,496 @@ +"""Data Record Transformation. + +This module provides tools for transforming kloppy DataRecord objects (such as +`Event`, `Frame`, and `Code`) into alternative formats like dictionaries, JSON +strings, or custom data structures. + +It separates **data extraction** (getting values from the object) from +**formatting** (structuring the output). This allows the same underlying +extraction logic to support multiple layouts (e.g., wide vs. long for tracking +data formats) and target types (e.g., `dict` vs. `str`). + +Key Components: + TransformerRegistry: Maps `DatasetType` and layout names to transformer classes. + DataRecordTransformer: Generic base class for extraction, filtering, and formatting. + register_data_record_transformer: Decorator to register new transformers. + +Examples: + **1. Basic Transformation** + Get a transformer for a specific DatasetType and convert a record to a list of dicts. + + >>> from kloppy.domain import DatasetType + >>> from kloppy.domain.services.transformers.data_record import get_transformer_cls + >>> # Default layout is implied + >>> cls = get_transformer_cls(DatasetType.EVENT) + >>> transformer = cls() + >>> data = transformer(event) + >>> # Result: [{'event_id': '...', 'timestamp': 0.1, ...}] + + **2. Column Selection & Wildcards** + Filter the output using specific column names or wildcards. + + >>> # Select 'event_id' and any column containing 'coordinates' + >>> transformer = cls("event_id", "*coordinates*") + >>> data = transformer(event) + + **3. Custom Formatting (e.g., JSON)** + Change the output type by providing a `formatter` callable. + + >>> import json + >>> # Returns a list of JSON strings instead of dicts + >>> transformer = cls(formatter=json.dumps) + >>> json_data = transformer(event) +""" + from abc import ABC, abstractmethod from fnmatch import fnmatch import sys -from typing import Any, Callable, Generic, TypeVar, Union +from typing import ( + Any, + Callable, + Generic, + Optional, + TypeVar, + Union, +) if sys.version_info >= (3, 11): from typing import Unpack else: from typing_extensions import Unpack -from kloppy.domain import Code, DataRecord, DatasetType, Event, Frame -from kloppy.domain.services.transformers.attribute import ( - DefaultCodeTransformer, - DefaultEventTransformer, - DefaultFrameTransformer, +from kloppy.domain import ( + Code, + DataRecord, + DatasetType, + Event, + Frame, + QualifierMixin, + ResultMixin, +) +from kloppy.domain.models.event import ( + CardEvent, + CarryEvent, + EventType, + PassEvent, + ShotEvent, ) from kloppy.exceptions import KloppyError -T = TypeVar("T", bound=DataRecord) -Column = Union[str, Callable[[T], Any]] -NamedColumns = dict[str, Column] +# --- Type Definitions --- +RecordT = TypeVar("RecordT", bound=DataRecord) +OutputT = TypeVar("OutputT") + +# A column can be a name (str) or a logical function +ColumnSelector = Union[str, Callable[[RecordT], Any]] +NamedColumnDefinitions = dict[str, Union[Any, Callable[[RecordT], Any]]] + + +# --- Registry System --- + + +class TransformerRegistry: + """Central registry for DataRecord transformers.""" + + def __init__(self): + # Structure: DatasetType -> Layout -> TransformerClass + self._registry: dict[ + DatasetType, dict[str, type[DataRecordTransformer]] + ] = {} + + def register( + self, + dataset_type: DatasetType, + layout: Union[str, list[str]] = "default", + ): + """Decorator to register a transformer class.""" + layouts = [layout] if isinstance(layout, str) else layout + + def wrapper(cls): + if dataset_type not in self._registry: + self._registry[dataset_type] = {} + + for layout_name in layouts: + current_map = self._registry[dataset_type] + + # Idempotency check + if layout_name in current_map: + existing = current_map[layout_name] + if existing != cls: + raise KloppyError( + f"Conflict: '{layout_name}' for {dataset_type} is already " + f"registered to {existing.__name__}." + ) + + current_map[layout_name] = cls + return cls + + return wrapper + + def get_class( + self, dataset_type: DatasetType, layout: Optional[str] = "default" + ) -> type["DataRecordTransformer"]: + """Retrieve a transformer class by type and layout.""" + layout = layout or "default" + + if dataset_type not in self._registry: + raise KloppyError( + f"No transformers registered for dataset type: {dataset_type}" + ) + + available = self._registry[dataset_type] + if layout not in available: + raise KloppyError( + f"Layout '{layout}' not found for {dataset_type}. " + f"Available: {list(available.keys())}" + ) + return available[layout] + + +# Global instance +_REGISTRY = TransformerRegistry() + +# Public API aliases for explicit naming +register_transformer = _REGISTRY.register +get_transformer_cls = _REGISTRY.get_class -class DataRecordToDictTransformer(ABC, Generic[T]): - @abstractmethod - def default_transformer(self) -> Callable[[T], dict]: ... + +# --- Base Transformer --- + + +class DataRecordTransformer(ABC, Generic[RecordT, OutputT]): + """ + Base class for transforming DataRecords into a specific OutputT. + + This class orchestrates: + 1. Extraction (to a canonical dict) + 2. Filtering (selecting specific columns) + 3. Augmentation (adding named columns) + 4. Formatting (converting dict to OutputT) + """ def __init__( self, - *columns: Unpack[tuple[Column]], - **named_columns: NamedColumns, + *columns: Unpack[tuple[ColumnSelector]], + formatter: Optional[Callable[[dict[str, Any]], OutputT]] = None, + **named_columns: NamedColumnDefinitions, ): - if not columns and not named_columns: - converter = self.default_transformer() - else: - default = self.default_transformer() - has_string_columns = any(not callable(column) for column in columns) + """ + Args: + *columns: Fields to select/compute. + formatter: Optional function to convert the final dict to OutputT. + If None, the output remains a dict (OutputT = dict). + **named_columns: New columns to append. + """ + self.columns = columns + self.named_columns = named_columns + self.formatter = formatter - def converter(data_record: T) -> dict[str, Any]: - if has_string_columns: - default_row = default(data_record) - else: - default_row = {} - - row = {} - for column in columns: - if callable(column): - res = column(data_record) - if not isinstance(res, dict): - raise KloppyError( - "A function column should return a dictionary" - ) - row.update(res) - else: - if column == "*": - row.update(default_row) - elif "*" in column: - row.update( - { - k: v - for k, v in default_row.items() - if fnmatch(k, column) - } - ) - elif column in default_row: - row[column] = default_row[column] - else: - row[column] = getattr(data_record, column, None) - - for name, column in named_columns.items(): - row[name] = ( - column(data_record) if callable(column) else column + def transform_record(self, record: RecordT) -> list[OutputT]: + """Public API to transform a record.""" + # 1. Extract canonical data (List of Dictionaries) + canonical_rows = self._extract_canonical(record) + + # 2. Process rows (Filter & Augment) + processed_rows = [ + self._process_row(row, record) for row in canonical_rows + ] + + # 3. Format output + if self.formatter: + return [self.formatter(row) for row in processed_rows] + + # If no formatter, we assume OutputT is dict + return processed_rows # type: ignore + + @abstractmethod + def _extract_canonical(self, record: RecordT) -> list[dict[str, Any]]: + """ + Implementation specific logic to extract raw data from the record. + Must return a list of flat dictionaries. + """ + pass + + def _process_row( + self, base_row: dict[str, Any], record: RecordT + ) -> dict[str, Any]: + """Applies column selection (filtering) and named column augmentation.""" + + # Optimization: If no columns specified, keep everything + if not self.columns: + row = base_row.copy() + else: + row = {} + for col in self.columns: + if callable(col): + # Callables merge their result into the row + res = col(record) + if not isinstance(res, dict): + raise KloppyError( + "Callable columns must return a dictionary." + ) + row.update(res) + elif col == "*": + row.update(base_row) + elif "*" in col: + # Wildcard match + row.update( + {k: v for k, v in base_row.items() if fnmatch(k, col)} ) + elif col in base_row: + row[col] = base_row[col] + else: + # Fallback to record attribute + row[col] = getattr(record, col, None) + + # Apply named columns + for name, value_or_func in self.named_columns.items(): + if callable(value_or_func): + row[name] = value_or_func(record) + else: + row[name] = value_or_func + + return row + + def __call__(self, record: RecordT) -> list[OutputT]: + return self.transform_record(record) + + +# --- Concrete Implementations --- + + +@register_transformer(DatasetType.EVENT, layout="default") +class EventTransformer(DataRecordTransformer[Event, Any]): + """Transformer for Event data.""" + + def _extract_canonical(self, record: Event) -> list[dict[str, Any]]: + row: dict[str, Any] = dict( + event_id=record.event_id, + event_type=( + record.event_type.value + if record.event_type != EventType.GENERIC + else f"GENERIC:{record.event_name}" + ), + period_id=record.period.id, + timestamp=record.timestamp, + end_timestamp=None, + ball_state=record.ball_state.value if record.ball_state else None, + ball_owning_team=( + record.ball_owning_team.team_id + if record.ball_owning_team + else None + ), + team_id=record.team.team_id if record.team else None, + player_id=record.player.player_id if record.player else None, + coordinates_x=record.coordinates.x if record.coordinates else None, + coordinates_y=record.coordinates.y if record.coordinates else None, + ) + + # Event-specific logic + if isinstance(record, PassEvent): + row.update( + { + "end_timestamp": record.receive_timestamp, + "end_coordinates_x": record.receiver_coordinates.x + if record.receiver_coordinates + else None, + "end_coordinates_y": record.receiver_coordinates.y + if record.receiver_coordinates + else None, + "receiver_player_id": record.receiver_player.player_id + if record.receiver_player + else None, + } + ) + elif isinstance(record, CarryEvent): + row.update( + { + "end_timestamp": record.end_timestamp, + "end_coordinates_x": record.end_coordinates.x + if record.end_coordinates + else None, + "end_coordinates_y": record.end_coordinates.y + if record.end_coordinates + else None, + } + ) + elif isinstance(record, ShotEvent): + row.update( + { + "end_coordinates_x": record.result_coordinates.x + if record.result_coordinates + else None, + "end_coordinates_y": record.result_coordinates.y + if record.result_coordinates + else None, + } + ) + elif isinstance(record, CardEvent): + row.update( + { + "card_type": record.card_type.value + if record.card_type + else None + } + ) + + if isinstance(record, QualifierMixin) and record.qualifiers: + for qualifier in record.qualifiers: + row.update(qualifier.to_dict()) + + if isinstance(record, ResultMixin): + row.update( + { + "result": record.result.value if record.result else None, + "success": record.result.is_success + if record.result + else None, + } + ) + else: + row.update( + { + "result": None, + "success": None, + } + ) + + return [row] + + +@register_transformer(DatasetType.TRACKING, layout=["wide", "default"]) +class TrackingWideTransformer(DataRecordTransformer[Frame, Any]): + """Wide-format transformer for Tracking data.""" + + def _extract_canonical(self, record: Frame) -> list[dict[str, Any]]: + row: dict[str, Any] = dict( + period_id=record.period.id if record.period else None, + timestamp=record.timestamp, + frame_id=record.frame_id, + ball_state=record.ball_state.value if record.ball_state else None, + ball_owning_team_id=( + record.ball_owning_team.team_id + if record.ball_owning_team + else None + ), + ball_x=record.ball_coordinates.x + if record.ball_coordinates + else None, + ball_y=record.ball_coordinates.y + if record.ball_coordinates + else None, + ball_z=getattr(record.ball_coordinates, "z", None) + if record.ball_coordinates + else None, + ball_speed=record.ball_speed, + ) + + for player, player_data in record.players_data.items(): + # Flatten player data into columns + prefix = f"{player.player_id}" + row.update( + { + f"{prefix}_x": player_data.coordinates.x + if player_data.coordinates + else None, + f"{prefix}_y": player_data.coordinates.y + if player_data.coordinates + else None, + f"{prefix}_d": player_data.distance, + f"{prefix}_s": player_data.speed, + } + ) + if player_data.other_data: + for k, v in player_data.other_data.items(): + row[f"{prefix}_{k}"] = v - return row + if record.other_data: + row.update(record.other_data) - self.converter = converter + return [row] - def __call__(self, data_record: T) -> dict[str, Any]: - return self.converter(data_record) +@register_transformer(DatasetType.TRACKING, layout="long") +class TrackingLongTransformer(DataRecordTransformer[Frame, Any]): + """Long-format transformer for Tracking data.""" -class EventToDictTransformer(DataRecordToDictTransformer[Event]): - def default_transformer(self) -> Callable[[Event], dict]: - return DefaultEventTransformer() + def _extract_canonical(self, record: Frame) -> list[dict[str, Any]]: + rows = [] + base_data = { + "period_id": record.period.id if record.period else None, + "timestamp": record.timestamp, + "frame_id": record.frame_id, + "ball_state": record.ball_state.value + if record.ball_state + else None, + "ball_owning_team_id": ( + record.ball_owning_team.team_id + if record.ball_owning_team + else None + ), + } + if record.other_data: + base_data.update(record.other_data) + # Ball + ball_row = base_data.copy() + ball_row.update( + { + "team_id": "ball", + "player_id": "ball", + "x": record.ball_coordinates.x + if record.ball_coordinates + else None, + "y": record.ball_coordinates.y + if record.ball_coordinates + else None, + "z": getattr(record.ball_coordinates, "z", None) + if record.ball_coordinates + else None, + "s": record.ball_speed, + } + ) + rows.append(ball_row) -class FrameToDictTransformer(DataRecordToDictTransformer[Frame]): - def default_transformer(self) -> Callable[[Frame], dict]: - return DefaultFrameTransformer() + # Players + for player, player_data in record.players_data.items(): + p_row = base_data.copy() + p_row.update( + { + "team_id": player.team.team_id if player.team else None, + "player_id": player.player_id, + "x": player_data.coordinates.x + if player_data.coordinates + else None, + "y": player_data.coordinates.y + if player_data.coordinates + else None, + "z": getattr(player_data.coordinates, "z", None) + if player_data.coordinates + else None, + "d": player_data.distance, + "s": player_data.speed, + } + ) + if player_data.other_data: + p_row.update(player_data.other_data) + rows.append(p_row) + return rows -class CodeToDictTransformer(DataRecordToDictTransformer[Code]): - def default_transformer(self) -> Callable[[Code], dict]: - return DefaultCodeTransformer() +@register_transformer(DatasetType.CODE, layout="default") +class CodeTransformer(DataRecordTransformer[Code, Any]): + """Transformer for Code data.""" -def get_transformer_cls( - dataset_type: DatasetType, -) -> type[DataRecordToDictTransformer]: - if dataset_type == DatasetType.EVENT: - return EventToDictTransformer - elif dataset_type == DatasetType.TRACKING: - return FrameToDictTransformer - elif dataset_type == DatasetType.CODE: - return CodeToDictTransformer + def _extract_canonical(self, record: Code) -> list[dict[str, Any]]: + row = dict( + code_id=record.code_id, + period_id=record.period.id if record.period else None, + timestamp=record.timestamp, + end_timestamp=record.end_timestamp, + code=record.code, + ) + row.update(record.labels) + return [row] diff --git a/kloppy/domain/services/transformers/dataset.py b/kloppy/domain/services/transformers/dataset.py index 5de757e54..d0ab6b671 100644 --- a/kloppy/domain/services/transformers/dataset.py +++ b/kloppy/domain/services/transformers/dataset.py @@ -1,3 +1,63 @@ +"""Dataset Transformation. + +This module provides the machinery for transforming the spatial representation of +kloppy datasets (`EventDataset` and `TrackingDataset`). It addresses three key +aspects of spatial data normalization: + +1. **Coordinate Systems**: Converting data from a provider-specific system + (e.g., Opta, Wyscout) to a standardized system (e.g., Kloppy Standard, + Metric). +2. **Pitch Dimensions**: Scaling coordinates from normalized (0-1) ranges to + metric values (meters), or vice-versa, based on specific pitch dimensions. +3. **Orientation**: Flipping coordinates to ensure a consistent attacking + direction (e.g., ensuring the Home team always attacks to the right). + +Key Components: + DatasetTransformer: The core class capable of transforming entire datasets, + individual frames, or specific events. + DatasetTransformerBuilder: A helper factory to construct transformers based + on configuration or provider names. + +Examples: + **1. Standardizing a Dataset** + Convert a dataset to the standard Kloppy coordinate system (Metric, origin + at center, x pointing to opposition goal). + + >>> from kloppy.domain import KloppyCoordinateSystem + >>> # dataset is an EventDataset or TrackingDataset + >>> new_dataset = DatasetTransformer.transform_dataset( + ... dataset, + ... to_coordinate_system=KloppyCoordinateSystem() + ... ) + + **2. Enforcing Orientation** + Force the dataset orientation so the Home team always attacks to the Right + (and the Away team attacks Left), flipping coordinates for the second half + if necessary. + + >>> from kloppy.domain import Orientation + >>> new_dataset = DatasetTransformer.transform_dataset( + ... dataset, + ... to_orientation=Orientation.HOME_AWAY + ... ) + + **3. Scaling to Specific Pitch Dimensions** + If the source data is normalized (0-1), you can project it onto a real pitch size. + + >>> from kloppy.domain import MetricPitchDimensions, Dimension + >>> dims = MetricPitchDimensions( + ... x_dim=Dimension(0, 105), + ... y_dim=Dimension(0, 68), + ... pitch_length=105, + ... pitch_width=68, + ... standardized=False, + ...) + >>> new_dataset = DatasetTransformer.transform_dataset( + ... dataset, + ... to_pitch_dimensions=dims + ... ) +""" + from dataclasses import fields, replace from typing import Optional, Union import warnings diff --git a/kloppy/tests/test_helpers.py b/kloppy/tests/test_helpers.py index f761ae26a..783f57589 100644 --- a/kloppy/tests/test_helpers.py +++ b/kloppy/tests/test_helpers.py @@ -25,6 +25,7 @@ TrackingDataset, ) from kloppy.domain.services.frame_factory import create_frame +from kloppy.exceptions import KloppyError class TestHelpers: @@ -411,7 +412,7 @@ def test_transform_event_data_freeze_frame(self, base_dir): assert coordinates.x == 1 - coordinates_transformed.x assert coordinates.y == 1 - coordinates_transformed.y - def test_to_pandas(self): + def test_to_pandas_wide_layout(self): tracking_data = self._get_tracking_dataset() data_frame = tracking_data.to_df(engine="pandas") @@ -437,6 +438,48 @@ def test_to_pandas(self): ) assert_frame_equal(data_frame, expected_data_frame, check_like=True) + def test_to_pandas_long_layout(self): + tracking_data = self._get_tracking_dataset() + + # Specify layout="long" + data_frame = tracking_data.to_df(engine="pandas", layout="long") + + expected_data_frame = DataFrame.from_dict( + { + # Row 0: Frame 1 - Ball + # Row 1: Frame 2 - Ball + # Row 2: Frame 2 - Player 'home_1' + "frame_id": [1, 2, 2], + "period_id": [1, 2, 2], + "timestamp": [0.1, 0.2, 0.2], + "ball_state": [None, None, None], + "ball_owning_team_id": ["home", "away", "away"], + # Identifiers + "team_id": ["ball", "ball", "home"], + "player_id": ["ball", "ball", "home_1"], + # Coordinates & Metrics (Unified columns) + "x": [100.0, 0.0, 15.0], + "y": [-50.0, 50.0, 35.0], + "z": [0.0, 1.0, None], # Player has no Z in wide test + "d": [None, None, 0.03], + "s": [None, None, 10.5], + # Metadata + # Note: Frame-level 'extra_data' (value 1 in frame 2) propagates to all rows in that frame + "extra_data": [None, 1, 1], + } + ) + + # check_like=True ignores column order + assert_frame_equal( + data_frame, expected_data_frame, check_like=True, check_dtype=False + ) + + def test_to_pandas_invalid_layout(self): + tracking_data = self._get_tracking_dataset() + + with pytest.raises(KloppyError, match="Layout 'wrong' not found"): + tracking_data.to_df(engine="pandas", layout="wrong") + def test_to_pandas_generic_events(self, base_dir): dataset = opta.load( f7_data=base_dir / "files/opta_f7.xml", diff --git a/kloppy/tests/test_metadata.py b/kloppy/tests/test_metadata.py index a6be8e737..c7c2120ef 100644 --- a/kloppy/tests/test_metadata.py +++ b/kloppy/tests/test_metadata.py @@ -11,7 +11,7 @@ Point3D, Unit, ) -from kloppy.domain.services.transformers import DatasetTransformer +from kloppy.domain.services.transformers.dataset import DatasetTransformer class TestPitchdimensions: diff --git a/kloppy/tests/test_statsbomb.py b/kloppy/tests/test_statsbomb.py index d850a396e..3aec36dab 100644 --- a/kloppy/tests/test_statsbomb.py +++ b/kloppy/tests/test_statsbomb.py @@ -616,7 +616,7 @@ def test_correct_normalized_deserialization(self): coordinates, ) in pass_event.freeze_frame.players_coordinates.items(): coordinates_per_team[player.team.name].append(coordinates) - print(coordinates_per_team) + assert coordinates_per_team == { "Belgium": [ Point(x=0.30230680550305883, y=0.5224074534269804),