diff --git a/img2dataset/logger.py b/img2dataset/logger.py index 006cb43..bf8645b 100644 --- a/img2dataset/logger.py +++ b/img2dataset/logger.py @@ -196,13 +196,22 @@ def write_stats( class LoggerProcess(multiprocessing.context.SpawnProcess): """Logger process that reads stats files regularly, aggregates and send to wandb / print to terminal""" - def __init__(self, output_folder, enable_wandb, wandb_project, config_parameters, log_interval=5): + def __init__( + self, + output_folder, + enable_wandb, + wandb_project, + config_parameters, + log_interval=5, + wandb_job_name=None + ): super().__init__() self.log_interval = log_interval self.enable_wandb = enable_wandb self.output_folder = output_folder self.stats_files = set() self.wandb_project = wandb_project + self.wandb_job_name = wandb_job_name self.done_shards = set() self.config_parameters = config_parameters ctx = multiprocessing.get_context("spawn") @@ -214,7 +223,12 @@ def run(self): fs, output_path = fsspec.core.url_to_fs(self.output_folder, use_listings_cache=False) if self.enable_wandb: - self.current_run = wandb.init(project=self.wandb_project, config=self.config_parameters, anonymous="allow") + self.current_run = wandb.init( + project=self.wandb_project, + name=self.wandb_job_name, + config=self.config_parameters, + anonymous="allow", + ) else: self.current_run = None self.total_speed_logger = SpeedLogger("total", enable_wandb=self.enable_wandb) diff --git a/img2dataset/main.py b/img2dataset/main.py index 4150c19..463a665 100644 --- a/img2dataset/main.py +++ b/img2dataset/main.py @@ -108,6 +108,7 @@ def download( max_shard_retry: int = 1, user_agent_token: Optional[str] = None, disallowed_header_directives: Optional[List[str]] = None, + wandb_job_name: Optional[str] = None, ): """Download is the main entry point of img2dataset, it uses multiple processes and download multiple files""" if disallowed_header_directives is None: @@ -127,7 +128,9 @@ def make_path_absolute(path): output_folder = make_path_absolute(output_folder) url_list = make_path_absolute(url_list) - logger_process = LoggerProcess(output_folder, enable_wandb, wandb_project, config_parameters) + logger_process = LoggerProcess( + output_folder, enable_wandb, wandb_project, config_parameters, wandb_job_name=wandb_job_name + ) tmp_path = output_folder + "/_tmp" fs, tmp_dir = fsspec.core.url_to_fs(tmp_path)