Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 46 additions & 5 deletions kloppy/domain/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import Iterable
from dataclasses import dataclass, field, replace
from dataclasses import dataclass, field, fields, replace
from datetime import datetime, timedelta
from enum import Enum, Flag
import sys
Expand Down Expand Up @@ -1667,6 +1667,18 @@ def __post_init__(self):
T = TypeVar("T", bound="DataRecord")


class FilteredDataset:
"""
A marker class to identify datasets that have been filtered.
"""

pass


# Helper to cache dynamic classes so we don't recreate them every time
_FILTERED_CLASS_CACHE = {}


@dataclass
class Dataset(ABC, Generic[T]):
"""
Expand Down Expand Up @@ -1757,10 +1769,39 @@ def filter(self, filter_: str | Callable[[T], bool]):
>>> dataset = dataset.filter(lambda event: event.event_type == EventType.PASS)
>>> dataset = dataset.filter('pass')
"""
return replace(
self,
records=self.find_all(filter_),
)
# 1. Perform filtering
filtered_records = self.find_all(filter_)

# 2. Determine the target class
current_class = self.__class__

if isinstance(self, FilteredDataset):
# Already a filtered class, keep using it
target_class = current_class
else:
# Need to create or retrieve the dynamic filtered subclass
if current_class not in _FILTERED_CLASS_CACHE:
# Dynamically create: class FilteredEventDataset(FilteredDataset, EventDataset)
new_cls_name = f"Filtered{current_class.__name__}"

# We inherit from FilteredDataset first, then the original class
_FILTERED_CLASS_CACHE[current_class] = type(
new_cls_name, (FilteredDataset, current_class), {}
)
target_class = _FILTERED_CLASS_CACHE[current_class]

# 3. Gather arguments to instantiate the new class
# We only want fields that are defined in __init__
init_kwargs = {
f.name: getattr(self, f.name) for f in fields(self) if f.init
}

# 4. Update the records in the arguments
init_kwargs["records"] = filtered_records

# 5. Instantiate and return the new class
# This runs __init__ and __post_init__ of the new class naturally
return target_class(**init_kwargs)

def map(self, mapper):
return replace(
Expand Down
10 changes: 3 additions & 7 deletions kloppy/infra/serializers/event/datafactory/deserializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ class DatafactoryDeserializer(EventDataDeserializer[DatafactoryInputs]):
def provider(self) -> Provider:
return Provider.DATAFACTORY

def deserialize(self, inputs: DatafactoryInputs) -> EventDataset:
def _deserialize(self, inputs: DatafactoryInputs) -> EventDataset:
transformer = self.get_transformer()

with performance_logging("load data", logger=logger):
Expand Down Expand Up @@ -590,13 +590,9 @@ def deserialize(self, inputs: DatafactoryInputs) -> EventDataset:
result=None,
qualifiers=None,
)
if self.should_include_event(event):
events.append(
transformer.transform_event(ball_out_event)
)
events.append(transformer.transform_event(ball_out_event))

if self.should_include_event(event):
events.append(transformer.transform_event(event))
events.append(transformer.transform_event(event))

