Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def preprocessing(self):
from .ultimate_space_preprocessing import (
convert_to_metrica_format,
create_intermediate_file,
format_tracking_headers,
)

home_tracking_dict = {}
Expand All @@ -55,7 +56,7 @@ def preprocessing(self):
os.path.splitext(os.path.basename(tracking_path_i))[0]
)[0]
match_tracking_df = pd.read_csv(tracking_path_i)
print(match_tracking_df)

# Create intermediate DataFrame with all required columns
intermidiate_df = create_intermediate_file(match_tracking_df)

Expand All @@ -64,6 +65,9 @@ def preprocessing(self):
intermidiate_df, self.tracking_herz
)

home_df = format_tracking_headers(home_df, team_prefix="Home")
away_df = format_tracking_headers(away_df, team_prefix="Away")

home_tracking_dict[match_i] = home_df
away_tracking_dict[match_i] = away_df
event_data_dict[match_i] = events_df
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def create_events_metrica(df, tracking_herz):
if not holder_data.empty:
to_id = (
holder_data["id"]
.map(lambda x: offense_ids.index(x) if x in offense_ids else np.nan)
.map(lambda x: offense_ids.index(x) + 1 if x in offense_ids else np.nan)
.reset_index(drop=True)
)
else:
Expand Down Expand Up @@ -353,3 +353,62 @@ def create_tracking_metrica(df, team, tracking_herz):
tracking_df.columns = multi_columns

return tracking_df


def format_tracking_headers(tracking_df, team_prefix="Home"):
"""Convert the multi-index tracking output into a single-header format."""
if tracking_df.empty:
return tracking_df

flattened_columns = []
active_columns = []
seen_columns = set()

for column in tracking_df.columns:
# MultiIndex columns are returned as tuples
level2_name = column[2] if isinstance(column, tuple) else column

if column in seen_columns:
continue

if level2_name == "Frame":
continue

if level2_name == "Period":
flattened_columns.append("Period")
active_columns.append(column)
seen_columns.add(column)
continue

if level2_name == "Time [s]":
flattened_columns.append("Time [s]")
active_columns.append(column)
seen_columns.add(column)
continue

if level2_name == "Disc__":
flattened_columns.append("disc_x")
flattened_columns.append("disc_y")
active_columns.append(column)
seen_columns.add(column)
continue

if (
isinstance(column, tuple)
and column[0] == team_prefix
and level2_name.startswith("Player")
):
player_index = int(level2_name.replace("Player", "")) + 1
for suffix in ["_x", "_y"]:
flattened_columns.append(f"{team_prefix}_{player_index}{suffix}")
active_columns.append(column)
seen_columns.add(column)
continue

formatted_df = tracking_df[active_columns].copy()
formatted_df.columns = flattened_columns

if "Period" in formatted_df.columns and formatted_df["Period"].isna().all():
formatted_df["Period"] = 1

return formatted_df