diff --git a/kloppy/domain/models/common.py b/kloppy/domain/models/common.py index 78ebef84c..b568428aa 100644 --- a/kloppy/domain/models/common.py +++ b/kloppy/domain/models/common.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from collections import defaultdict from collections.abc import Iterable -from dataclasses import dataclass, field, replace +from dataclasses import dataclass, field, fields, replace from datetime import datetime, timedelta from enum import Enum, Flag import sys @@ -1667,6 +1667,18 @@ def __post_init__(self): T = TypeVar("T", bound="DataRecord") +class FilteredDataset: + """ + A marker class to identify datasets that have been filtered. + """ + + pass + + +# Helper to cache dynamic classes so we don't recreate them every time +_FILTERED_CLASS_CACHE = {} + + @dataclass class Dataset(ABC, Generic[T]): """ @@ -1757,10 +1769,39 @@ def filter(self, filter_: str | Callable[[T], bool]): >>> dataset = dataset.filter(lambda event: event.event_type == EventType.PASS) >>> dataset = dataset.filter('pass') """ - return replace( - self, - records=self.find_all(filter_), - ) + # 1. Perform filtering + filtered_records = self.find_all(filter_) + + # 2. Determine the target class + current_class = self.__class__ + + if isinstance(self, FilteredDataset): + # Already a filtered class, keep using it + target_class = current_class + else: + # Need to create or retrieve the dynamic filtered subclass + if current_class not in _FILTERED_CLASS_CACHE: + # Dynamically create: class FilteredEventDataset(FilteredDataset, EventDataset) + new_cls_name = f"Filtered{current_class.__name__}" + + # We inherit from FilteredDataset first, then the original class + _FILTERED_CLASS_CACHE[current_class] = type( + new_cls_name, (FilteredDataset, current_class), {} + ) + target_class = _FILTERED_CLASS_CACHE[current_class] + + # 3. Gather arguments to instantiate the new class + # We only want fields that are defined in __init__ + init_kwargs = { + f.name: getattr(self, f.name) for f in fields(self) if f.init + } + + # 4. Update the records in the arguments + init_kwargs["records"] = filtered_records + + # 5. Instantiate and return the new class + # This runs __init__ and __post_init__ of the new class naturally + return target_class(**init_kwargs) def map(self, mapper): return replace( diff --git a/kloppy/infra/serializers/event/datafactory/deserializer.py b/kloppy/infra/serializers/event/datafactory/deserializer.py index c733381be..2a960da6d 100644 --- a/kloppy/infra/serializers/event/datafactory/deserializer.py +++ b/kloppy/infra/serializers/event/datafactory/deserializer.py @@ -346,7 +346,7 @@ class DatafactoryDeserializer(EventDataDeserializer[DatafactoryInputs]): def provider(self) -> Provider: return Provider.DATAFACTORY - def deserialize(self, inputs: DatafactoryInputs) -> EventDataset: + def _deserialize(self, inputs: DatafactoryInputs) -> EventDataset: transformer = self.get_transformer() with performance_logging("load data", logger=logger): @@ -590,13 +590,9 @@ def deserialize(self, inputs: DatafactoryInputs) -> EventDataset: result=None, qualifiers=None, ) - if self.should_include_event(event): - events.append( - transformer.transform_event(ball_out_event) - ) + events.append(transformer.transform_event(ball_out_event)) - if self.should_include_event(event): - events.append(transformer.transform_event(event)) + events.append(transformer.transform_event(event)) # only consider as a previous_event a ball-in-play event if e_class not in ( diff --git a/kloppy/infra/serializers/event/deserializer.py b/kloppy/infra/serializers/event/deserializer.py index 60662a396..2f84b1866 100644 --- a/kloppy/infra/serializers/event/deserializer.py +++ b/kloppy/infra/serializers/event/deserializer.py @@ -1,5 +1,7 @@ from abc import ABC, abstractmethod -from typing import Generic, Optional, TypeVar, Union +from dataclasses import fields, replace +from typing import Any, Generic, Optional, TypeVar, Union +import warnings from kloppy.domain import ( DatasetTransformer, @@ -64,5 +66,51 @@ def provider(self) -> Provider: raise NotImplementedError @abstractmethod - def deserialize(self, inputs: T) -> EventDataset: + def _deserialize(self, inputs: T) -> EventDataset: raise NotImplementedError + + def deserialize( + self, inputs: T, additional_metadata: Optional[dict[str, Any]] = None + ) -> EventDataset: + dataset = self._deserialize(inputs) + + # Check for additional metadata to merge + if additional_metadata: + # Identify valid fields in the Metadata class + valid_fields = {f.name for f in fields(dataset.metadata)} + + # Split additional_metadata into known and unknown keys + known_updates = {} + unknown_updates = {} + + for key, value in additional_metadata.items(): + if key in valid_fields: + known_updates[key] = value + else: + unknown_updates[key] = value + + # Handle unknown keys (put them into 'attributes' and warn) + if unknown_updates: + warnings.warn( + f"The following metadata keys are not supported fields and will be " + f"added to 'attributes': {list(unknown_updates.keys())}" + ) + + # specific logic to merge with existing attributes safely + current_attributes = dataset.metadata.attributes or {} + # Create a new dict to avoid mutating the original if it's shared + new_attributes = current_attributes.copy() + new_attributes.update(unknown_updates) + + known_updates["attributes"] = new_attributes + + # Apply updates + if known_updates: + updated_metadata = replace(dataset.metadata, **known_updates) + dataset = replace(dataset, metadata=updated_metadata) + + # Check if we need to return a FilteredEventDataset + if self.event_types: + return dataset.filter(self.should_include_event) + + return dataset diff --git a/kloppy/infra/serializers/event/impect/deserializer.py b/kloppy/infra/serializers/event/impect/deserializer.py index daf5687ef..ab71b785c 100644 --- a/kloppy/infra/serializers/event/impect/deserializer.py +++ b/kloppy/infra/serializers/event/impect/deserializer.py @@ -94,7 +94,7 @@ class ImpectDeserializer(EventDataDeserializer[ImpectInputs]): def provider(self) -> Provider: return Provider.IMPECT - def deserialize(self, inputs: ImpectInputs) -> EventDataset: + def _deserialize(self, inputs: ImpectInputs) -> EventDataset: # Initialize coordinate system transformer self.transformer = self.get_transformer() @@ -134,10 +134,9 @@ def deserialize(self, inputs: ImpectInputs) -> EventDataset: periods, teams, impect_events ).deserialize(self.event_factory, teams) for event in new_events: - if self.should_include_event(event): - # Transform event to the coordinate system - event = self.transformer.transform_event(event) - events.append(event) + # Transform event to the coordinate system + event = self.transformer.transform_event(event) + events.append(event) self.mark_events_as_assists(events) substitution_events = self.parse_substitutions( diff --git a/kloppy/infra/serializers/event/metrica/json_deserializer.py b/kloppy/infra/serializers/event/metrica/json_deserializer.py index 8ac736084..64a7156f8 100644 --- a/kloppy/infra/serializers/event/metrica/json_deserializer.py +++ b/kloppy/infra/serializers/event/metrica/json_deserializer.py @@ -252,7 +252,7 @@ class MetricaJsonEventDataDeserializer( def provider(self) -> Provider: return Provider.METRICA - def deserialize(self, inputs: MetricaJsonEventDataInputs) -> EventDataset: + def _deserialize(self, inputs: MetricaJsonEventDataInputs) -> EventDataset: with performance_logging("load data", logger=logger): raw_events = json.load(inputs.event_data) metadata = load_metadata( @@ -370,8 +370,7 @@ def deserialize(self, inputs: MetricaJsonEventDataInputs) -> EventDataset: **generic_event_kwargs, ) - if self.should_include_event(event): - events.append(transformer.transform_event(event)) + events.append(transformer.transform_event(event)) # Checks if the event ended out of the field and adds a synthetic out event if ( @@ -393,9 +392,7 @@ def deserialize(self, inputs: MetricaJsonEventDataInputs) -> EventDataset: qualifiers=None, **generic_event_kwargs, ) - - if self.should_include_event(event): - events.append(transformer.transform_event(event)) + events.append(transformer.transform_event(event)) return EventDataset( metadata=replace( diff --git a/kloppy/infra/serializers/event/sportec/deserializer.py b/kloppy/infra/serializers/event/sportec/deserializer.py index c19a50e4e..f17fd4a10 100644 --- a/kloppy/infra/serializers/event/sportec/deserializer.py +++ b/kloppy/infra/serializers/event/sportec/deserializer.py @@ -461,7 +461,7 @@ class SportecEventDataDeserializer( def provider(self) -> Provider: return Provider.SPORTEC - def deserialize(self, inputs: SportecEventDataInputs) -> EventDataset: + def _deserialize(self, inputs: SportecEventDataInputs) -> EventDataset: with performance_logging("load data", logger=logger): match_root = objectify.fromstring(inputs.meta_data.read()) event_root = objectify.fromstring(inputs.event_data.read()) @@ -682,13 +682,6 @@ def deserialize(self, inputs: SportecEventDataInputs) -> EventDataset: else: event.receiver_coordinates = events[i + 1].coordinates - events = list( - filter( - self.should_include_event, - events, - ) - ) - metadata = Metadata( teams=teams, periods=periods, diff --git a/kloppy/infra/serializers/event/statsbomb/deserializer.py b/kloppy/infra/serializers/event/statsbomb/deserializer.py index d36d4f747..1105b0911 100644 --- a/kloppy/infra/serializers/event/statsbomb/deserializer.py +++ b/kloppy/infra/serializers/event/statsbomb/deserializer.py @@ -37,9 +37,7 @@ class StatsBombDeserializer(EventDataDeserializer[StatsBombInputs]): def provider(self) -> Provider: return Provider.STATSBOMB - def deserialize( - self, inputs: StatsBombInputs, additional_metadata - ) -> EventDataset: + def _deserialize(self, inputs: StatsBombInputs) -> EventDataset: # Intialize coordinate system transformer self.transformer = self.get_transformer() @@ -76,10 +74,9 @@ def deserialize( .deserialize(self.event_factory) ) for event in new_events: - if self.should_include_event(event): - # Transform event to the coordinate system - event = self.transformer.transform_event(event) - events.append(event) + # Transform event to the coordinate system + event = self.transformer.transform_event(event) + events.append(event) metadata = Metadata( teams=teams, @@ -91,7 +88,6 @@ def deserialize( score=None, provider=Provider.STATSBOMB, coordinate_system=self.transformer.get_to_coordinate_system(), - **additional_metadata, ) dataset = EventDataset(metadata=metadata, records=events) for event in dataset: diff --git a/kloppy/infra/serializers/event/statsperform/deserializer.py b/kloppy/infra/serializers/event/statsperform/deserializer.py index 7910a655b..6b5e56735 100644 --- a/kloppy/infra/serializers/event/statsperform/deserializer.py +++ b/kloppy/infra/serializers/event/statsperform/deserializer.py @@ -711,7 +711,7 @@ class StatsPerformDeserializer(EventDataDeserializer[StatsPerformInputs]): def provider(self) -> Provider: return Provider.OPTA - def deserialize(self, inputs: StatsPerformInputs) -> EventDataset: + def _deserialize(self, inputs: StatsPerformInputs) -> EventDataset: transformer = self.get_transformer( pitch_length=inputs.pitch_length, pitch_width=inputs.pitch_width ) @@ -984,8 +984,7 @@ def deserialize(self, inputs: StatsPerformInputs) -> EventDataset: event_name=_get_event_type_name(raw_event.type_id), ) - if self.should_include_event(event): - events.append(transformer.transform_event(event)) + events.append(transformer.transform_event(event)) metadata = Metadata( teams=list(teams), diff --git a/kloppy/infra/serializers/event/wyscout/deserializer_v2.py b/kloppy/infra/serializers/event/wyscout/deserializer_v2.py index ad7e1f99f..2542fa589 100644 --- a/kloppy/infra/serializers/event/wyscout/deserializer_v2.py +++ b/kloppy/infra/serializers/event/wyscout/deserializer_v2.py @@ -465,7 +465,7 @@ class WyscoutDeserializerV2(EventDataDeserializer[WyscoutInputs]): def provider(self) -> Provider: return Provider.WYSCOUT - def deserialize(self, inputs: WyscoutInputs) -> EventDataset: + def _deserialize(self, inputs: WyscoutInputs) -> EventDataset: transformer = self.get_transformer() with performance_logging("load data", logger=logger): @@ -710,8 +710,7 @@ def deserialize(self, inputs: WyscoutInputs) -> EventDataset: new_events.insert(i, interception_event) for new_event in new_events: - if self.should_include_event(new_event): - events.append(transformer.transform_event(new_event)) + events.append(transformer.transform_event(new_event)) metadata = Metadata( teams=[home_team, away_team], diff --git a/kloppy/infra/serializers/event/wyscout/deserializer_v3.py b/kloppy/infra/serializers/event/wyscout/deserializer_v3.py index c91c6dc06..9d34623dd 100644 --- a/kloppy/infra/serializers/event/wyscout/deserializer_v3.py +++ b/kloppy/infra/serializers/event/wyscout/deserializer_v3.py @@ -759,7 +759,7 @@ class WyscoutDeserializerV3(EventDataDeserializer[WyscoutInputs]): def provider(self) -> Provider: return Provider.WYSCOUT - def deserialize(self, inputs: WyscoutInputs) -> EventDataset: + def _deserialize(self, inputs: WyscoutInputs) -> EventDataset: transformer = self.get_transformer() with performance_logging("load data", logger=logger): @@ -811,6 +811,7 @@ def deserialize(self, inputs: WyscoutInputs) -> EventDataset: events = [] next_pass_is_kickoff = False + event = None for idx, raw_event in enumerate(raw_events["events"]): next_event = None ball_owning_team = None @@ -961,7 +962,7 @@ def deserialize(self, inputs: WyscoutInputs) -> EventDataset: ) # We already append event to events # as we potentially have a card and foul event for one raw event - if event and self.should_include_event(event): + if event: events.append(transformer.transform_event(event)) continue if ( @@ -972,7 +973,7 @@ def deserialize(self, inputs: WyscoutInputs) -> EventDataset: event = self.event_factory.build_card( **card_event_args, **generic_event_args ) - if event and self.should_include_event(event): + if event: events.append(transformer.transform_event(event)) continue elif "carry" in secondary_event_types: @@ -991,7 +992,7 @@ def deserialize(self, inputs: WyscoutInputs) -> EventDataset: **generic_event_args, ) - if event and self.should_include_event(event): + if event: events.append(transformer.transform_event(event)) if next_event: @@ -1023,7 +1024,7 @@ def deserialize(self, inputs: WyscoutInputs) -> EventDataset: **formation_change_event_kwargs, **generic_event_args, ) - if event and self.should_include_event(event): + if event: events.append(transformer.transform_event(event)) metadata = Metadata( diff --git a/kloppy/tests/test_event.py b/kloppy/tests/test_event.py index c89c214bf..2a492ea7a 100644 --- a/kloppy/tests/test_event.py +++ b/kloppy/tests/test_event.py @@ -1,7 +1,7 @@ import pytest from kloppy import statsbomb -from kloppy.domain import EventDataset +from kloppy.domain import EventDataset, FilteredDataset class TestEvent: @@ -51,8 +51,10 @@ def test_filter(self, dataset: EventDataset): """ Test filtering allows simple 'css selector' (.) """ + # Perform the filter goals_dataset = dataset.filter("shot.goal") + # Assert data correctness df = goals_dataset.to_df(engine="pandas") assert df["event_id"].to_list() == [ "4c7c4ab1-6b9f-4504-a237-249c2e0c549f", @@ -60,6 +62,15 @@ def test_filter(self, dataset: EventDataset): "55d71847-9511-4417-aea9-6f415e279011", ] + # Assert type correctness + assert isinstance(goals_dataset, FilteredDataset) + assert isinstance(goals_dataset, EventDataset) + assert type(goals_dataset).__name__ == "FilteredEventDataset" + + # Filtering again should not break the class structure + subset = goals_dataset.filter(lambda x: True) + assert type(subset).__name__ == "FilteredEventDataset" + def test_map(self, dataset: EventDataset): """ Test the `map` method on a Dataset to allow chaining (filter and map)