# only consider as a previous_event a ball-in-play event
if e_class not in (
Expand Down
52 changes: 50 additions & 2 deletions kloppy/infra/serializers/event/deserializer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from abc import ABC, abstractmethod
from typing import Generic, Optional, TypeVar, Union
from dataclasses import fields, replace
from typing import Any, Generic, Optional, TypeVar, Union
import warnings

from kloppy.domain import (
DatasetTransformer,
Expand Down Expand Up @@ -64,5 +66,51 @@ def provider(self) -> Provider:
raise NotImplementedError

@abstractmethod
def deserialize(self, inputs: T) -> EventDataset:
def _deserialize(self, inputs: T) -> EventDataset:
raise NotImplementedError

def deserialize(
self, inputs: T, additional_metadata: Optional[dict[str, Any]] = None
) -> EventDataset:
dataset = self._deserialize(inputs)

# Check for additional metadata to merge
if additional_metadata:
# Identify valid fields in the Metadata class
valid_fields = {f.name for f in fields(dataset.metadata)}

# Split additional_metadata into known and unknown keys
known_updates = {}
unknown_updates = {}

for key, value in additional_metadata.items():
if key in valid_fields:
known_updates[key] = value
else:
unknown_updates[key] = value

# Handle unknown keys (put them into 'attributes' and warn)
if unknown_updates:
warnings.warn(
f"The following metadata keys are not supported fields and will be "
f"added to 'attributes': {list(unknown_updates.keys())}"
)

# specific logic to merge with existing attributes safely
current_attributes = dataset.metadata.attributes or {}
# Create a new dict to avoid mutating the original if it's shared
new_attributes = current_attributes.copy()
new_attributes.update(unknown_updates)

known_updates["attributes"] = new_attributes

# Apply updates
if known_updates:
updated_metadata = replace(dataset.metadata, **known_updates)
dataset = replace(dataset, metadata=updated_metadata)

# Check if we need to return a FilteredEventDataset
if self.event_types:
return dataset.filter(self.should_include_event)

return dataset
9 changes: 4 additions & 5 deletions kloppy/infra/serializers/event/impect/deserializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class ImpectDeserializer(EventDataDeserializer[ImpectInputs]):
def provider(self) -> Provider:
return Provider.IMPECT

def deserialize(self, inputs: ImpectInputs) -> EventDataset:
def _deserialize(self, inputs: ImpectInputs) -> EventDataset:
# Initialize coordinate system transformer
self.transformer = self.get_transformer()

Expand Down Expand Up @@ -134,10 +134,9 @@ def deserialize(self, inputs: ImpectInputs) -> EventDataset:
periods, teams, impect_events
).deserialize(self.event_factory, teams)
for event in new_events:
if self.should_include_event(event):
# Transform event to the coordinate system
event = self.transformer.transform_event(event)
events.append(event)
# Transform event to the coordinate system
event = self.transformer.transform_event(event)
events.append(event)

self.mark_events_as_assists(events)
substitution_events = self.parse_substitutions(
Expand Down
9 changes: 3 additions & 6 deletions kloppy/infra/serializers/event/metrica/json_deserializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ class MetricaJsonEventDataDeserializer(
def provider(self) -> Provider:
return Provider.METRICA

def deserialize(self, inputs: MetricaJsonEventDataInputs) -> EventDataset:
def _deserialize(self, inputs: MetricaJsonEventDataInputs) -> EventDataset:
with performance_logging("load data", logger=logger):
raw_events = json.load(inputs.event_data)
metadata = load_metadata(
Expand Down Expand Up @@ -370,8 +370,7 @@ def deserialize(self, inputs: MetricaJsonEventDataInputs) -> EventDataset:
**generic_event_kwargs,
)

if self.should_include_event(event):
events.append(transformer.transform_event(event))
events.append(transformer.transform_event(event))

# Checks if the event ended out of the field and adds a synthetic out event
if (
Expand All @@ -393,9 +392,7 @@ def deserialize(self, inputs: MetricaJsonEventDataInputs) -> EventDataset:
qualifiers=None,
**generic_event_kwargs,
)

if self.should_include_event(event):
events.append(transformer.transform_event(event))
events.append(transformer.transform_event(event))

return EventDataset(
metadata=replace(
Expand Down
9 changes: 1 addition & 8 deletions kloppy/infra/serializers/event/sportec/deserializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ class SportecEventDataDeserializer(
def provider(self) -> Provider:
return Provider.SPORTEC

def deserialize(self, inputs: SportecEventDataInputs) -> EventDataset:
def _deserialize(self, inputs: SportecEventDataInputs) -> EventDataset:
with performance_logging("load data", logger=logger):
match_root = objectify.fromstring(inputs.meta_data.read())
event_root = objectify.fromstring(inputs.event_data.read())
Expand Down Expand Up @@ -682,13 +682,6 @@ def deserialize(self, inputs: SportecEventDataInputs) -> EventDataset:
else:
event.receiver_coordinates = events[i + 1].coordinates

events = list(
filter(
self.should_include_event,
events,
)
)

metadata = Metadata(
teams=teams,
periods=periods,
Expand Down
12 changes: 4 additions & 8 deletions kloppy/infra/serializers/event/statsbomb/deserializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,7 @@ class StatsBombDeserializer(EventDataDeserializer[StatsBombInputs]):
def provider(self) -> Provider:
return Provider.STATSBOMB

def deserialize(
self, inputs: StatsBombInputs, additional_metadata
) -> EventDataset:
def _deserialize(self, inputs: StatsBombInputs) -> EventDataset:
# Intialize coordinate system transformer
self.transformer = self.get_transformer()

Expand Down Expand Up @@ -76,10 +74,9 @@ def deserialize(
.deserialize(self.event_factory)
)
for event in new_events:
if self.should_include_event(event):
# Transform event to the coordinate system
event = self.transformer.transform_event(event)
events.append(event)
# Transform event to the coordinate system
event = self.transformer.transform_event(event)
events.append(event)

metadata = Metadata(
teams=teams,
Expand All @@ -91,7 +88,6 @@ def deserialize(
score=None,
provider=Provider.STATSBOMB,
coordinate_system=self.transformer.get_to_coordinate_system(),
**additional_metadata,
)
dataset = EventDataset(metadata=metadata, records=events)
for event in dataset:
Expand Down
5 changes: 2 additions & 3 deletions kloppy/infra/serializers/event/statsperform/deserializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,7 @@ class StatsPerformDeserializer(EventDataDeserializer[StatsPerformInputs]):
def provider(self) -> Provider:
return Provider.OPTA

def deserialize(self, inputs: StatsPerformInputs) -> EventDataset:
def _deserialize(self, inputs: StatsPerformInputs) -> EventDataset:
transformer = self.get_transformer(
pitch_length=inputs.pitch_length, pitch_width=inputs.pitch_width
)
Expand Down Expand Up @@ -984,8 +984,7 @@ def deserialize(self, inputs: StatsPerformInputs) -> EventDataset:
event_name=_get_event_type_name(raw_event.type_id),
)

if self.should_include_event(event):
events.append(transformer.transform_event(event))
events.append(transformer.transform_event(event))

metadata = Metadata(
teams=list(teams),
Expand Down
5 changes: 2 additions & 3 deletions kloppy/infra/serializers/event/wyscout/deserializer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ class WyscoutDeserializerV2(EventDataDeserializer[WyscoutInputs]):
def provider(self) -> Provider:
return Provider.WYSCOUT

def deserialize(self, inputs: WyscoutInputs) -> EventDataset:
def _deserialize(self, inputs: WyscoutInputs) -> EventDataset:
transformer = self.get_transformer()

with performance_logging("load data", logger=logger):
Expand Down Expand Up @@ -710,8 +710,7 @@ def deserialize(self, inputs: WyscoutInputs) -> EventDataset:
new_events.insert(i, interception_event)

for new_event in new_events:
if self.should_include_event(new_event):
events.append(transformer.transform_event(new_event))
events.append(transformer.transform_event(new_event))

metadata = Metadata(
teams=[home_team, away_team],
Expand Down
11 changes: 6 additions & 5 deletions kloppy/infra/serializers/event/wyscout/deserializer_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,7 +759,7 @@ class WyscoutDeserializerV3(EventDataDeserializer[WyscoutInputs]):
def provider(self) -> Provider:
return Provider.WYSCOUT

def deserialize(self, inputs: WyscoutInputs) -> EventDataset:
def _deserialize(self, inputs: WyscoutInputs) -> EventDataset:
transformer = self.get_transformer()

with performance_logging("load data", logger=logger):
Expand Down Expand Up @@ -811,6 +811,7 @@ def deserialize(self, inputs: WyscoutInputs) -> EventDataset:
events = []

next_pass_is_kickoff = False
event = None
for idx, raw_event in enumerate(raw_events["events"]):
next_event = None
ball_owning_team = None
Expand Down Expand Up @@ -961,7 +962,7 @@ def deserialize(self, inputs: WyscoutInputs) -> EventDataset:
)
# We already append event to events
# as we potentially have a card and foul event for one raw event
if event and self.should_include_event(event):
if event:
events.append(transformer.transform_event(event))
continue
if (
Expand All @@ -972,7 +973,7 @@ def deserialize(self, inputs: WyscoutInputs) -> EventDataset:
event = self.event_factory.build_card(
**card_event_args, **generic_event_args
)
if event and self.should_include_event(event):
if event:
events.append(transformer.transform_event(event))
continue
elif "carry" in secondary_event_types:
Expand All @@ -991,7 +992,7 @@ def deserialize(self, inputs: WyscoutInputs) -> EventDataset:
**generic_event_args,
)

if event and self.should_include_event(event):
if event:
events.append(transformer.transform_event(event))

if next_event:
Expand Down Expand Up @@ -1023,7 +1024,7 @@ def deserialize(self, inputs: WyscoutInputs) -> EventDataset:
**formation_change_event_kwargs,
**generic_event_args,
)
if event and self.should_include_event(event):
if event:
events.append(transformer.transform_event(event))

metadata = Metadata(
Expand Down
13 changes: 12 additions & 1 deletion kloppy/tests/test_event.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest

from kloppy import statsbomb
from kloppy.domain import EventDataset
from kloppy.domain import EventDataset, FilteredDataset


class TestEvent:
Expand Down Expand Up @@ -51,15 +51,26 @@ def test_filter(self, dataset: EventDataset):
"""
Test filtering allows simple 'css selector' (<event_type>.<result>)
"""
# Perform the filter
goals_dataset = dataset.filter("shot.goal")

# Assert data correctness
df = goals_dataset.to_df(engine="pandas")
assert df["event_id"].to_list() == [
"4c7c4ab1-6b9f-4504-a237-249c2e0c549f",
"683c6752-13bc-4892-94ed-22e1c938f1f7",
"55d71847-9511-4417-aea9-6f415e279011",
]

# Assert type correctness
assert isinstance(goals_dataset, FilteredDataset)
assert isinstance(goals_dataset, EventDataset)
assert type(goals_dataset).__name__ == "FilteredEventDataset"

# Filtering again should not break the class structure
subset = goals_dataset.filter(lambda x: True)
assert type(subset).__name__ == "FilteredEventDataset"

def test_map(self, dataset: EventDataset):
"""
Test the `map` method on a Dataset to allow chaining (filter and map)
Expand Down
Loading