Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ __pycache__/
build/
develop-eggs/
dist/

# Keep visdet runtime dist package tracked
!visdet/engine/dist/
!visdet/engine/dist/**
visdet/engine/dist/__pycache__/

downloads/
eggs/
.eggs/
Expand Down
40 changes: 40 additions & 0 deletions scripts/train_auto_ddp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#!/usr/bin/env python3
"""Train with automatic single-node DDP (no torchrun).

This is a thin wrapper around `visdet.engine.runner.auto_train.auto_train`.

Usage:
python scripts/train_auto_ddp.py path/to/config.py

Notes:
- The config must be compatible with `visdet.engine.runner.Runner.from_cfg()`.
- If multiple GPUs are available, one worker process is spawned per GPU.
"""

import argparse

from visdet.engine.config import Config
from visdet.engine.runner import auto_train

_CONFIG_PATH: str | None = None


def _config_builder(_rank: int, _world_size: int) -> tuple[Config, dict]:
assert _CONFIG_PATH is not None
cfg = Config.fromfile(_CONFIG_PATH)
return cfg, {"config": _CONFIG_PATH}


def main() -> None:
parser = argparse.ArgumentParser(description="visdet auto-DDP training")
parser.add_argument("config", help="Path to a Runner.from_cfg config")
args = parser.parse_args()

global _CONFIG_PATH
_CONFIG_PATH = args.config

auto_train(_config_builder)


if __name__ == "__main__":
main()
48 changes: 48 additions & 0 deletions tests/test_runtime/test_dist_env_fallback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import os

from visdet.engine import dist


def test_get_rank_world_size_env_fallback(monkeypatch):
# Ensure process group isn't initialized in this unit test
assert not dist.is_distributed()

monkeypatch.setenv("RANK", "3")
monkeypatch.setenv("WORLD_SIZE", "8")

assert dist.get_rank() == 3
assert dist.get_world_size() == 8
assert dist.get_dist_info() == (3, 8)


def test_infer_launcher_env(monkeypatch):
monkeypatch.setenv("WORLD_SIZE", "2")
assert dist.infer_launcher() == "pytorch"

monkeypatch.delenv("WORLD_SIZE", raising=False)
monkeypatch.setenv("SLURM_NTASKS", "2")
assert dist.infer_launcher() == "slurm"

monkeypatch.delenv("SLURM_NTASKS", raising=False)
monkeypatch.setenv("OMPI_COMM_WORLD_LOCAL_RANK", "0")
assert dist.infer_launcher() == "mpi"

monkeypatch.delenv("OMPI_COMM_WORLD_LOCAL_RANK", raising=False)
assert dist.infer_launcher() == "none"


def test_master_only_decorator(monkeypatch):
monkeypatch.setenv("RANK", "1")

called = {"value": False}

@dist.master_only
def _fn():
called["value"] = True

_fn()
assert called["value"] is False

monkeypatch.setenv("RANK", "0")
_fn()
assert called["value"] is True
288 changes: 72 additions & 216 deletions visdet/engine/dist/__init__.py
Original file line number Diff line number Diff line change
@@ -1,221 +1,77 @@
# ruff: noqa
# type: ignore
# Copyright (c) OpenMMLab. All rights reserved.
"""Distributed utilities for visdet."""

import functools
import os
import pickle
import warnings
from typing import Any, List, Optional

import torch
import torch.distributed as dist_lib


def _is_dist_available_and_initialized():
"""Check if distributed training is available and initialized."""
return dist_lib.is_available() and dist_lib.is_initialized()


def get_dist_info():
"""Get distributed training info.

Returns:
tuple: rank, world_size
"""
if _is_dist_available_and_initialized():
rank = dist_lib.get_rank()
world_size = dist_lib.get_world_size()
else:
rank = 0
world_size = 1
return rank, world_size


def get_rank():
"""Get rank of current process."""
if _is_dist_available_and_initialized():
return dist_lib.get_rank()
return 0


def get_world_size():
"""Get world size."""
if _is_dist_available_and_initialized():
return dist_lib.get_world_size()
return 1


def is_distributed():
"""Check if distributed training is initialized."""
return _is_dist_available_and_initialized()


def is_main_process():
"""Check if current process is main process (rank 0)."""
return get_rank() == 0


def master_only(func):
"""Decorator to make a function only execute on master process."""
@functools.wraps(func)
def wrapper(*args, **kwargs):
if is_main_process():
return func(*args, **kwargs)
return wrapper


def barrier():
"""Synchronize all processes."""
if _is_dist_available_and_initialized():
dist_lib.barrier()


def broadcast(data: Any, src: int = 0, group: Any | None = None) -> Any:
"""Broadcast data from src rank to all ranks."""
if not _is_dist_available_and_initialized():
return data

if isinstance(data, torch.Tensor):
dist_lib.broadcast(data, src, group=group)
return data
else:
# For non-tensor data, convert to tensor, broadcast, then convert back
if get_rank() == src:
data_tensor = torch.tensor(data, device='cuda' if torch.cuda.is_available() else 'cpu')
else:
data_tensor = torch.zeros_like(torch.tensor(data, device='cuda' if torch.cuda.is_available() else 'cpu'))
dist_lib.broadcast(data_tensor, src, group=group)
return data_tensor.item() if data_tensor.dim() == 0 else data_tensor


def broadcast_object_list(obj_list, src=0, group=None):
"""Broadcast a list of objects from src rank to all ranks."""
if not _is_dist_available_and_initialized():
return obj_list

dist_lib.broadcast_object_list(obj_list, src, group=group)
return obj_list


def all_reduce_params(model):
"""All reduce model parameters for synchronization."""
if not _is_dist_available_and_initialized():
return

world_size = get_world_size()
for param in model.parameters():
if param.requires_grad and param.grad is not None:
dist_lib.all_reduce(param.grad.data)
param.grad.data.div_(world_size)


def init_dist(launcher: str = "pytorch", backend: str = "nccl", **kwargs: Any) -> tuple[int, int]:
"""Initialize distributed environment."""
if _is_dist_available_and_initialized():
return get_dist_info()

if launcher == 'pytorch':
dist_lib.init_process_group(backend=backend, **kwargs)
else:
raise NotImplementedError(f'Launcher {launcher} is not supported')

return get_dist_info()


def collect_results(result_part, size, tmpdir=None):
"""Collect results from all processes and merge them."""
rank, world_size = get_dist_info()

# Non-distributed mode: just return the results directly
if world_size == 1:
return result_part

if tmpdir is None:
tmpdir = '.'

# Create result file
result_file = os.path.join(tmpdir, f'result_rank_{rank}.pkl')
with open(result_file, 'wb') as f:
pickle.dump(result_part, f)

dist_lib.barrier()

if rank == 0:
results = []
for i in range(world_size):
result_file = os.path.join(tmpdir, f'result_rank_{i}.pkl')
with open(result_file, 'rb') as f:
results.append(pickle.load(f))

# Clean up
for i in range(world_size):
result_file = os.path.join(tmpdir, f'result_rank_{i}.pkl')
if os.path.exists(result_file):
os.remove(result_file)

# Merge results (flatten list of lists)
merged_results = []
for result in results:
if isinstance(result, list):
merged_results.extend(result)
else:
merged_results.append(result)
return merged_results

return None


def sync_random_seed(seed=None, device='cuda'):
"""Make sure different ranks share the same seed.

All workers must call this function, otherwise it will deadlock.
"""
import numpy as np

if seed is None:
seed = np.random.randint(2**31)

rank, world_size = get_dist_info()

if world_size == 1:
return seed

if not _is_dist_available_and_initialized():
return seed

if rank == 0:
random_num = torch.tensor(seed, dtype=torch.int32, device=device)
else:
random_num = torch.tensor(0, dtype=torch.int32, device=device)

dist_lib.broadcast(random_num, src=0)
return random_num.item()


def infer_launcher():
"""Infer launcher type from environment variables."""
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
return 'pytorch'
else:
return None


# Backward compatibility
utils = type('utils', (), {'master_only': master_only})()
from visdet.engine.dist.dist import (
all_gather,
all_gather_object,
all_reduce,
all_reduce_dict,
all_reduce_params,
broadcast,
broadcast_object_list,
collect_results,
collect_results_cpu,
collect_results_gpu,
gather,
gather_object,
sync_random_seed,
)
from visdet.engine.dist.dist_utils import broadcast_from_rank_0, rank_0_only, rank_0_only_method
from visdet.engine.dist.utils import (
barrier,
cast_data_device,
get_backend,
get_comm_device,
get_data_device,
get_default_group,
get_dist_info,
get_local_group,
get_local_rank,
get_local_size,
get_rank,
get_world_size,
infer_launcher,
init_dist,
init_local_group,
is_distributed,
is_main_process,
master_only,
)

__all__ = [
'get_dist_info',
'get_rank',
'get_world_size',
'is_distributed',
'is_main_process',
'master_only',
'barrier',
'broadcast',
'broadcast_object_list',
'all_reduce_params',
'init_dist',
'collect_results',
'sync_random_seed',
'infer_launcher',
"all_gather",
"all_gather_object",
"all_reduce",
"all_reduce_dict",
"all_reduce_params",
"barrier",
"broadcast",
"broadcast_from_rank_0",
"broadcast_object_list",
"cast_data_device",
"collect_results",
"collect_results_cpu",
"collect_results_gpu",
"gather",
"gather_object",
"get_backend",
"get_comm_device",
"get_data_device",
"get_default_group",
"get_dist_info",
"get_local_group",
"get_local_rank",
"get_local_size",
"get_rank",
"get_world_size",
"infer_launcher",
"init_dist",
"init_local_group",
"is_distributed",
"is_main_process",
"master_only",
"rank_0_only",
"rank_0_only_method",
"sync_random_seed",
]
Loading
Loading