Skip to content

htil/onnx-model-conversion-demo

 
 

Repository files navigation

ONNX Framework Interoperability Demo

Proof of concept: Bidirectional model transfer between TensorFlow.js (JavaScript) and PyTorch (Python) using ONNX as the universal exchange format.

What This Proves

A model trained in one framework can be loaded and run in the other, with numerically identical predictions. ONNX acts as the universal bridge between JavaScript and Python ML ecosystems.

Direction 1:  TF.js (Node.js)  -->  weight JSON  -->  PyTorch  -->  ONNX file
Direction 2:  PyTorch           -->  ONNX file    -->  Node.js (onnxruntime-node)

Task

Simple aerospace regression: predict atmospheric density (kg/m³) from altitude (km) using the barometric formula with synthetic data.

Model architecture (identical in both frameworks):

Input(1) -> Dense(32, ReLU) -> Dense(16, ReLU) -> Dense(1)

Quick Start

1. Create a virtual environment (recommended)

cd C:\Users\Amon\Desktop\onnx_simple_demo
python -m venv onnx_env
onnx_env\Scripts\activate

2. Install Python dependencies

pip install numpy>=1.26.0
pip install torch>=2.2.0
pip install onnx==1.20.1 --only-binary=onnx
pip install onnxruntime==1.24.1

Why --only-binary=onnx? This forces pip to use a pre-built wheel instead of compiling from C++ source. ONNX 1.20.1 ships abi3 wheels that cover Python 3.12+, including 3.13. Older versions lacked cp313 wheels and would attempt (and fail) a source build on Windows due to protobuf/CMake path-length issues.

3. Install Node.js dependencies

npm install

4. Run the demo

python run_demo.py

Or run steps manually:

python generate_data.py
node tfjs_train.js
python pytorch_train.py
python tfjs_to_pytorch.py
python pytorch_to_onnx.py
node validate_onnx_node.js

Files

File Language Purpose
generate_data.py Python Creates synthetic atmosphere CSV
tfjs_train.js Node.js Trains model in TF.js, exports weights as JSON
pytorch_train.py Python Trains same architecture in PyTorch
tfjs_to_pytorch.py Python Loads TF.js weights into PyTorch, exports ONNX
pytorch_to_onnx.py Python Exports PyTorch model to ONNX, validates
validate_onnx_node.js Node.js Loads ONNX files in JavaScript, cross-validates
run_demo.py Python Orchestrates all steps

Dependency Rationale

What was removed vs. the previous version

The previous project required TensorFlow (Python), tf2onnx, tensorflowjs, and tf-keras — adding ~2GB of dependencies and creating version conflicts (especially tensorflow-addons EOL). This version eliminates all TensorFlow Python dependencies by:

  1. Direction 1 (TF.js → PyTorch): Instead of TF.js → tensorflowjs_converter → TF SavedModel → tf2onnx → ONNX → PyTorch, we now export weights as plain JSON from TF.js and load them directly into PyTorch with a simple transpose. No TensorFlow Python needed.

  2. Direction 2 (PyTorch → JS): Instead of trying to convert ONNX back to TensorFlow format (which required the deprecated onnx-tf), we use onnxruntime-node to run the ONNX file directly in JavaScript. This is actually how ONNX is used in production.

Python packages (4 total, down from 10)

Package Version Why this version
torch ≥2.2.0 Stable ONNX export via torch.onnx.export
numpy ≥1.26.0 Required by torch; 1.26+ has cp313 wheels
onnx 1.20.1 Ships abi3 wheels (works on Python 3.12–3.13+ without compiling)
onnxruntime 1.24.1 Native cp313-win_amd64 wheel; validates ONNX files

Node.js packages (3 total)

Package Purpose
@tensorflow/tfjs TF.js core (model building/training)
@tensorflow/tfjs-node C++ backend for Node.js (fast training)
onnxruntime-node Run ONNX models in JavaScript

Weight Transfer Convention

The critical insight enabling Direction 1 without TensorFlow:

TF.js Dense kernel:    shape [in_features, out_features]
PyTorch Linear weight: shape [out_features, in_features]

Conversion: pytorch_weight = transpose(tfjs_kernel)
Biases:     same shape in both frameworks, no conversion needed

Expected Output

Direction 1 (TF.js -> PyTorch): Max error < 1e-5
Direction 2 (PyTorch -> ONNX):  Max error < 1e-6
Cross-language (Node.js ONNX):  Max error < 1e-5

Errors are purely floating-point precision differences between JavaScript and Python runtimes.

Prerequisites

  • Python 3.10+ (tested on 3.13)
  • Node.js 18+ LTS
  • Windows, macOS, or Linux

About

This project demonstrates interoperability between Tensorflow.js and PyTorch ML models.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages

  • Python 74.7%
  • JavaScript 23.8%
  • Batchfile 1.5%