Skip to content

Focuses on practical examples for PINNs using JAX.

Notifications You must be signed in to change notification settings

adiManethia/JAX-PINN

Repository files navigation

JAX-PINN-Tutorials

This repository contains Jupyter notebooks that provide an introduction to JAX (a high-performance numerical computing library) and Physics-Informed Neural Networks (PINNs). The focus is on theoretical concepts, practical implementations, and bridging machine learning with physical simulations.

  • JAX_Tutorial.ipynb: A beginner-friendly guide to JAX fundamentals, including array operations, just-in-time (JIT) compilation, automatic differentiation, vectorization, randomness handling, and training a simple neural network on the Iris dataset. It also explores hardware acceleration using TPUs.
  • PINN.ipynb: An in-depth exploration of PINNs, covering theoretical foundations, mathematical formulations, and a practical implementation for parameter identification in partial differential equations (PDEs), such as the 1D Burgers' equation (with elements of the heat equation for illustration).

The repository emphasizes detailed theoretical explanations to help users understand the underlying principles, making it suitable for researchers, students, and practitioners in scientific computing, machine learning, and physics-based modeling. Visualizations and plots are included in the notebooks.


Table of Contents


Introduction to JAX

JAX is a Python library developed by Google for high-performance machine learning research and numerical computing. It combines the flexibility of NumPy with automatic differentiation, just-in-time compilation, and hardware acceleration (CPU, GPU, TPU). Theoretically, JAX builds on functional programming and composable transformations, allowing computations that can be differentiated, vectorized, or compiled efficiently. This is powerful for gradient-based tasks like optimization in ML or inverse problems in physics.

In JAX_Tutorial.ipynb, we start with basic array operations and progressively build up to training a multi-layer perceptron on the Iris dataset. The notebook demonstrates how JAX enables composable transformations (e.g., gradients, vectorization) while maintaining immutability and functional purity for better performance and parallelism.


Key JAX Concepts

JAX as Accelerated NumPy

  • JAX provides a NumPy-like API (jax.numpy or jnp) for array operations, with key differences:
    • Arrays are immutable to enable functional programming and avoid side effects.
    • Operations are accelerated on hardware like GPUs/TPUs and optimized for XLA (Accelerated Linear Algebra) compilation.
  • Theory (why it’s fast):
    • JAX leverages XLA to compile numerical computations into optimized kernels by tracing the computation graph, fusing operations to reduce memory bandwidth usage, and applying optimizations like common subexpression elimination.
    • Immutability ensures pure functions, enabling safe parallelism and differentiation without hidden state dependencies.

Just-in-Time (JIT) Compilation with jax.jit

  • jax.jit compiles Python functions into optimized XLA code for faster execution, especially on repeated calls.
    • Traces once (warmup) and caches the compiled executable.
    • Use jax.make_jaxpr() to inspect the intermediate representation (JAXPR).
  • Theory:
    • JIT turns dynamic Python into static hardware instructions via shape polymorphism and partial evaluation. JAX traces with abstract shapes, compiles the graph, and specializes for concrete inputs. Prefer static-friendly control flow (e.g., jnp.where, lax.cond, lax.scan) to avoid retracing.

Automatic Differentiation with jax.grad

  • JAX computes gradients via forward/reverse-mode AD.
    • jax.grad(f) returns a function that computes the gradient of f with respect to its inputs.
    • Supports higher-order derivatives (e.g., jax.grad(jax.grad(f)) for Hessians), and partial derivatives via argnums for multi-variable functions.
  • Theory:
    • AD decomposes complex functions into primitive operations and applies the chain rule. In reverse mode (efficient for many inputs, one output like loss functions), it propagates adjoints backward. Forward mode suits few inputs, many outputs. JAX’s AD integrates with JIT for compiled gradients, crucial for efficiency in large-scale optimization.

Automatic Vectorization with jax.vmap

  • jax.vmap applies a function across batches in parallel, avoiding explicit Python loops.
    • Transforms a function operating on scalars/vectors into one for batches (e.g., matrix multiplications in neural networks).
  • Theory:
    • Vectorization maps a function over leading axes, enabling batched computations. It’s a higher-order transformation lifting scalar ops to tensor ones, exploiting SIMD parallelism. Composed with grad (e.g., jax.vmap(jax.grad(f))), it computes per-example gradients efficiently, as in mini-batch training.

Randomness Handling in JAX

  • JAX uses explicit PRNG keys (jax.random.key) instead of global seeds for reproducibility and parallelism.
    • Split keys (jax.random.split) to generate independent sub-keys for parallel operations.
  • Theory:
    • Traditional RNGs rely on mutable state, conflicting with functional purity. JAX’s splittable PRNG (Threefry/Philox) allows deterministic, parallel-safe randomness. Keys are hashed and split into independent streams, ensuring reproducibility across devices without synchronization overhead.

