Skip to content

Commit f6fbe7c

Browse files
committed
feat :: test ai model
1 parent 3b1dad7 commit f6fbe7c

File tree

1 file changed

+57
-0
lines changed

1 file changed

+57
-0
lines changed

scripts/create_test_model.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
#!/usr/bin/env python3
2+
"""Create a simple ONNX model for testing airML."""
3+
4+
import torch
5+
import torch.nn as nn
6+
import os
7+
8+
class SimpleModel(nn.Module):
9+
"""Simple CNN for image classification (10 classes)."""
10+
def __init__(self):
11+
super().__init__()
12+
self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
13+
self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
14+
self.pool = nn.AdaptiveAvgPool2d(1)
15+
self.fc = nn.Linear(32, 10)
16+
17+
def forward(self, x):
18+
x = torch.relu(self.conv1(x))
19+
x = torch.relu(self.conv2(x))
20+
x = self.pool(x)
21+
x = x.view(x.size(0), -1)
22+
x = self.fc(x)
23+
return x
24+
25+
def main():
26+
# Create model
27+
model = SimpleModel()
28+
model.eval()
29+
30+
# Dummy input (batch=1, channels=3, height=224, width=224)
31+
dummy_input = torch.randn(1, 3, 224, 224)
32+
33+
# Export to ONNX using dynamo=False (legacy export, single file)
34+
output_path = "models/simple_cnn.onnx"
35+
36+
# Remove old files
37+
for f in [output_path, output_path + ".data"]:
38+
if os.path.exists(f):
39+
os.remove(f)
40+
41+
torch.onnx.export(
42+
model,
43+
dummy_input,
44+
output_path,
45+
input_names=["input"],
46+
output_names=["output"],
47+
opset_version=13,
48+
dynamo=False, # Use legacy export (single file)
49+
)
50+
51+
print(f"Model saved to {output_path}")
52+
print(f"Input shape: [1, 3, 224, 224]")
53+
print(f"Output shape: [1, 10]")
54+
print(f"File size: {os.path.getsize(output_path) / 1024:.1f} KB")
55+
56+
if __name__ == "__main__":
57+
main()

0 commit comments

Comments
 (0)