diff --git a/kloppy/domain/services/state_builder/__init__.py b/kloppy/domain/services/state_builder/__init__.py index c80ddad68..0bb4bae1b 100644 --- a/kloppy/domain/services/state_builder/__init__.py +++ b/kloppy/domain/services/state_builder/__init__.py @@ -47,4 +47,7 @@ def add_state(dataset: EventDataset, *builder_keys: list[str]) -> EventDataset: for builder_key, builder in builders.items() } + for builder_key, builder in builders.items(): + builder.post_process(events) + return replace(dataset, records=events) diff --git a/kloppy/domain/services/state_builder/builder.py b/kloppy/domain/services/state_builder/builder.py index 06f7a9474..34173af29 100644 --- a/kloppy/domain/services/state_builder/builder.py +++ b/kloppy/domain/services/state_builder/builder.py @@ -20,3 +20,6 @@ def reduce_before(self, state: T, event: Event) -> T: @abstractmethod def reduce_after(self, state: T, event: Event) -> T: pass + + def post_process(self, events: list[Event]): + pass diff --git a/kloppy/domain/services/state_builder/builders/sequence.py b/kloppy/domain/services/state_builder/builders/sequence.py index 34cc23886..4e9fedc04 100644 --- a/kloppy/domain/services/state_builder/builders/sequence.py +++ b/kloppy/domain/services/state_builder/builders/sequence.py @@ -1,42 +1,113 @@ from dataclasses import dataclass, replace +from typing import Optional from kloppy.domain import ( BallOutEvent, + CardEvent, CarryEvent, + ClearanceEvent, + DuelEvent, + DuelResult, Event, EventDataset, + FormationChangeEvent, FoulCommittedEvent, + GenericEvent, + GoalkeeperActionType, + GoalkeeperEvent, + GoalkeeperQualifier, + InterceptionEvent, + InterceptionResult, PassEvent, + PlayerOffEvent, + PlayerOnEvent, RecoveryEvent, SetPieceQualifier, ShotEvent, + SubstitutionEvent, + TakeOnEvent, Team, ) from ..builder import StateBuilder -OPEN_SEQUENCE = (PassEvent, CarryEvent, RecoveryEvent) -CLOSE_SEQUENCE = (BallOutEvent, FoulCommittedEvent, ShotEvent) - @dataclass class Sequence: - sequence_id: int - team: Team + sequence_id: Optional[int] + team: Optional[Team] + + +EXCLUDED_OFF_BALL_EVENTS = ( + GenericEvent, + SubstitutionEvent, + CardEvent, + PlayerOnEvent, + PlayerOffEvent, + FormationChangeEvent, +) + +CLOSE_SEQUENCE = (BallOutEvent, FoulCommittedEvent, ShotEvent) + + +def is_ball_winning_defensive_action(event: Event) -> bool: + if isinstance(event, DuelEvent) and event.result == DuelResult.WON: + return True + elif isinstance(event, ClearanceEvent): + return True + + +def is_possessing_event(event: Event) -> bool: + if isinstance(event, (PassEvent, CarryEvent, RecoveryEvent, TakeOnEvent)): + return True + elif isinstance(event, GoalkeeperEvent) and event.get_qualifier_value( + GoalkeeperQualifier + ) in [ + GoalkeeperActionType.PICK_UP, + GoalkeeperActionType.CLAIM, + ]: + return True + elif ( + isinstance(event, InterceptionEvent) + and event.result == InterceptionResult.SUCCESS + ): + return True + + +def should_open_sequence( + event: Event, next_event: Optional[Event], state: Optional[Sequence] = None +) -> bool: + can_open_sequence = False + if is_possessing_event(event): + can_open_sequence = True + elif ( + is_ball_winning_defensive_action(event) + and next_event is not None + and next_event.team == event.team + and is_possessing_event(next_event) + ): + can_open_sequence = True + return can_open_sequence and ( + state is None + or state.team != event.team + or event.get_qualifier_value(SetPieceQualifier) + ) + + +def should_close_sequence(event: Event) -> bool: + if isinstance(event, CLOSE_SEQUENCE): + return True class SequenceStateBuilder(StateBuilder): def initial_state(self, dataset: EventDataset) -> Sequence: for event in dataset.events: - if isinstance(event, OPEN_SEQUENCE): + if should_open_sequence(event, event.next_record): return Sequence(sequence_id=0, team=event.team) return Sequence(sequence_id=0, team=None) def reduce_before(self, state: Sequence, event: Event) -> Sequence: - if isinstance(event, OPEN_SEQUENCE) and ( - state.team != event.team - or event.get_qualifier_value(SetPieceQualifier) - ): + if should_open_sequence(event, event.next_record, state): state = replace( state, sequence_id=state.sequence_id + 1, team=event.team ) @@ -48,3 +119,28 @@ def reduce_after(self, state: Sequence, event: Event) -> Sequence: state = replace(state, sequence_id=state.sequence_id + 1, team=None) return state + + def post_process(self, events: list[Event]): + current_sequence_id = 1 + sequence_id_mapping = {} + + for event in events: + sequence = event.state["sequence"] + + if ( + isinstance(event, EXCLUDED_OFF_BALL_EVENTS) + or sequence.team is None + ): + event.state["sequence"] = Sequence(sequence_id=None, team=None) + elif sequence.sequence_id is not None: + # Map old sequence IDs to new consecutive IDs + # Get or assign a new sequence ID + new_sequence_id = sequence_id_mapping.setdefault( + sequence.sequence_id, current_sequence_id + ) + if new_sequence_id == current_sequence_id: + current_sequence_id += 1 + # Assign the new sequence ID + event.state["sequence"] = Sequence( + sequence_id=new_sequence_id, team=sequence.team + ) diff --git a/kloppy/tests/test_state_builder.py b/kloppy/tests/test_state_builder.py index 10326d534..7be865bde 100644 --- a/kloppy/tests/test_state_builder.py +++ b/kloppy/tests/test_state_builder.py @@ -1,6 +1,7 @@ +from collections import defaultdict from itertools import groupby -from kloppy import statsbomb +from kloppy import statsbomb, statsperform from kloppy.domain import Event, EventDataset, EventType, FormationType from kloppy.domain.services.state_builder.builder import StateBuilder from kloppy.utils import performance_logging @@ -9,14 +10,22 @@ class TestStateBuilder: """""" - def _load_dataset(self, base_dir, base_filename="statsbomb"): + def _load_dataset_statsbomb(self, base_dir, base_filename="statsbomb"): return statsbomb.load( event_data=base_dir / f"files/{base_filename}_event.json", lineup_data=base_dir / f"files/{base_filename}_lineup.json", ) + def _load_dataset_statsperform( + self, base_dir, base_filename="statsperform" + ): + return statsperform.load_event( + ma1_data=base_dir / f"files/{base_filename}_event_ma1.json", + ma3_data=base_dir / f"files/{base_filename}_event_ma3.json", + ) + def test_score_state_builder(self, base_dir): - dataset = self._load_dataset(base_dir) + dataset = self._load_dataset_statsbomb(base_dir) with performance_logging("add_state"): dataset_with_state = dataset.add_state("score") @@ -36,25 +45,44 @@ def test_score_state_builder(self, base_dir): "3-1": 2, } - def test_sequence_state_builder(self, base_dir): - dataset = self._load_dataset(base_dir) + def test_sequence_state_builder_statsbomb(self, base_dir): + dataset = self._load_dataset_statsbomb(base_dir) + + with performance_logging("add_state"): + dataset_with_state = dataset.add_state("sequence") + + events_per_sequence = defaultdict(int) + for sequence_id, events in groupby( + dataset_with_state.events, + lambda event: event.state["sequence"].sequence_id, + ): + events = list(events) + events_per_sequence[sequence_id] += len(events) + + assert events_per_sequence[1] == 3 + assert events_per_sequence[72] == 11 + + def test_sequence_state_builder_statsperform(self, base_dir): + dataset = self._load_dataset_statsperform(base_dir) with performance_logging("add_state"): dataset_with_state = dataset.add_state("sequence") - events_per_sequence = {} + events_per_sequence = defaultdict(int) for sequence_id, events in groupby( dataset_with_state.events, lambda event: event.state["sequence"].sequence_id, ): events = list(events) - events_per_sequence[sequence_id] = len(events) + events_per_sequence[sequence_id] += len(events) - assert events_per_sequence[0] == 4 - assert events_per_sequence[51] == 10 + assert events_per_sequence[1] == 5 + assert events_per_sequence[89] == 12 def test_lineup_state_builder(self, base_dir): - dataset = self._load_dataset(base_dir, base_filename="statsbomb_15986") + dataset = self._load_dataset_statsbomb( + base_dir, base_filename="statsbomb_15986" + ) with performance_logging("add_state"): dataset_with_state = dataset.add_state("lineup") @@ -79,7 +107,9 @@ def test_lineup_state_builder(self, base_dir): ] def test_formation_state_builder(self, base_dir): - dataset = self._load_dataset(base_dir, base_filename="statsbomb") + dataset = self._load_dataset_statsbomb( + base_dir, base_filename="statsbomb" + ) with performance_logging("add_state"): dataset_with_state = dataset.add_state("formation") @@ -114,7 +144,9 @@ def reduce_before(self, state: int, event: Event) -> int: def reduce_after(self, state: int, event: Event) -> int: return state + 1 - dataset = self._load_dataset(base_dir, base_filename="statsbomb_15986") + dataset = self._load_dataset_statsbomb( + base_dir, base_filename="statsbomb_15986" + ) with performance_logging("add_state"): dataset_with_state = dataset.add_state("custom")