Skip to content

Core & custom Jaxprs #62

@rlouf

Description

@rlouf

Here are a few ideas for the core that I won't have the time to do before the first release but have the potential to make MCX more general and greatly simplify the bijectors API.

First, a shortcoming of many PPLs is the impossibility of conditioning on deterministic transformations of random variables. This is because the logpdf function would need to propagate the inverse of the log-determinant of the jacobian matrix for volume conservation. This seems to be a job for Jaxprs. The idea would be to have the core compile the graph in a way that can be manipulated by JAX and create a "logpdf" Jaxpr that is applied on this function.

Then, if this works, we could only implement the "forward" part of bijectors. The logpdf Jaxpr would automatically take care of conserving the volume. Writing a Jaxpr that inverses the transformation is, if not easy, possible.

Graph --> JAX-ready logpdf --> logpdf Jaxpr
Graph --> Joint distribution forward sampler
Graph --> Predictive distribution sampler

Sampling and sampling predictive are simple enough that they can be left as is.

As a result we would have a two-layer core:

  • A "conceptual" graph that allows to reason about distributions directly. Which allows to identify conjugacy relationships that can be collapsed, transformations to apply to random variables with constrained support, the samplers that are best adapted to each variable, etc.
  • A computational layer that relies on custom JAX primitives to compute the logprob of transformed random variables, simplifies the computation when possible.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions