Skip to content

Distrax + Flax bijector error and best practices #263

@JamesAllingham

Description

@JamesAllingham

I've encountered a small error when implementing Distrax bijectors with Flax conditioner NNs, and I also have a question about best practices for using Distrax with Flax.

The error can be reproduced with the following setup (also in this Colab notebook https://colab.research.google.com/drive/1RLRZul_pHnglcT_-YZ7mcuKLU1qd3w5O?usp=sharing).

class Conditioner(nn.Module):
    event_shape: Sequence[int]
    num_bijector_params: int
    hidden_dims: Sequence[int]

    @nn.compact
    def __call__(self, z: Array, h: Array) -> Array:
        h = jnp.concatenate((z.flatten(), h.flatten()), axis=0)

        for hidden_dim in self.hidden_dims:
            h = nn.Dense(hidden_dim)(h)
            h = nn.relu(h)

        y = nn.Dense(np.prod(self.event_shape) * self.num_bijector_params)(h)
        y = y.reshape(tuple(self.event_shape) + (self.num_bijector_params,))

        return y

class MyModel(nn.Module):
    hidden_dims: Sequence[int]
    num_flows: int
    num_bins: int
    event_shape: Sequence[int]
    conditioner: Optional[KwArgs] = None

    @nn.compact
    def __call__(self, x, y: Optional[Array] = None):
        # base distribution
        output_dim = np.prod(self.event_shape)
        base = distrax.Independent(
            distrax.Normal(loc=jnp.zeros(output_dim,), scale=jnp.ones(output_dim,)), len(self.event_shape)
        )

        # bijector
        # Number of parameters for the rational-quadratic spline:
        # - `num_bins` bin widths
        # - `num_bins` bin heights
        # - `num_bins + 1` knot slopes
        # for a total of `3 * num_bins + 1` parameters.
        num_bijector_params = 3 * self.num_bins + 1

        layers = []
        mask = jnp.arange(0, np.prod(self.event_shape)) % 2
        mask = jnp.reshape(mask, self.event_shape)
        mask = mask.astype(bool)

        def bijector_fn(params: Array):
            return distrax.RationalQuadraticSpline(
                params, range_min=-3.0, range_max=3.0
            )

        h = x.flatten()

        # shared feature extractor
        for hidden_dim in self.hidden_dims:
            h = nn.Dense(hidden_dim)(h)
            h = nn.relu(h)

        for i in range(self.num_flows):
            conditioner = Conditioner(
                event_shape=self.event_shape,
                num_bijector_params=num_bijector_params,
                **(self.conditioner or {}),
            )

            layer = distrax.MaskedCoupling(
                mask=mask,
                bijector=bijector_fn,
                conditioner=functools.partial(conditioner, h=h),
            )

            layers.append(layer)
            mask = ~mask

        bijector = distrax.Inverse(distrax.Chain(layers))
        transformed = distrax.Transformed(base, bijector)

        if y is not None:
            return transformed, transformed.log_prob(y)
        else:
            return transformed
            
            
model = MyModel(
    hidden_dims = [64, 32],
    num_flows = 3,
    num_bins = 8,
    event_shape = (6,),
    conditioner = {'hidden_dims': [64, 32]}
)

variables = model.init(random.PRNGKey(0), jnp.empty((28, 28, 1)), y=jnp.empty((6,)))

dist = model.apply(variables, jnp.ones((28, 28, 1)))

dist.event_shape

Which raises the following error:

JaxTransformError: Jax transforms and Flax models cannot be mixed. (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.JaxTransformError)

Thankfully, evaluating log probs, i.e., dist.log_prob(jnp.zeros(6,)), runs without any error.

Any idea why this is happening? Am I doing something wrong when constructing the model?

On that note, I've also found that if I initialize the parameters like this:

variables = model.init(random.PRNGKey(0), jnp.empty((28, 28, 1)))

The parameters for the conditioner are not instantiated. To fix this, I've used the workaround of evaluating the log prob of some dummy data when initializing the model:

variables = model.init(random.PRNGKey(0), jnp.empty((28, 28, 1)), y=jnp.empty((6,)))

But this feels a little hacky to me and suggests that perhaps I am doing something wrong in my model definition. Do you have a set of best practices for using Flax with Distrax (now that Haiku is deprecated)?

Thanks for the help!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions