Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions kloppy/_providers/cdf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from typing import Optional

from kloppy.domain import TrackingDataset
from kloppy.infra.serializers.tracking.cdf import (
CDFTrackingDataInputs,
CDFTrackingDeserializer,
)
from kloppy.io import FileLike, open_as_file


def load_tracking(
meta_data: FileLike,
raw_data: FileLike,
sample_rate: Optional[float] = None,
limit: Optional[int] = None,
coordinates: Optional[str] = None,
include_empty_frames: Optional[bool] = False,
only_alive: Optional[bool] = True,
) -> TrackingDataset:
"""
Load Common Data Format broadcast tracking data.

Args:
meta_data: A JSON feed containing the meta data.
raw_data: A JSONL feed containing the raw tracking data.
sample_rate: Sample the data at a specific rate.
limit: Limit the number of frames to load to the first `limit` frames.
coordinates: The coordinate system to use.
include_empty_frames: Include frames in which no objects were tracked.
only_alive: Only include frames in which the game is not paused.

Returns:
The parsed tracking data.
"""
deserializer = CDFTrackingDeserializer(
sample_rate=sample_rate,
limit=limit,
coordinate_system=coordinates,
include_empty_frames=include_empty_frames,
only_alive=only_alive,
)
with (
open_as_file(meta_data) as meta_data_fp,
open_as_file(raw_data) as raw_data_fp,
):
return deserializer.deserialize(
inputs=CDFTrackingDataInputs(
meta_data=meta_data_fp, raw_data=raw_data_fp
)
)


# def load_event(
# event_data: FileLike,
# meta_data: FileLike,
# event_types: Optional[list[str]] = None,
# coordinates: Optional[str] = None,
# event_factory: Optional[EventFactory] = None,
# ) -> EventDataset:
# """
# Load Common Data Format event data.

# Args:
# event_data: JSON feed with the raw event data of a game.
# meta_data: JSON feed with the corresponding lineup information of the game.
# event_types: A list of event types to load.
# coordinates: The coordinate system to use.
# event_factory: A custom event factory.

# Returns:
# The parsed event data.
# """
# deserializer = StatsBombDeserializer(
# event_types=event_types,
# coordinate_system=coordinates,
# event_factory=event_factory
# or get_config("event_factory")
# or StatsBombEventFactory(),
# )
# with (
# open_as_file(event_data) as event_data_fp,
# open_as_file(meta_data) as meta_data_fp,
# ):
# return deserializer.deserialize(
# inputs=StatsBombInputs(
# event_data=event_data_fp,
# lineup_data=meta_data_fp,
# )
# )
5 changes: 3 additions & 2 deletions kloppy/_providers/sportscode.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from kloppy.infra.serializers.code.sportscode import (
SportsCodeDeserializer,
SportsCodeInputs,
SportsCodeOutputs,
SportsCodeSerializer,
)
from kloppy.io import FileLike, open_as_file
Expand Down Expand Up @@ -31,6 +32,6 @@ def save(dataset: CodeDataset, output_filename: str) -> None:
dataset: The SportsCode dataset to save.
output_filename: The output filename.
"""
with open(output_filename, "wb") as fp:
with open_as_file(output_filename, "wb") as data_fp:
serializer = SportsCodeSerializer()
fp.write(serializer.serialize(dataset))
serializer.serialize(dataset, outputs=SportsCodeOutputs(data=data_fp))
5 changes: 5 additions & 0 deletions kloppy/cdf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Functions for loading SkillCorner broadcast tracking data."""

from ._providers.cdf import load_tracking

__all__ = ["load_tracking"]
78 changes: 65 additions & 13 deletions kloppy/domain/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ class Provider(Enum):
HAWEKEYE (Provider):
SPORTVU (Provider):
IMPECT (Provider):
CDF (Provider):
OTHER (Provider):
"""

Expand All @@ -128,8 +129,9 @@ class Provider(Enum):
STATSPERFORM = "statsperform"
HAWKEYE = "hawkeye"
SPORTVU = "sportvu"
SIGNALITY = "signality"
IMPECT = "impect"
CDF = "common_data_format"
SIGNALITY = "signality"
OTHER = "other"

def __str__(self):
Expand Down Expand Up @@ -679,12 +681,16 @@ def to_mplsoccer(self):
dim = BaseDims(
left=self.pitch_dimensions.x_dim.min,
right=self.pitch_dimensions.x_dim.max,
bottom=self.pitch_dimensions.y_dim.min
if not invert_y
else self.pitch_dimensions.y_dim.max,
top=self.pitch_dimensions.y_dim.max
if not invert_y
else self.pitch_dimensions.y_dim.min,
bottom=(
self.pitch_dimensions.y_dim.min
if not invert_y
else self.pitch_dimensions.y_dim.max
),
top=(
self.pitch_dimensions.y_dim.max
if not invert_y
else self.pitch_dimensions.y_dim.min
),
width=self.pitch_dimensions.x_dim.max
- self.pitch_dimensions.x_dim.min,
length=self.pitch_dimensions.y_dim.max
Expand Down Expand Up @@ -733,14 +739,16 @@ def to_mplsoccer(self):
- self.pitch_dimensions.x_dim.min
),
pad_multiplier=1,
aspect_equal=False
if self.pitch_dimensions.unit == Unit.NORMED
else True,
aspect_equal=(
False if self.pitch_dimensions.unit == Unit.NORMED else True
),
pitch_width=pitch_width,
pitch_length=pitch_length,
aspect=pitch_width / pitch_length
if self.pitch_dimensions.unit == Unit.NORMED
else 1.0,
aspect=(
pitch_width / pitch_length
if self.pitch_dimensions.unit == Unit.NORMED
else 1.0
),
)
return dim

Expand Down Expand Up @@ -1184,6 +1192,45 @@ def pitch_dimensions(self) -> PitchDimensions:
)


class CDFCoordinateSystem(ProviderCoordinateSystem):
"""
CDFCoordinateSystem coordinate system.

Uses a pitch with the origin at the center and the y-axis oriented
from bottom to top. The coordinates are in meters.
"""

@property
def provider(self) -> Provider:
return Provider.CDF

@property
def origin(self) -> Origin:
return Origin.CENTER

@property
def vertical_orientation(self) -> VerticalOrientation:
return VerticalOrientation.BOTTOM_TO_TOP

@property
def pitch_dimensions(self) -> PitchDimensions:
return NormalizedPitchDimensions(
x_dim=Dimension(
-1 * self._pitch_length / 2, self._pitch_length / 2
),
y_dim=Dimension(-1 * self._pitch_width / 2, self._pitch_width / 2),
pitch_length=self._pitch_length,
pitch_width=self._pitch_width,
standardized=False,
)

def __init__(self, base_coordinate_system: ProviderCoordinateSystem):
self._pitch_length = (
base_coordinate_system.pitch_dimensions.pitch_length
)
self._pitch_width = base_coordinate_system.pitch_dimensions.pitch_width


class SignalityCoordinateSystem(ProviderCoordinateSystem):
@property
def provider(self) -> Provider:
Expand Down Expand Up @@ -1414,6 +1461,7 @@ def build_coordinate_system(
Provider.SPORTVU: SportVUCoordinateSystem,
Provider.SIGNALITY: SignalityCoordinateSystem,
Provider.IMPECT: ImpectCoordinateSystem,
Provider.CDF: CDFCoordinateSystem,
}

if provider in coordinate_systems:
Expand Down Expand Up @@ -1943,6 +1991,10 @@ def to_df(
else:
raise KloppyParameterError(f"Engine {engine} is not valid")

def to_cdf(self):
if self.dataset_type != DatasetType.TRACKING:
raise ValueError("to_cdf() is only supported for TrackingDataset")

def __repr__(self):
return f"<{self.__class__.__name__} record_count={len(self.records)}>"

Expand Down
75 changes: 74 additions & 1 deletion kloppy/domain/models/tracking.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable, Optional
from typing import TYPE_CHECKING, Any, Callable, Optional, Union

from kloppy.domain.models.common import DatasetType
from kloppy.utils import (
deprecated,
docstring_inherit_attributes,
)

if TYPE_CHECKING:
from kloppy.io import FileLike

from .common import DataRecord, Dataset, Player
from .pitch import Point, Point3D

if TYPE_CHECKING:
from cdf.domain import CdfMetaDataSchema
from pandas import DataFrame


Expand Down Expand Up @@ -122,5 +126,74 @@ def generic_record_converter(frame: Frame):
map(generic_record_converter, self.records)
)

# Update the to_cdf method in Dataset class
def to_cdf(
self,
metadata_output_file: "FileLike",
tracking_output_file: "FileLike",
additional_metadata: Optional[Union[dict, "CdfMetaDataSchema"]] = None,
) -> None:
"""
Export dataset to Common Data Format (CDF).

Args:
metadata_output_file: File path or file-like object for metadata JSON output.
Must have .json extension if a string path.
tracking_output_file: File path or file-like object for tracking JSONL output.
Must have .jsonl extension if a string path.
additional_metadata: Additional metadata to include in the CDF output.
Can be a complete CdfMetaDataSchema TypedDict or a partial dict.
Supported top-level keys: 'competition', 'season', 'stadium', 'meta', 'match'.
Supports nested updates like {'stadium': {'id': '123'}}.

Raises:
KloppyError: If the dataset is not a TrackingDataset.
ValueError: If file extensions are invalid.

Examples:
>>> # Export to local files
>>> dataset.to_cdf(
... metadata_output_file='metadata.json',
... tracking_output_file='tracking.jsonl'
... )

>>> # Export to S3
>>> dataset.to_cdf(
... metadata_output_file='s3://bucket/metadata.json',
... tracking_output_file='s3://bucket/tracking.jsonl'
... )

>>> # Export with partial metadata updates
>>> dataset.to_cdf(
... metadata_output_file='metadata.json',
... tracking_output_file='tracking.jsonl',
... additional_metadata={
... 'competition': {'id': '123'},
... 'season': {'id': '2024'},
... 'stadium': {'id': '456', 'name': 'Stadium Name'}
... }
... )
"""
from kloppy.infra.serializers.tracking.cdf import (
CDFOutputs,
CDFTrackingSerializer,
)
from kloppy.io import open_as_file

serializer = CDFTrackingSerializer()

# TODO: write files but also support non-local files, similar to how open_as_file supports non-local files

# Use open_as_file with mode="wb" for writing
with (
open_as_file(metadata_output_file, mode="wb") as metadata_fp,
open_as_file(tracking_output_file, mode="wb") as tracking_fp,
):
serializer.serialize(
dataset=self,
outputs=CDFOutputs(meta_data=metadata_fp, raw_data=tracking_fp),
additional_metadata=additional_metadata,
)


__all__ = ["Frame", "TrackingDataset", "PlayerData"]
22 changes: 20 additions & 2 deletions kloppy/infra/io/adapters/adapter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from typing import BinaryIO

from kloppy.infra.io.buffered_stream import BufferedStream


class Adapter(ABC):
Expand All @@ -16,9 +17,26 @@ def is_file(self, url: str) -> bool:
pass

@abstractmethod
def read_to_stream(self, url: str, output: BinaryIO):
def read_to_stream(self, url: str, output: BufferedStream):
pass

def write_from_stream(self, url: str, input: BufferedStream, mode: str): # noqa: A002
"""
Write content from BufferedStream to the given URL.

Args:
url: The destination URL
input: BufferedStream to read from
mode: Write mode ('wb' for write/overwrite or 'ab' for append)

Raises:
NotImplementedError: If write operations are not supported by this adapter
"""
raise NotImplementedError(
f"Write operations not supported for {url}. "
f"Adapter {self.__class__.__name__} does not implement write_from_stream."
)

@abstractmethod
def list_directory(self, url: str, recursive: bool = True) -> list[str]:
pass
Expand Down
Loading
Loading