Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions preprocessing/sports/space_data/space_class.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,25 @@
class Space_data:
# Modified the sports list to only include fully supported providers
basketball_data_provider = ['SportVU_NBA']
soccer_data_provider = ['fifa_wc_2022']
basketball_data_provider = ["SportVU_NBA"]
soccer_data_provider = ["fifa_wc_2022"]
ultimate_data_provider = ["UltimateTrack", "UFA"]

def __new__(cls, data_provider, *args, **kwargs):
if data_provider in cls.basketball_data_provider:
from .basketball.basketball_space_class import Basketball_space_data

# If the data_provider is in the supported list, return an instance of Basketball_space_data
return Basketball_space_data(data_provider, *args, **kwargs)
elif data_provider in cls.soccer_data_provider:
from .soccer.soccer_space_class import Soccer_space_data

# If the data_provider is in the supported list, return an instance of Soccer_space_data
return Soccer_space_data(data_provider, *args, **kwargs)
elif data_provider in cls.ultimate_data_provider:
from .ultimate.ultimate_space_class import Ultimate_space_data

# If the data_provider is in the supported list, return an instance of Ultimate_space_data
return Ultimate_space_data(data_provider, *args, **kwargs)
else:
# If the data_provider is unrecognized, raise a ValueError
raise ValueError(f'Unknown data provider: {data_provider}')
raise ValueError(f"Unknown data provider: {data_provider}")
Empty file.
93 changes: 93 additions & 0 deletions preprocessing/sports/space_data/ultimate/ultimate_space_class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import os

import pandas as pd
from tqdm import tqdm


class Ultimate_space_data:
def __init__(
self,
data_provider,
tracking_data_path,
out_path=None,
testing_mode=False,
):
self.data_provider = data_provider
self.tracking_path = tracking_data_path
self.testing_mode = testing_mode
self.out_path = out_path
if self.data_provider == "UltimateTrack":
self.tracking_herz = 15
elif self.data_provider == "UFA":
self.tracking_herz = 10

def get_files(self):
if os.path.isdir(self.tracking_path):
data_files = [
os.path.join(self.tracking_path, f)
for f in os.listdir(self.tracking_path)
if f.endswith(".csv")
]
elif os.path.isfile(self.tracking_path) and self.tracking_path.endswith(".csv"):
data_files = [self.tracking_path]
else:
raise ValueError(f"Invalid data path: {self.tracking_path}")
return data_files

def preprocessing(self):
tracking_files = self.get_files()
if self.testing_mode:
tracking_files = tracking_files[:2]
print("Running in testing mode. Limited files will be processed.")

from .ultimate_space_preprocessing import (
convert_to_metrica_format,
create_intermediate_file,
)

home_tracking_dict = {}
away_tracking_dict = {}
event_data_dict = {}
for tracking_path_i in tqdm(
tracking_files, total=len(tracking_files), desc="Processing tracking files"
):
match_i = os.path.splitext(
os.path.splitext(os.path.basename(tracking_path_i))[0]
)[0]
match_tracking_df = pd.read_csv(tracking_path_i)
print(match_tracking_df)
# Create intermediate DataFrame with all required columns
intermidiate_df = create_intermediate_file(match_tracking_df)

# Convert to Metrica format
home_df, away_df, events_df = convert_to_metrica_format(
intermidiate_df, self.tracking_herz
)

home_tracking_dict[match_i] = home_df
away_tracking_dict[match_i] = away_df
event_data_dict[match_i] = events_df

if self.out_path:
# create output directory if not exists
os.makedirs(self.out_path + "/event", exist_ok=True)
os.makedirs(self.out_path + "/home_tracking", exist_ok=True)
os.makedirs(self.out_path + "/away_tracking", exist_ok=True)

for match_id, df in event_data_dict.items():
df.to_csv(
os.path.join(self.out_path, "event", f"{match_id}.csv"),
index=False,
)
for match_id, df in home_tracking_dict.items():
df.to_csv(
os.path.join(self.out_path, "home_tracking", f"{match_id}.csv"),
index=False,
)
for match_id, df in away_tracking_dict.items():
df.to_csv(
os.path.join(self.out_path, "away_tracking", f"{match_id}.csv"),
index=False,
)

return event_data_dict, home_tracking_dict, away_tracking_dict
Loading