Optimization with Optax

  • Optax is a gradient-based optimization library for JAX, providing optimizers like Adam, SGD, etc.
    • Used for updating parameters during training (e.g., optax.adam).
  • Theory:
    • Optax modularizes gradient transformations (momentum, adaptive learning rates). Adam combines momentum with per-parameter variance scaling, converging faster on noisy gradients. Rooted in stochastic gradient descent with approximations to second-order information, it’s robust for non-convex NN training.

Hardware Acceleration (e.g., TPUs)

  • JAX supports seamless switching between CPU, GPU, and TPU backends via jax.devices().
    • Operations like matrix multiplications are optimized and can be sharded across devices via jax.pmap/jax.sharding.
  • Theory:
    • TPUs use systolic arrays for matrix ops, enabling massive FLOPs with low power. XLA performs loop fusion and memory layout tuning. Sharding distributes data/models across devices, supporting data/model parallelism for scalable computations.

Training a Neural Network in JAX

In JAX_Tutorial.ipynb, we implement a feedforward neural network from scratch:

  • Initialize parameters with random keys.
  • Define forward pass and loss (e.g., cross-entropy with L2 regularization).
  • Use jax.jit for efficient training steps.
  • Train on the Iris dataset, achieving high accuracy.

Theory:

  • Neural networks approximate functions via compositions of linear transformations and non-linear activations. Backpropagation (via AD) computes gradients for weight updates. JAX’s functional style treats parameters as inputs, enabling pure functions for easier composition and parallelism. Regularization (L2) prevents overfitting by penalizing large weights, acting as a smoothness prior.

Introduction to Physics-Informed Neural Networks (PINNs)

PINNs integrate physical laws (e.g., PDEs) into neural network training, enabling data-efficient solutions to forward/inverse problems. In PINN.ipynb, we cover theory and implement a PINN for parameter identification in the 1D Burgers' equation (viscous fluid flow), with references to the heat equation.

Theoretical Foundations

  • PINNs approximate PDE solutions using neural networks while enforcing physics via loss terms.
  • Why PINNs?
    • Traditional solvers (e.g., finite elements) require meshes and can struggle with high dimensions or inverse problems.
    • PINNs are mesh-free, leveraging universal approximation theorems (NNs can approximate continuous functions) and physics as an inductive bias.
  • Theory expansion:
    • In inverse problems, PINNs treat parameters as learnable variables, optimizing jointly with the solution. PDE residuals act as unsupervised regularization, reducing reliance on data. PINNs align with Bayesian inference by incorporating physics-informed priors.

Mathematical Formulation

For a PDE N[u](t, x; λ) = 0 with boundaries/IC u(t, x) = g(t, x),

  • Approximate u(t, x) ≈ u_θ(t, x) with a neural network.
  • Enforce the PDE residual at collocation points.

Example (Burgers' equation): u_t + u u_x - ν u_xx = 0, where ν is viscosity (inferred parameter).

  • Theory expansion:
    • The PDE operator N can be nonlinear or time-dependent. Collocation sampling (e.g., Latin hypercube) ensures coverage. For time-dependent PDEs, PINNs treat time as an input, solving the space-time domain at once. This continuous representation allows analytic differentiation via AD, avoiding grid-based errors.

Loss Function and Automatic Differentiation

Total loss: L = L_data + L_PDE + L_BC/IC.

  • L_PDE uses AD to compute derivatives (e.g., u_t, u_xx).
  • Training uses gradient descent on collocation/IC/BC points.
  • Theory expansion:
    • AD enables analytic derivatives, avoiding numerical errors in finite differences. Loss balancing (via weights) addresses scale differences; adaptive methods (e.g., NTK-based) dynamically adjust. For stiff PDEs, vanishing/exploding gradients can occur; mitigations include Fourier features and curriculum learning.

Implementation for Parameter Identification

  • Use Flax for the NN architecture.

  • Infer parameters (λ1, λ2) in the PDE form u_t + λ1 u u_x - λ2 u_xx = 0.

  • Visualizations: solution snapshots, loss curves, parameter convergence.

  • Theory expansion:

    • In inverse problems, parameters λ are optimized alongside θ, turning the PDE into a constraint optimization embedded in the loss. Convergence depends on the loss landscape; multi-fidelity approaches (combining low/high-res data) enhance robustness. Burgers' equation highlights nonlinear dynamics where (λ1, λ2) control advection and diffusion.

Advantages and Limitations of PINNs

  • Advantages:
    • Mesh-free and dimension-agnostic, suitable for complex geometries.
    • Handle multi-physics and uncertainties via ensembles.
    • Enable discovery of hidden physics through parameter inference.
  • Limitations:
    • Training instability for advection-dominated PDEs.
    • Spectral bias in NNs (struggle with high frequencies).
    • Scalability challenges for very high dimensions due to computational cost.
  • Theory expansion:
    • Remedies include causal training (sequential time windows), adaptive activation functions, Fourier features, and operator learning (e.g., DeepONet, FNO). PINNs relate to variational methods, implicitly minimizing energy functionals.

Getting Started

Clone the repo

git clone https://github.com/adiManethia/JAX-PINN-Tutorials.git
cd JAX-PINN-Tutorials

About

Focuses on practical examples for PINNs using JAX.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published