-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathprocess_batch.py
More file actions
217 lines (177 loc) · 7.22 KB
/
process_batch.py
File metadata and controls
217 lines (177 loc) · 7.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
"""
Process batch data from folder structure: batch/<node>/<slot>.json
Output: output/<date>/<slot>.npy containing [response_of_node1, response_of_node2, ...]
"""
import numpy as np
from pathlib import Path
from ultralytics import YOLO
from models.request import Request
from models.response import Response
import argparse
import json
from collections import defaultdict
from tqdm import tqdm
from multiprocessing import Pool, cpu_count, Manager
from datetime import datetime, timedelta
def get_prev_vehicles_count(
prev_counts, slot_time: str, camera_id: str, output_dir: str, date: str
):
"""
Best-effort fallback helper:
1. Try previous slots from the same day (via in-memory shared store).
2. If none exist, fall back to the previous day's latest result for that camera.
"""
# 1) Same-day previous slots from shared store
# prev_counts keys are tuples: (camera_id, slot_time)
candidates = [
(k_slot, v)
for (k_cam, k_slot), v in prev_counts.items()
if k_cam == camera_id and k_slot < slot_time
]
if candidates:
# Pick the latest earlier slot
_, vehicles_count = max(candidates, key=lambda x: x[0])
return vehicles_count
# 2) No same-day history: look at previous day's outputs on disk
try:
curr_date = datetime.strptime(date, "%Y-%m-%d").date()
except ValueError:
# If date format is unexpected, give up on cross-day fallback
return None
prev_date = curr_date - timedelta(days=1)
prev_date_str = prev_date.isoformat()
prev_output_dir = Path(output_dir) / prev_date_str
if not prev_output_dir.exists():
return None
# Find latest slots from previous day (filenames are <slot_time>.npy)
prev_files = sorted(prev_output_dir.glob("*.npy"))
for prev_file in reversed(prev_files):
try:
prev_responses = np.load(prev_file, allow_pickle=True)
for r in prev_responses:
if r.get("camera_id") == camera_id:
return r.get("vehicles_count")
except Exception:
# Ignore any corrupt/unreadable files and keep looking further back
continue
return None
def process_slot(args):
"""Process a single slot across all nodes."""
slot_time, node_files, model_path, output_dir, date, prev_counts = args
# Load model (each process gets its own model instance)
model = YOLO(model_path, verbose=False)
classes = [1, 2, 3, 5, 7] # Vehicle classes
responses = []
# Process each node for this slot
for node_id, slot_file in node_files:
try:
# Load request from JSON file
with open(slot_file, "r") as f:
req = Request(**json.load(f))
# Prepare images from frames
image_refs = [frame.image_ref for frame in req.frames]
if not image_refs:
# No images: try to reuse vehicles_count from previous slots/day for this camera
vehicles_count = get_prev_vehicles_count(
prev_counts=prev_counts,
slot_time=slot_time,
camera_id=req.camera_id,
output_dir=output_dir,
date=date,
)
if vehicles_count is None:
# Nothing to fall back to, keep old behavior and skip
continue
else:
# Run YOLO on frames
results = model(image_refs, classes=classes, verbose=False)
vehicles_count = sum([len(r.boxes) for r in results])
# Update shared fallback store for this camera & slot
prev_counts[(req.camera_id, slot_time)] = vehicles_count
# Create response
response = Response(
camera_id=req.camera_id,
slot=req.slot,
generated_at=req.generated_at,
duration_sec=req.duration_sec,
vehicles_count=vehicles_count,
speed=req.speed,
weather=req.weather,
)
responses.append(response.model_dump())
except Exception as e:
print(f"\nError processing {node_id}/{slot_file.name}: {e}")
continue
# Save all node responses for this slot
if responses:
output_date_dir = Path(output_dir) / date
output_file = output_date_dir / f"{slot_time}.npy"
np.save(output_file, responses, allow_pickle=True)
return slot_time
def process_batch(
date: str,
batch_dir: str = "batch",
output_dir: str = "output",
model_path: str = "yolo11n.pt",
workers: int = None,
):
"""Process all slots for a given date across all nodes."""
# Get all node directories
batch_path = Path(batch_dir)
node_dirs = [d for d in batch_path.iterdir() if d.is_dir()]
if not node_dirs:
print(f"Error: No node directories found in {batch_dir}")
return
print(f"Found {len(node_dirs)} nodes")
# Collect all slots grouped by slot timestamp
slots_by_time = defaultdict(list) # {slot_time: [(node, file_path), ...]}
for node_dir in node_dirs:
node_id = node_dir.name
# Find slot files for this date
for slot_file in node_dir.glob(f"batch_{date}*.json"):
# Extract slot time from filename: batch_2025-11-12T11:10:00.json -> 2025-11-12T11:10:00
slot_time = slot_file.stem.replace("batch_", "")
slots_by_time[slot_time].append((node_id, slot_file))
if not slots_by_time:
print(f"No slots found for date {date}")
return
print(f"Found {len(slots_by_time)} unique slots")
# Create output directory
output_date_dir = Path(output_dir) / date
output_date_dir.mkdir(parents=True, exist_ok=True)
# Shared store for previous vehicles_count values across workers
manager = Manager()
prev_counts = manager.dict()
# Prepare arguments for parallel processing
tasks = [
(slot_time, node_files, model_path, output_dir, date, prev_counts)
for slot_time, node_files in sorted(slots_by_time.items())
]
# Determine number of workers
if workers is None:
workers = min(cpu_count(), len(tasks))
print(f"Processing with {workers} workers")
# Process slots in parallel
with Pool(processes=workers) as pool:
list(
tqdm(
pool.imap(process_slot, tasks),
total=len(tasks),
desc=f"Processing {date}",
)
)
print(f"\nResults saved to: {output_dir}/{date}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Process batch data by date")
parser.add_argument("--date", required=True, help="Date to process (YYYY-MM-DD)")
parser.add_argument("--batch-dir", default="batch", help="Batch input directory")
parser.add_argument("--output-dir", default="output", help="Output directory")
parser.add_argument("--model", default="yolo11n.pt", help="YOLO model path")
parser.add_argument(
"--workers",
type=int,
default=None,
help="Number of parallel workers (default: auto)",
)
args = parser.parse_args()
process_batch(args.date, args.batch_dir, args.output_dir, args.model, args.workers)