-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathconvert_backbone_tensorrt.py
More file actions
409 lines (314 loc) · 12.6 KB
/
convert_backbone_tensorrt.py
File metadata and controls
409 lines (314 loc) · 12.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
#!/usr/bin/env python3
"""
Convert DINOv3 Backbone to TensorRT.
Usage:
python convert_backbone_tensorrt.py --export_onnx
python convert_backbone_tensorrt.py --convert_trt
python convert_backbone_tensorrt.py --benchmark
python convert_backbone_tensorrt.py --all
The backbone accepts:
Input: [B, 3, 512, 512] RGB image (normalized)
Output: [B, 1280, 32, 32] feature map
Instructions:
Step 1: Export and convert to TensorRT
# All-in-one: export ONNX + convert TensorRT + benchmark
python convert_backbone_tensorrt.py --all
# Or run steps individually:
python convert_backbone_tensorrt.py --export_onnx # Export ONNX
python convert_backbone_tensorrt.py --convert_trt # Convert to TensorRT
python convert_backbone_tensorrt.py --benchmark # Performance comparison
Step 2: Run inference with TensorRT
# Set environment variable to enable TensorRT backbone
USE_TRT_BACKBONE=1 python profile_nsight.py --image_path ./notebook/images/dancing.jpg --detector yolo --detector_model ./checkpoints/yolo/yolo11m.engine
# Or specify a custom engine path
USE_TRT_BACKBONE=1 TRT_BACKBONE_PATH=/path/to/backbone.engine python demo_human.py ...
"""
import argparse
import os
import sys
import time
import torch
import torch.nn as nn
# Add parent directory to path
parent_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, parent_dir)
# Default paths
CHECKPOINT_DIR = os.path.join(parent_dir, "checkpoints", "sam-3d-body-dinov3")
TRT_OUTPUT_DIR = os.path.join(CHECKPOINT_DIR, "backbone_trt")
ONNX_PATH = os.path.join(TRT_OUTPUT_DIR, "backbone_dinov3.onnx")
TRT_PATH_BF16 = os.path.join(TRT_OUTPUT_DIR, "backbone_dinov3_bf16.engine")
TRT_PATH_FP16 = os.path.join(TRT_OUTPUT_DIR, "backbone_dinov3_fp16.engine")
TRT_PATH = TRT_PATH_FP16 # Default to FP16 for better TensorRT optimization
# Model config
IMAGE_SIZE = (512, 512) # H, W
EMBED_DIM = 1280 # dinov3_vith16plus
PATCH_SIZE = 16
OUTPUT_SIZE = (32, 32) # 512 / 16 = 32
class BackboneWrapper(nn.Module):
"""
Wrapper for DINOv3 backbone that exposes a simple forward interface.
This wraps the get_intermediate_layers call for ONNX export.
"""
def __init__(self, encoder):
super().__init__()
self.encoder = encoder
def forward(self, x):
"""
Args:
x: [B, 3, H, W] normalized RGB image
Returns:
y: [B, C, H', W'] feature map
"""
# get_intermediate_layers returns a list of features
# We take the last one with reshape=True, norm=True
y = self.encoder.get_intermediate_layers(x, n=1, reshape=True, norm=True)[-1]
return y
def load_backbone():
"""Load the DINOv3 backbone from checkpoint."""
from sam_3d_body.build_models import load_sam_3d_body
# Load model using checkpoint path
ckpt_path = os.path.join(CHECKPOINT_DIR, "model.ckpt")
mhr_path = os.path.join(CHECKPOINT_DIR, "assets", "mhr_model.pt")
model, cfg = load_sam_3d_body(checkpoint_path=ckpt_path, mhr_path=mhr_path)
# Extract backbone
backbone = model.backbone
backbone.eval()
print(f"Backbone type: {cfg.MODEL.BACKBONE.TYPE}")
print(f"Embed dim: {backbone.embed_dim}")
print(f"Patch size: {backbone.patch_size}")
return backbone
def step1_export_onnx(backbone, batch_sizes=[1, 2, 4]):
"""Export backbone to ONNX with dynamic batch size."""
print("=" * 60)
print("Step 1: Export Backbone to ONNX")
print("=" * 60)
# Ensure output directory exists
os.makedirs(TRT_OUTPUT_DIR, exist_ok=True)
print(f" Output directory: {TRT_OUTPUT_DIR}")
# Create wrapper and convert to FP32 for ONNX export
# (TensorRT will handle FP16 conversion during engine build)
wrapper = BackboneWrapper(backbone.encoder)
wrapper.eval()
wrapper.float() # Convert to FP32
wrapper.cuda()
# Test input (FP32)
dummy_input = torch.randn(1, 3, *IMAGE_SIZE, device="cuda", dtype=torch.float32)
# Verify output
with torch.no_grad():
output = wrapper(dummy_input)
print(f" Input shape: {dummy_input.shape}")
print(f" Output shape: {output.shape}")
# Export to ONNX with dynamic batch
print(" Exporting to ONNX...")
# Dynamic axes for batch dimension
dynamic_axes = {
"input": {0: "batch_size"},
"output": {0: "batch_size"}
}
torch.onnx.export(
wrapper,
dummy_input,
ONNX_PATH,
input_names=["input"],
output_names=["output"],
opset_version=17,
do_constant_folding=True,
dynamic_axes=dynamic_axes,
)
print(f" [SUCCESS] Saved to: {ONNX_PATH}")
print(f" File size: {os.path.getsize(ONNX_PATH) / 1024 / 1024:.1f} MB")
# Verify ONNX
try:
import onnx
model = onnx.load(ONNX_PATH)
onnx.checker.check_model(model)
print(" ONNX model verified!")
except Exception as e:
print(f" Warning: ONNX verification failed: {e}")
return True
def step2_convert_tensorrt(batch_sizes=[1, 2, 4]):
"""Convert ONNX to TensorRT with FP16."""
print("\n" + "=" * 60)
print("Step 2: Convert to TensorRT (FP16)")
print("=" * 60)
try:
import tensorrt as trt
except ImportError:
print(" [ERROR] TensorRT not installed")
print(" Install with: pip install tensorrt")
return False
logger = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(logger)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, logger)
# Parse ONNX (use parse_from_file for external data support)
print(" Parsing ONNX...")
print(f" ONNX file: {ONNX_PATH}")
# Use parse_from_file to handle ONNX with external data
if not parser.parse_from_file(ONNX_PATH):
for i in range(parser.num_errors):
print(f" [ERROR] {parser.get_error(i)}")
return False
# Print network info
print(f" Network inputs: {network.num_inputs}")
for i in range(network.num_inputs):
inp = network.get_input(i)
print(f" {inp.name}: {inp.shape}")
print(f" Network outputs: {network.num_outputs}")
for i in range(network.num_outputs):
out = network.get_output(i)
print(f" {out.name}: {out.shape}")
# Build config
config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 4 << 30) # 4GB
# Use FP16 precision for internal compute and I/O
# (FP16 is better optimized in TensorRT than BF16)
config.set_flag(trt.BuilderFlag.FP16)
# Set input/output layers to use FP16
for i in range(network.num_inputs):
inp = network.get_input(i)
inp.dtype = trt.float16
for i in range(network.num_outputs):
out = network.get_output(i)
out.dtype = trt.float16
print(" Using FP16 precision (compute + I/O)")
# Optimization profile for dynamic batch size
profile = builder.create_optimization_profile()
# Set min/opt/max shapes for batch dimension
min_batch = min(batch_sizes)
opt_batch = batch_sizes[len(batch_sizes) // 2] if len(batch_sizes) > 1 else batch_sizes[0]
max_batch = max(batch_sizes)
profile.set_shape(
"input",
(min_batch, 3, IMAGE_SIZE[0], IMAGE_SIZE[1]), # min
(opt_batch, 3, IMAGE_SIZE[0], IMAGE_SIZE[1]), # opt
(max_batch, 3, IMAGE_SIZE[0], IMAGE_SIZE[1]), # max
)
config.add_optimization_profile(profile)
print(f" Batch size range: [{min_batch}, {opt_batch}, {max_batch}]")
# Build engine
print(" Building TensorRT engine (this may take several minutes)...")
engine = builder.build_serialized_network(network, config)
if engine is None:
print(" [ERROR] Engine build failed")
return False
with open(TRT_PATH, "wb") as f:
f.write(engine)
print(f" [SUCCESS] Saved to: {TRT_PATH}")
print(f" File size: {os.path.getsize(TRT_PATH) / 1024 / 1024:.1f} MB")
return True
class TRTBackbone:
"""TensorRT inference wrapper for backbone."""
def __init__(self, engine_path):
import tensorrt as trt
self.logger = trt.Logger(trt.Logger.WARNING)
self.runtime = trt.Runtime(self.logger)
with open(engine_path, "rb") as f:
self.engine = self.runtime.deserialize_cuda_engine(f.read())
self.context = self.engine.create_execution_context()
# Get binding info
self.input_name = "input"
self.output_name = "output"
def __call__(self, x):
"""
Run TensorRT inference.
Args:
x: [B, 3, H, W] input tensor (BF16 or FP32)
Returns:
output: [B, C, H', W'] feature map (BF16)
"""
batch_size = x.shape[0]
# Set input shape for dynamic batch
self.context.set_input_shape(self.input_name, x.shape)
# Allocate output buffer (FP16)
output = torch.empty(
batch_size, EMBED_DIM, OUTPUT_SIZE[0], OUTPUT_SIZE[1],
device=x.device, dtype=torch.float16
)
# Set tensor addresses
x_fp16 = x.half() if x.dtype != torch.float16 else x
self.context.set_tensor_address(self.input_name, x_fp16.data_ptr())
self.context.set_tensor_address(self.output_name, output.data_ptr())
# Execute
self.context.execute_async_v3(torch.cuda.current_stream().cuda_stream)
return output
def step3_benchmark(backbone):
"""Benchmark PyTorch vs TensorRT."""
print("\n" + "=" * 60)
print("Step 3: Benchmark")
print("=" * 60)
# Test different batch sizes
for batch_size in [1, 2]:
print(f"\n Batch size: {batch_size}")
# Test input
x = torch.randn(batch_size, 3, *IMAGE_SIZE, device="cuda", dtype=torch.float32)
# PyTorch (BF16)
print(" [PyTorch BF16]")
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
# Warmup
for _ in range(5):
with torch.no_grad():
_ = backbone(x)
torch.cuda.synchronize()
# Benchmark
start = time.perf_counter()
for _ in range(50):
with torch.no_grad():
_ = backbone(x)
torch.cuda.synchronize()
pt_time = (time.perf_counter() - start) * 1000 / 50
print(f" Time: {pt_time:.3f} ms/call")
# TensorRT
if os.path.exists(TRT_PATH):
print(" [TensorRT BF16]")
trt_backbone = TRTBackbone(TRT_PATH)
# Warmup
for _ in range(5):
_ = trt_backbone(x)
torch.cuda.synchronize()
# Benchmark
start = time.perf_counter()
for _ in range(50):
_ = trt_backbone(x)
torch.cuda.synchronize()
trt_time = (time.perf_counter() - start) * 1000 / 50
print(f" Time: {trt_time:.3f} ms/call")
print(f" Speedup: {pt_time / trt_time:.2f}x")
# Verify output
with torch.cuda.amp.autocast(dtype=torch.float16):
with torch.no_grad():
pt_out = backbone(x)
trt_out = trt_backbone(x)
diff = (pt_out.float() - trt_out.float()).abs()
print(f" Max diff: {diff.max().item():.6f}")
print(f" Mean diff: {diff.mean().item():.6f}")
else:
print(" [TensorRT] Engine not found, skipping...")
def main():
parser = argparse.ArgumentParser(description="Convert DINOv3 Backbone to TensorRT")
parser.add_argument("--export_onnx", action="store_true", help="Export to ONNX")
parser.add_argument("--convert_trt", action="store_true", help="Convert ONNX to TensorRT")
parser.add_argument("--benchmark", action="store_true", help="Benchmark PyTorch vs TensorRT")
parser.add_argument("--all", action="store_true", help="Run all steps")
parser.add_argument("--batch_sizes", type=str, default="1,2,4", help="Batch sizes for optimization")
args = parser.parse_args()
batch_sizes = [int(x) for x in args.batch_sizes.split(",")]
if args.all:
args.export_onnx = True
args.convert_trt = True
args.benchmark = True
if not any([args.export_onnx, args.convert_trt, args.benchmark]):
parser.print_help()
return
backbone = None
if args.export_onnx or args.benchmark:
print("Loading backbone...")
backbone = load_backbone()
if args.export_onnx:
step1_export_onnx(backbone, batch_sizes)
if args.convert_trt:
step2_convert_tensorrt(batch_sizes)
if args.benchmark:
step3_benchmark(backbone)
if __name__ == "__main__":
main()