Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions agents/perception.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,13 @@ def box_iou(boxes1: torch.Tensor, boxes2: torch.Tensor) -> torch.Tensor:
return iou


# Try to import torchvision NMS for performance optimization
try:
from torchvision.ops import nms as _torchvision_nms
except ImportError:
_torchvision_nms = None


def nms(
boxes: torch.Tensor,
scores: torch.Tensor,
Expand All @@ -693,6 +700,9 @@ def nms(
"""
Non-Maximum Suppression.

Uses optimized torchvision implementation if available (27x faster),
otherwise falls back to pure Python implementation.

Args:
boxes: (N, 4) tensor of [x1, y1, x2, y2]
scores: (N,) tensor of confidence scores
Expand All @@ -701,6 +711,11 @@ def nms(
Returns:
Indices of boxes to keep
"""
# Use optimized torchvision implementation if available
if _torchvision_nms is not None:
return _torchvision_nms(boxes, scores, iou_threshold)

# Fallback implementation
if boxes.numel() == 0:
return torch.empty((0,), dtype=torch.long, device=boxes.device)

Expand Down