Skip to content

[DEMO] High-Performance Quanvolutional Neural Networks with JAX & Flax #1621

@Spartoons

Description

@Spartoons

General information

Name
Aran Oliveras (@Spartoons)

Affiliation (optional)
Computer Vision Center (CVC), QML Group, Universitat Autònoma de Barcelona (UAB)

Twitter (optional)
@aranoliveras

Image (optional)
Image


Demo information

Title
High-Performance Quanvolutional Neural Networks with JAX & Flax

Abstract
This demo presents a high-performance port of the classic Quanvolutional Neural Network tutorial, migrated from TensorFlow to the JAX/Flax ecosystem. By leveraging JAX's vmap and JIT compilation, it demonstrates how to vectorize the execution of quantum circuits over image patches, achieving massive speedups (over 10,000x faster processing per image compared to standard loops). It includes a custom, stateless training loop using Optax and Flax.

Relevant links

Metadata

Metadata

Assignees

No one assigned

    Labels

    demosUpdating the demonstrations/tutorials

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions