Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions img2dataset/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,13 +196,22 @@ def write_stats(
class LoggerProcess(multiprocessing.context.SpawnProcess):
"""Logger process that reads stats files regularly, aggregates and send to wandb / print to terminal"""

def __init__(self, output_folder, enable_wandb, wandb_project, config_parameters, log_interval=5):
def __init__(
self,
output_folder,
enable_wandb,
wandb_project,
config_parameters,
log_interval=5,
wandb_job_name=None
):
super().__init__()
self.log_interval = log_interval
self.enable_wandb = enable_wandb
self.output_folder = output_folder
self.stats_files = set()
self.wandb_project = wandb_project
self.wandb_job_name = wandb_job_name
self.done_shards = set()
self.config_parameters = config_parameters
ctx = multiprocessing.get_context("spawn")
Expand All @@ -214,7 +223,12 @@ def run(self):
fs, output_path = fsspec.core.url_to_fs(self.output_folder, use_listings_cache=False)

if self.enable_wandb:
self.current_run = wandb.init(project=self.wandb_project, config=self.config_parameters, anonymous="allow")
self.current_run = wandb.init(
project=self.wandb_project,
name=self.wandb_job_name,
config=self.config_parameters,
anonymous="allow",
)
else:
self.current_run = None
self.total_speed_logger = SpeedLogger("total", enable_wandb=self.enable_wandb)
Expand Down
5 changes: 4 additions & 1 deletion img2dataset/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def download(
max_shard_retry: int = 1,
user_agent_token: Optional[str] = None,
disallowed_header_directives: Optional[List[str]] = None,
wandb_job_name: Optional[str] = None,
):
"""Download is the main entry point of img2dataset, it uses multiple processes and download multiple files"""
if disallowed_header_directives is None:
Expand All @@ -127,7 +128,9 @@ def make_path_absolute(path):
output_folder = make_path_absolute(output_folder)
url_list = make_path_absolute(url_list)

logger_process = LoggerProcess(output_folder, enable_wandb, wandb_project, config_parameters)
logger_process = LoggerProcess(
output_folder, enable_wandb, wandb_project, config_parameters, wandb_job_name=wandb_job_name
)

tmp_path = output_folder + "/_tmp"
fs, tmp_dir = fsspec.core.url_to_fs(tmp_path)
Expand Down