From d2585a318a49affbf265501ad902bd1b1242ad2f Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 17 Mar 2026 19:18:51 +0000 Subject: [PATCH 1/3] Fixed transformations --- HISTORY.md | 87 ++++- docs/src/accs/existing.md | 2 + docs/src/accs/values.md | 34 +- docs/src/api.md | 25 +- docs/src/evaluation.md | 2 +- docs/src/init.md | 18 +- docs/src/migration.md | 6 +- docs/src/tilde.md | 14 +- docs/src/transforms.md | 4 +- docs/src/vnt/arraylikeblocks.md | 2 +- ext/DynamicPPLEnzymeCoreExt.jl | 2 +- ext/DynamicPPLMarginalLogDensitiesExt.jl | 4 +- ext/DynamicPPLMooncakeExt.jl | 4 +- src/DynamicPPL.jl | 25 +- src/abstract_varinfo.jl | 218 +++-------- src/accumulators.jl | 8 +- src/accumulators/bijector.jl | 67 ---- src/accumulators/linked_vec_transforms.jl | 32 ++ src/accumulators/vector_params.jl | 29 +- src/accumulators/vector_values.jl | 83 ++-- src/contexts/default.jl | 10 +- src/contexts/init.jl | 80 ++-- src/logdensityfunction.jl | 63 +-- src/onlyaccs.jl | 2 +- src/submodel.jl | 79 ++-- src/test_utils/models.jl | 45 --- src/test_utils/varinfo.jl | 10 +- src/threadsafe.jl | 8 +- src/transformed_values.jl | 359 +++++++++--------- src/varinfo.jl | 347 ++++++++--------- test/accumulators.jl | 2 +- test/bijector.jl | 42 -- test/chains.jl | 55 +-- test/compiler.jl | 20 +- test/conditionfix.jl | 7 +- test/contexts/init.jl | 59 +-- test/ext/DynamicPPLMarginalLogDensitiesExt.jl | 28 +- test/linking.jl | 17 +- test/logdensityfunction.jl | 4 +- test/model.jl | 69 ++-- test/runtests.jl | 1 - test/submodels.jl | 94 +++-- test/varinfo.jl | 105 +++-- 43 files changed, 984 insertions(+), 1188 deletions(-) delete mode 100644 src/accumulators/bijector.jl create mode 100644 src/accumulators/linked_vec_transforms.jl delete mode 100644 test/bijector.jl diff --git a/HISTORY.md b/HISTORY.md index 130b4facb..94244092d 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,6 +1,91 @@ # 0.41 -Removed the `varinfo` keyword argument from `DynamicPPL.TestUtils.AD.run_ad` and replaced the `varinfo` field in `ADResult` with `ldf::LogDensityFunction`. +## Breaking changes + +### Unification of transformed values + +Previously, there were separate types `UntransformedValue`, `VectorValue`, and `LinkedVectorValue`, which were all subtypes of `AbstractTransformedValue`. +The abstract type has been removed, and all of these have been unified in a single `TransformedValue` struct, which wraps the (maybe transformed) value, plus an `AbstractTransform` that describes the inverse transformation (to get back to the raw value). + +Concretely, + + - `UntransformedValue(val)` is now `TransformedValue(val, NoTransform())` + - `VectorValue(vec, tfm)` is now `TransformedValue(vec, Unlink())` + - `LinkedVectorValue(vec, tfm)` is now `TransformedValue(vec, DynamicLink())` + +**Note that this means for `VectorValue` and `LinkedVectorValue`, the transform is no longer stored on the value itself.** +This means that given one of these values, you *cannot* access the raw value without running the model. + +The reason why this is done is that the transform may in principle change between model executions. +This can happen if the prior distribution of a variable depends on the value of another variable. +Previously, in DynamicPPL, we *always* made sure to recompute the transform during model evaluation; however, this was not enforced by the data structure. +The current implementation makes it impossible to accidentally use an outdated transform, and is therefore more robust. + +### Addition of `FixedTransform` + +The above unification allows us to introduce a new transform subtype, `FixedTransform{F}`, which wraps a known function `F` that is assumed to always be static, allowing the transform to be cached and reused across model executions. +**This should only be used when it is known ahead of time that the transform will never change between model executions.** +It is the user's responsibility to ensure that this is the case. +Using `FixedTransform` when the transform does change between model executions can lead to incorrect results. + +For many simple distributions, this in fact saves absolutely no time, because deriving the transform from the distribution takes almost negligible time (~ 1 ns!). +However, there are some edge cases for which this is not the case: for example, `product_distribution([Beta(2, 2), Normal()])` is quite slow (~ 3 µs). +In such cases, using `FixedTransform` can lead to substantial performance improvements. + +To use `FixedTransform` with `LogDensityFunction`, you need to: + + 1. Create a `VarNamedTuple` mapping `VarName`s to `FixedTransform`s for the variables in your model. + This can be done using `get_linked_vec_transforms(model)`, which automatically calculates `Bijectors.VectorBijectors.from_linked_vec(dist)` for each variable in the model. + TODO: Control whether it's linked or not???? + + 2. Wrap the `VarNamedTuple` inside `WithTransforms(vnt, UnlinkAll())`. + `WithTransforms` is a subtype of `AbstractTransformStrategy`, much like `LinkAll()`. + However, `WithTransforms` specifies that *these exact transforms are to be used*, whereas `LinkAll` says 'derive the transforms again at model runtime'. + 3. Construct a `LogDensityFunction(model, getlogjoint_internal, WithTransforms(...)); adtype=adtype`. + +### Removal of `getindex(vi::VarInfo, vn::VarName)` + +The main role of `VarInfo` was to store vectorised transformed values. +Previously, these were stored as `VectorValue`s or `LinkedVectorValue`s: these used to carry the inverse transform with them, which allowed you to access the raw value via `vi[vn]`, or equivalently `getindex(vi, vn)`. + +The problem with this is that (as described above) the correct transform may depend on the values of the variables themselves. +That means that if we update the vectorised value without changing the transform, we could end up with an inconsistent state and incorrect results. +In particular, this is *exactly* what the function `unflatten!!` does: it updates the vectorised values but does not touch the transform. + +In the current version, we have removed this method to prevent the possibility of obtaining incorrect results. +(Our hands are also forced by the fact that the new `TransformedValue`s do not store the actual transform with them.) + +*In place of using `VarInfo`, we strongly recommend that you migrate to using `OnlyAccsVarInfo`.* +In particular, to access raw (untransformed) values, you should use an `OnlyAccsVarInfo` with a `RawValueAccumulator`. +There is [a migration guide available on the DynamicPPL documentation](https://turinglang.org/DynamicPPL.jl/stable/migration/) and we are very happy to add more examples to this if you run into something that is not covered. + +### LinkedVecTransformAccumulator + +TODO, this part is still being worked on. + + - `BijectorAccumulator` → `LinkedVecTransformAccumulator` + - `get_linked_vec_transforms(::VarInfo)` + - `get_linked_vec_transforms(::Model)` + +## Miscellaneous breaking changes + + - Removed the `varinfo` keyword argument from `DynamicPPL.TestUtils.AD.run_ad`, and replaced the `varinfo` field in the returned `ADResult` with `ldf::LogDensityFunction`. + +## Internal changes + +The following functions were not exported; we document changes in them for completeness. + + - Given the above changes, the old framework of `AbstractTransformation`, `StaticTransformation`, and `DynamicTransformation` are no longer needed, and have been removed. + The (rarely used) methods of `link`, `link!!`, `invlink`, and `invlink!!` that took an `AbstractTransformation` as the first argument have been removed. + The same is true of the functions `default_transformation` and `transformation`. + + - `RangeAndLinked` has been expanded to `RangeAndTransform`: instead of just carrying a Boolean indicating whether the transform is `DynamicLink` or `Unlink`, it now stores the full transform. + This is done in order to accommodate `FixedTransform`. + - Consequently, `get_ranges_and_linked` has been renamed to `get_rangeandtransforms`. + Its function is still to return a `VarNamedTuple` of `RangeAndTransform`s. + - `update_link_status!!` has been renamed to `update_transform_status!!`. + - `get_transform_strategy` has been renamed to `infer_transform_strategy`. + - `from_internal_transform` and `from_linked_internal_transform` have been removed, since the new `TransformedValue`s do not store the transform with them. Removed `getargnames`, `getmissings`, and `Base.nameof(::Model)` from the public API (export and documentation) as they are considered internal implementation details. diff --git a/docs/src/accs/existing.md b/docs/src/accs/existing.md index 8e6eeab99..f2e7f51ea 100644 --- a/docs/src/accs/existing.md +++ b/docs/src/accs/existing.md @@ -42,4 +42,6 @@ get_vector_params ```@docs PriorDistributionAccumulator get_priors +LinkedVecTransformAccumulator +get_linked_vec_transforms ``` diff --git a/docs/src/accs/values.md b/docs/src/accs/values.md index c2b1ec6db..07a214226 100644 --- a/docs/src/accs/values.md +++ b/docs/src/accs/values.md @@ -29,8 +29,10 @@ struct VarInfo{Tfm<:AbstractTransformStrategy,V<:VarNamedTuple,A<:AccumulatorTup end ``` -The `values` field stores either [`LinkedVectorValue`](@ref)s or [`VectorValue`](@ref)s. -The `transform_strategy` field stores an `AbstractTransformStrategy` which is (as far as possible) consistent with the type of values stored in `values`. +The `values` field stores `DynamicPPL.TransformedValue`s, but it is mandatory that these transformed values are vectorised. +That is, it is permissible to store (for example) `TransformedValue([1.0], Unlink())`, but not `TransformedValue(1.0, NoTransform())`. + +Furthermore, the `transform_strategy` field stores an `AbstractTransformStrategy` which is (as far as possible) consistent with the type of values stored in `values`. Here is an example: @@ -46,7 +48,7 @@ vi = VarInfo(dirichlet_model) vi ``` -In `VarInfo`, it is mandatory to store `LinkedVectorValue`s or `VectorValue`s as `ArrayLikeBlock`s (see the [Array-like blocks](@ref array-like-blocks) documentation for information on this). +In `vi.values`, it is mandatory to store `TransformedValue`s as `ArrayLikeBlock`s (see the [Array-like blocks](@ref array-like-blocks) documentation for information on this). The reason is because, if the value is linked, it may have a different size than the number of indices in the `VarName`. This means that when retrieving the keys, we obtain each block as a single key: @@ -60,21 +62,21 @@ In a `VarInfo`, the `accs` field is responsible for the accumulation step, just However, `values` serves three purposes in one: - - it is sometimes used for initialisation (when the model's leaf context is `DefaultContext`, the `AbstractTransformedValue` to be used in the transformation step is read from it) + - it is sometimes used for initialisation (when the model's leaf context is `DefaultContext`, the `TransformedValue` to be used in the transformation step is read from it) - it also determines whether the log-Jacobian term should be included or not (if the value is a `LinkedVectorValue`, the log-Jacobian is included) - - it is sometimes also used for accumulation (when evaluating a model with a VarInfo, we will potentially store a new `AbstractTransformedValue` in it!). + - it is sometimes also used for accumulation (when evaluating a model with a VarInfo, we will potentially store a new `TransformedValue` in it!). The path to removing `VarInfo` is essentially to separate these three roles: 1. The initialisation role of `varinfo.values` can be taken over by an initialisation strategy that wraps it. - Recall that the only role of an initialisation strategy is to provide an `AbstractTransformedValue` via [`DynamicPPL.init`](@ref). + Recall that the only role of an initialisation strategy is to provide an `TransformedValue` via [`DynamicPPL.init`](@ref). This can be trivially done by indexing into the `VarNamedTuple` stored in the strategy. 2. Whether the log-Jacobian term should be included or not can be determined by a transform strategy. Much like how we can have an initialisation strategy that takes values from a `VarInfo`, we can also have a transform strategy that is defined by the existing status of a `VarInfo`. This is implemented in the `DynamicPPL.get_link_strategy(::AbstractVarInfo)` function. 3. The accumulation role of `varinfo.values` can be taken over by a new accumulator, which we call `VectorValueAccumulator`. - This name is chosen because it does not store generic `AbstractTransformedValue`s, but only two subtypes of it, `LinkedVectorValue` and `VectorValue`. + This name is chosen because it does not store generic `TransformedValue`s, but only two subtypes of it, `LinkedVectorValue` and `VectorValue`. `VectorValueAccumulator` is implemented inside `src/accs/vector_value.jl`. !!! note @@ -86,11 +88,11 @@ The path to removing `VarInfo` is essentially to separate these three roles: ## `RawValueAccumulator` -Earlier we said that `VectorValueAccumulator` stores only two subtypes of `AbstractTransformedValue`: `LinkedVectorValue` and `VectorValue`. -One might therefore ask about the third subtype, namely, `UntransformedValue`. +Earlier we said that `VectorValueAccumulator` stores only values that have been vectorised. +One might therefore ask about unvectorised values — and in particular, values that have *not* been transformed at all, i.e., `TransformedValue(val, NoTransform())`. -It turns out that it is very often useful to store [`UntransformedValue`](@ref)s. -Additionally, since `UntransformedValue`s must always correspond exactly to the indices they are assigned to, we can unwrap them and do not need to store them as array-like blocks! +It turns out that it is very often useful to store such untransformed values. +Additionally, since the values must always correspond exactly to the indices they are assigned to, we can unwrap them and do not need to store them as array-like blocks! This is the role of `RawValueAccumulator`. @@ -100,7 +102,7 @@ _, oavi = DynamicPPL.init!!(dirichlet_model, oavi, InitFromPrior(), UnlinkAll()) raw_vals = get_raw_values(oavi) ``` -Note that when we unwrap `UntransformedValue`s, we also lose the block structure that was present in the model. +Note that when we unwrap `TransformedValue`s, we also lose the block structure that was present in the model. That means that in `RawValueAccumulator`, there is no longer any notion that `x[1:3]` was set together, so the keys correspond to the individual indices. ```@example 1 @@ -116,18 +118,18 @@ This is why indices of keys like `x[1:3] ~ dist` end up being split up in chains ## Why do we still need to store `TransformedValue`s? -Given that `RawValueAccumulator` exists, one may wonder why we still need to store the other `AbstractTransformedValue`s at all, i.e. what the purpose of `VectorValueAccumulator` is. +Given that `RawValueAccumulator` exists, one may wonder why we still need to store the other `TransformedValue`s at all, i.e. what the purpose of `VectorValueAccumulator` is. Currently, the only remaining reason for transformed values is the fact that we may sometimes need to perform [`DynamicPPL.unflatten!!`](@ref) on a `VarInfo`, to insert new values into it from a vector. ```@example 1 vi = VarInfo(dirichlet_model) -vi[@varname(x[1:3])] +DynamicPPL.getindex_internal(vi, @varname(x[1:3])) ``` ```@example 1 vi = DynamicPPL.unflatten!!(vi, [0.2, 0.5, 0.3]) -vi[@varname(x[1:3])] +DynamicPPL.getindex_internal(vi, @varname(x[1:3])) ``` If we do not store the vectorised form of the values, we will not know how many values to read from the input vector for each key. @@ -135,5 +137,5 @@ If we do not store the vectorised form of the values, we will not know how many Removing upstream usage of `unflatten!!` would allow us to completely get rid of `TransformedValueAccumulator` and only ever use `RawValueAccumulator`. See [this DynamicPPL issue](https://github.com/TuringLang/DynamicPPL.jl/issues/836) for more information. -One possibility for removing `unflatten!!` is to turn it into a function that, instead of generating a new VarInfo, instead generates a tuple of new initialisation and link strategies which returns `LinkedVectorValue`s or `VectorValue`s containing views into the input vector. +One possibility for removing `unflatten!!` is to turn it into a function that, instead of generating a new VarInfo, instead generates a tuple of new initialisation and transform strategies similar to `InitFromVector`. This would be conceptually very similar to how `LogDensityFunction` currently works. diff --git a/docs/src/api.md b/docs/src/api.md index 289febe0d..fabdee5a0 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -440,19 +440,12 @@ DynamicPPL.setindex_internal!! #### Transformations -```@docs -DynamicPPL.AbstractTransformation -DynamicPPL.NoTransformation -DynamicPPL.DynamicTransformation -DynamicPPL.StaticTransformation -``` - ```@docs DynamicPPL.link DynamicPPL.invlink DynamicPPL.link!! DynamicPPL.invlink!! -DynamicPPL.update_link_status!! +DynamicPPL.update_transform_status!! ``` ```@docs @@ -461,21 +454,19 @@ DynamicPPL.LinkAll DynamicPPL.UnlinkAll DynamicPPL.LinkSome DynamicPPL.UnlinkSome +DynamicPPL.WithTransforms ``` ```@docs DynamicPPL.AbstractTransform DynamicPPL.DynamicLink DynamicPPL.Unlink +DynamicPPL.FixedTransform +DynamicPPL.NoTransform DynamicPPL.target_transform DynamicPPL.apply_transform_strategy ``` -```@docs -DynamicPPL.transformation -DynamicPPL.default_transformation -``` - #### Utils ```@docs @@ -561,14 +552,10 @@ init get_param_eltype ``` -The function [`DynamicPPL.init`](@ref) should return an `AbstractTransformedValue`. -There are three subtypes currently available: +The function [`DynamicPPL.init`](@ref) should return a `TransformedValue`. ```@docs -DynamicPPL.AbstractTransformedValue -DynamicPPL.VectorValue -DynamicPPL.LinkedVectorValue -DynamicPPL.UntransformedValue +DynamicPPL.TransformedValue ``` The interface for working with transformed values consists of: diff --git a/docs/src/evaluation.md b/docs/src/evaluation.md index 8a5910b30..10e1333ec 100644 --- a/docs/src/evaluation.md +++ b/docs/src/evaluation.md @@ -249,7 +249,7 @@ getlogprior(accs) - getlogjac(accs) You might ask: given that we specified parameters in untransformed space, how do we then retrieve the parameters in transformed space? The answer to this is to use an accumulator (no surprises there!) that collects the transformed values. -Specifically, a `VectorValueAccumulator` collects vectorised forms of the parameters, which may either be [`VectorValue`](@ref)s or [`LinkedVectorValue`](@ref)s. +Specifically, a `VectorValueAccumulator` collects vectorised forms of the parameters: that is, `TransformedValue{V,T}` where `V<:AbstractVector`. ```@example 1 accs = OnlyAccsVarInfo(VectorValueAccumulator()) diff --git a/docs/src/init.md b/docs/src/init.md index 306cdbe5d..ed8cd68f7 100644 --- a/docs/src/init.md +++ b/docs/src/init.md @@ -29,12 +29,12 @@ The subsequent sections will demonstrate how this can be done. ## The required interface -Each initialisation strategy must subtype `AbstractInitStrategy`, and implement `DynamicPPL.init(rng, vn, dist, strategy)`, which must return an `AbstractTransformedValue`. +Each initialisation strategy must subtype `AbstractInitStrategy`, and implement `DynamicPPL.init(rng, vn, dist, strategy)`, which must return a `TransformedValue`. ```@docs; canonical=false AbstractInitStrategy init -DynamicPPL.AbstractTransformedValue +DynamicPPL.TransformedValue ``` ## An example @@ -64,7 +64,7 @@ function DynamicPPL.init(rng, vn::VarName, ::Distribution, strategy::InitRandomW new_x = rand(rng, Normal(strategy.x_prev, strategy.step_size)) # Insert some printing to see when this is called. @info "init() is returning: $new_x" - return DynamicPPL.UntransformedValue(new_x) + return DynamicPPL.TransformedValue(new_x, DynamicPPL.NoTransform()) end ``` @@ -103,19 +103,19 @@ In this case, we have defined an initialisation strategy that is random (and thu However, initialisation strategies can also be fully deterministic, in which case the `rng` argument is not needed. For example, [`DynamicPPL.InitFromParams`](@ref) reads from a set of parameters which are known ahead of time. -## The returned `AbstractTransformedValue` +## The returned `TransformedValue` -As mentioned above, the `init` function must return an `AbstractTransformedValue`. -The subtype of `AbstractTransformedValue` used does not affect the result of the model evaluation, but it may have performance implications. +As mentioned above, the `init` function must return an `TransformedValue`. +The transform stored inside this does not affect the result of the model evaluation, but it may have performance implications. **In particular, the returned subtype does not determine whether the log-Jacobian term is accumulated or not: that is determined by a separate [_transform strategy_](@ref transform-strategies).** -What this means is that initialisation strategies should always choose the laziest possible subtype of `AbstractTransformedValue`. +What this means is that initialisation strategies should always choose the laziest possible version of `TransformedValue`, electing to do as few transformations as possible inside `init`. -For example, in the above example, we used `UntransformedValue`, which is the simplest possible choice. +For example, in the above example, we simply wrapped the untransformed value in `TransformedValue(..., NoTransform())`, which is the simplest possible choice. If a linked value is required by a later step inside `tilde_assume!!` (either the transformation or accumulation steps), it is the responsibility of that step to perform the linking. Conversely, [`DynamicPPL.InitFromUniform`](@ref) samples inside linked space. -Instead of performing the inverse link transform and returning an `UntransformedValue`, it directly returns a `LinkedVectorValue`: this means that if a linked value is required by a later step, it is not necessary to link it again. +Instead of performing the inverse link transform eagerly, it directly returns a `TransformedValue(val, DynamicLink())`, where `val` is *already* the linked vector: this means that if a linked value is required by a later step, it is not necessary to link it again. Even if no linked value is required, this lazy approach does not hurt performance, as it just defers the inverse linking to the later step. In both cases, only one linking operation is performed (at most). diff --git a/docs/src/migration.md b/docs/src/migration.md index 6b1701d18..5b0f97dd5 100644 --- a/docs/src/migration.md +++ b/docs/src/migration.md @@ -36,7 +36,11 @@ Old: ```@example 1 vi = VarInfo(Xoshiro(468), model) -vi[@varname(x)], vi[@varname(y)] +# This no longer works, but you may have used it. +# vi[@varname(x)], vi[@varname(y)] + +# This still works +DynamicPPL.getindex_internal(vi, @varname(x)) ``` New: diff --git a/docs/src/tilde.md b/docs/src/tilde.md index cd2ac46e7..458b12d0e 100644 --- a/docs/src/tilde.md +++ b/docs/src/tilde.md @@ -47,7 +47,7 @@ Every tilde-statement `vn ~ dist` (where `vn` represents a random variable) is t As described on the [Model evaluation page](./evaluation.md), there are three stages to every tilde-statement: - 1. Initialisation: get an `AbstractTransformedValue` from the initialisation strategy. + 1. Initialisation: get an `TransformedValue` from the initialisation strategy. 2. Transformation: figure out the untransformed (raw) value and the transformed value (where necessary); compute the relevant log-Jacobian. 3. Accumulation: pass all the relevant information to the accumulators, which individually decide what to do with it. @@ -98,20 +98,20 @@ For example, if `ctx.strategy` is `InitFromPrior()`, then `init()` samples a val For `DefaultContext`, initialisation is handled by looking for the value stored inside `vi`. As discussed in the [Initialisation strategies](./init.md) page, this step, in general, does not return just the raw value (like `rand(dist)`). -It returns an [`DynamicPPL.AbstractTransformedValue`](@ref), which represents a value that _may_ have been transformed. -In the case of `InitFromPrior()`, the value is of course not transformed; we return a [`DynamicPPL.UntransformedValue`](@ref) wrapping the sampled value. +It returns an [`DynamicPPL.TransformedValue`](@ref), which represents a value that _may_ have been transformed. +In the case of `InitFromPrior()`, the value is of course not transformed; we return a [`DynamicPPL.TransformedValue(..., NoTransform())`](@ref TransformedValue) wrapping the sampled value. However, consider the case where we are using parameters stored inside a `VarInfo`: the value may have been stored either as a vectorised form, or as a linked vectorised form. -In this case, `init()` will return either a [`DynamicPPL.VectorValue`](@ref) or a [`DynamicPPL.LinkedVectorValue`](@ref). +In this case, `init()` will return a `TransformedValue` wrapping the vectorised value, whose transform is either `Unlink()` (i.e., vectorised but not linked), or `DynamicLink()` (i.e., vectorised and linked). The reason why we return this wrapped value is because we want to avoid having to perform transformations multiple times. Each step is responsible for only performing the transformations it needs to. At this stage, there has not yet been any need for the raw value, so we do not perform any transformations yet. -Thus, the `AbstractTransformedValue` is passed straight through and is used by the transformation step. +Thus, the `TransformedValue` is passed straight through and is used by the transformation step. !!! note "The return type of init() doesn't matter" - The exact subtype of `AbstractTransformedValue` returned by `init()` has no impact on whether the value is considered to be transformed or not. + The exact `TransformedValue` returned by `init()` has no impact on whether the value is considered to be transformed or not. That is determined solely by the transform strategy (see below). This separation allows us to perform the minimum amount of transformations necessary inside `init()`. If we were to eagerly transform the value inside `init()`, we could easily end up performing the same transformation multiple times across the different steps. @@ -184,7 +184,7 @@ vi = DynamicPPL.accumulate_assume!!(vi, x, tval, logjac, vn, dist, template) !!! note The first line, `setindex_with_dist!!`, is only necessary when using a full `VarInfo`. - It essentially stores the value `tval` inside the `VarInfo`, but makes sure to store a vectorised form (i.e., if `tval` is an `UntransformedValue`, it will be converted to a `VectorValue` before being stored). + It essentially stores the value `tval` inside the `VarInfo`, but makes sure to store a vectorised form (i.e., if `tval` is not vectorised, it will be). This is entirely equivalent to using a `VectorValueAccumulator` to store the values; it's just that when using a full `VarInfo` that accumulator is 'built-in' as `vi.values`. Since conceptually this is the same as an accumulator, we will not discuss it further here. diff --git a/docs/src/transforms.md b/docs/src/transforms.md index 08c7d657c..b805b2019 100644 --- a/docs/src/transforms.md +++ b/docs/src/transforms.md @@ -78,7 +78,7 @@ vi_linked.accs The *transform strategy* is what determines whether the log-Jacobian is applied or not when evaluating the log-probability. One could think of the transform strategy as being a *re-interpretation* of the value provided by the initialisation strategy. - This frees up the initialisation strategy to return whatever kind of `AbstractTransformedValue` is most convenient for it. + This frees up the initialisation strategy to return whatever kind of `TransformedValue` is most convenient for it. ## Making your own transform strategy @@ -197,6 +197,6 @@ Furthermore, this allows us to remove the inverse transform step inside `InitFro Essentially, having a separate transform strategy allows us to: - 1. Free up the initialisation strategy to return whatever kind of `AbstractTransformedValue` is most convenient for it, without worrying about whether it needs to perform some transform. + 1. Free up the initialisation strategy to return whatever kind of `TransformedValue` is most convenient for it, without worrying about whether it needs to perform some transform. 2. Consolidate all the actual transformation in a single function (`DynamicPPL.apply_transform_strategy`), which allows us to ensure that each tilde-statement involves at most *one* transformation. diff --git a/docs/src/vnt/arraylikeblocks.md b/docs/src/vnt/arraylikeblocks.md index 95c51e0f3..d5b80a6f7 100644 --- a/docs/src/vnt/arraylikeblocks.md +++ b/docs/src/vnt/arraylikeblocks.md @@ -95,7 +95,7 @@ Some examples follow. ### VarInfo -In `VarInfo`, we need to be able to store either linked or unlinked values (in general, `AbstractTransformedValue`s). +In `VarInfo`, we need to be able to store either linked or unlinked values (in general, `TransformedValue`s). These are always vectorised values, and the linked and unlinked vectors may have different sizes (this is indeed the case for Dirichlet distributions). This means that we have to collectively assign multiple indices in the `VarNamedTuple` to a single vector, which may or may not have the same size as the indices. diff --git a/ext/DynamicPPLEnzymeCoreExt.jl b/ext/DynamicPPLEnzymeCoreExt.jl index cdeb6c8e6..52a9af728 100644 --- a/ext/DynamicPPLEnzymeCoreExt.jl +++ b/ext/DynamicPPLEnzymeCoreExt.jl @@ -8,7 +8,7 @@ using EnzymeCore @inline EnzymeCore.EnzymeRules.inactive(::typeof(DynamicPPL.is_transformed), args...) = nothing @inline EnzymeCore.EnzymeRules.inactive( - ::typeof(DynamicPPL._get_range_and_linked), args... + ::typeof(DynamicPPL._get_range_and_transform), args... ) = nothing # Enzyme errors on Gibbs sampling without this one. @inline EnzymeCore.EnzymeRules.inactive( diff --git a/ext/DynamicPPLMarginalLogDensitiesExt.jl b/ext/DynamicPPLMarginalLogDensitiesExt.jl index 348152d30..75d2d25b8 100644 --- a/ext/DynamicPPLMarginalLogDensitiesExt.jl +++ b/ext/DynamicPPLMarginalLogDensitiesExt.jl @@ -1,6 +1,6 @@ module DynamicPPLMarginalLogDensitiesExt -using DynamicPPL: DynamicPPL, LogDensityProblems, VarName, RangeAndLinked +using DynamicPPL: DynamicPPL, LogDensityProblems, VarName, RangeAndTransform using MarginalLogDensities: MarginalLogDensities # A thin wrapper to adapt a DynamicPPL.LogDensityFunction to the interface expected by @@ -107,7 +107,7 @@ function DynamicPPL.marginalize( varindices = mapreduce(vcat, marginalized_varnames) do vn # The type assertion helps in cases where the model is type unstable and thus # `varname_ranges` may have an abstract element type. - (ldf._varname_ranges[vn]::RangeAndLinked).range + (ldf._varname_ranges[vn]::RangeAndTransform).range end mld = MarginalLogDensities.MarginalLogDensity( LogDensityFunctionWrapper(ldf, varinfo), diff --git a/ext/DynamicPPLMooncakeExt.jl b/ext/DynamicPPLMooncakeExt.jl index 9760d9f4b..b876df575 100644 --- a/ext/DynamicPPLMooncakeExt.jl +++ b/ext/DynamicPPLMooncakeExt.jl @@ -4,10 +4,10 @@ using DynamicPPL: DynamicPPL, is_transformed using Mooncake: Mooncake # These are purely optimisations (although quite significant ones sometimes, especially for -# _get_range_and_linked). +# _get_range_and_transform). Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(is_transformed),Vararg} Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{ - typeof(DynamicPPL._get_range_and_linked),Vararg + typeof(DynamicPPL._get_range_and_transform),Vararg } Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{ typeof(Base.haskey),DynamicPPL.VarInfo,DynamicPPL.VarName diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 379437303..8bc540d9d 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -102,7 +102,8 @@ export AbstractVarInfo, # Accumulators - miscellany PriorDistributionAccumulator, get_priors, - BijectorAccumulator, + LinkedVecTransformAccumulator, + get_linked_vec_transforms, # Working with internal values as vectors unflatten!!, internal_values_as_vector, @@ -150,28 +151,30 @@ export AbstractVarInfo, get_param_eltype, init, # Transformed values - VectorValue, - LinkedVectorValue, - UntransformedValue, + TransformedValue, get_transform, get_internal_value, set_internal_value, - # Linking - link, - link!!, - invlink, - invlink!!, - update_link_status!!, + # Transform strategies + update_transform_status!!, AbstractTransformStrategy, LinkAll, UnlinkAll, LinkSome, UnlinkSome, + WithTransforms, target_transform, apply_transform_strategy, AbstractTransform, DynamicLink, Unlink, + FixedTransform, + NoTransform, + # Linking + link, + link!!, + invlink, + invlink!!, # Pseudo distributions NamedDist, NoDist, @@ -254,7 +257,7 @@ include("accumulators/vnt.jl") include("accumulators/vector_values.jl") include("accumulators/priors.jl") include("accumulators/raw_values.jl") -include("accumulators/bijector.jl") +include("accumulators/linked_vec_transforms.jl") include("accumulators/pointwise_logdensities.jl") include("abstract_varinfo.jl") include("threadsafe.jl") diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index ba30404ec..32e68110e 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -1,75 +1,3 @@ -# Transformation related. -""" - $(TYPEDEF) - -Represents a transformation to be used in `link!!` and `invlink!!`, amongst others. - -A concrete implementation of this should implement the following methods: -- [`link!!`](@ref): transforms the [`AbstractVarInfo`](@ref) to the unconstrained space. -- [`invlink!!`](@ref): transforms the [`AbstractVarInfo`](@ref) to the constrained space. - -See also: [`link!!`](@ref), [`invlink!!`](@ref) -""" -abstract type AbstractTransformation end - -""" - $(TYPEDEF) - -Transformation which applies the identity function. -""" -struct NoTransformation <: AbstractTransformation end - -""" - $(TYPEDEF) - -Transformation which transforms the variables on a per-need-basis -in the execution of a given `Model`. - -This is in constrast to `StaticTransformation` which transforms all variables -_before_ the execution of a given `Model`. - -Different VarInfo types should implement their own methods for `link!!` and `invlink!!` for -`DynamicTransformation`. - -See also: [`StaticTransformation`](@ref). -""" -struct DynamicTransformation <: AbstractTransformation end - -""" - $(TYPEDEF) - -Transformation which represents a fixed bijector to be applied to the variables, as opposed -to deriving the bijector again at runtime. - -See also: [`DynamicTransformation`](@ref). - -# Fields -$(TYPEDFIELDS) -""" -struct StaticTransformation{F} <: AbstractTransformation - "The function, assumed to implement the `Bijectors` interface, to be applied to the variables" - bijector::F -end - -function merge_bijectors(left::Bijectors.NamedTransform, right::Bijectors.NamedTransform) - return Bijectors.NamedTransform(merge_bijector(left.bs, right.bs)) -end - -""" - default_transformation(model::Model[, vi::AbstractVarInfo]) - -Return the `AbstractTransformation` currently related to `model` and, potentially, `vi`. -""" -default_transformation(model::Model, ::AbstractVarInfo) = default_transformation(model) -default_transformation(::Model) = DynamicTransformation() - -""" - transformation(vi::AbstractVarInfo) - -Return the `AbstractTransformation` related to `vi`. -""" -function transformation end - # Accumulation of log-probabilities. """ getlogjoint(vi::AbstractVarInfo) @@ -518,9 +446,9 @@ function getindex_internal end """ get_transformed_value(vi::AbstractVarInfo, vn::VarName) -Return the actual `AbstractTransformedValue` stored in `vi` for variable `vn`. +Return the actual `TransformedValue` stored in `vi` for variable `vn`. -This differs from `getindex_internal`, which obtains the `AbstractTransformedValue` and then +This differs from `getindex_internal`, which obtains the `TransformedValue` and then directly returns `get_internal_value(tval)`; and `getindex` which returns `get_transform(tval)(get_internal_value(tval))`. """ @@ -574,15 +502,15 @@ Subset a `varinfo` to only contain the variables `vns`. The ordering of variables in the return value will be the same as in `varinfo`. # Examples -```jldoctest varinfo-subset; setup = :(using Distributions, DynamicPPL) +julia> using DynamicPPL, Distributions + julia> @model function demo() s ~ InverseGamma(2, 3) m ~ Normal(0, sqrt(s)) x = Vector{Float64}(undef, 2) x[1] ~ Normal(m, sqrt(s)) x[2] ~ Normal(m, sqrt(s)) - end -demo (generic function with 2 methods) + end; julia> model = demo(); @@ -598,7 +526,7 @@ VarNamedTuple julia> vi = last(init!!(model, VarInfo(), InitFromParams(params), UnlinkAll())); -julia> keys(vi) +julia> vi.values 4-element Vector{VarName}: s m @@ -612,8 +540,9 @@ julia> keys(vi_subset1) 1-element Vector{VarName}: m -julia> vi_subset1[@varname(m)] -2.0 +julia> DynamicPPL.getindex_internal(vi_subset1, @varname(m)) +1-element Vector{Float64}: + 2.0 julia> # Extract one with both `s` and `x[2]`. vi_subset2 = subset(vi, [@varname(s), @varname(x[2])]); @@ -623,9 +552,12 @@ julia> keys(vi_subset2) s x[2] -julia> vi_subset2[[@varname(s), @varname(x[2])]] -2-element Vector{Float64}: +julia> DynamicPPL.getindex_internal(vi_subset2, @varname(s)) +1-element Vector{Float64}: 1.0 + +julia> DynamicPPL.getindex_internal(vi_subset2, @varname(x[2])) +1-element Vector{Float64}: 4.0 ``` @@ -656,23 +588,7 @@ julia> keys(vi_merged) m x[1] x[2] - -julia> vi_merged[[@varname(s), @varname(m), @varname(x[1]), @varname(x[2])]] -4-element Vector{Float64}: - 1.0 - 2.0 - 3.0 - 4.0 ``` - -# Notes - -## Type-stability - -!!! warning - This function is only type-stable when `vns` contains only varnames - with the same symbol. For example, `[@varname(m[1]), @varname(m[2])]` will - be type-stable, but `[@varname(m[1]), @varname(x)]` will not be. """ function subset end @@ -696,7 +612,6 @@ function Base.merge( return merge(Base.merge(varinfo1, varinfo2), varinfo3, varinfo_others...) end -# Transformations """ is_transformed(vi::AbstractVarInfo[, vns::Union{VarName, AbstractVector{<:Varname}}]) @@ -731,97 +646,59 @@ If `vn` is not specified, then `is_transformed(vi)` evaluates to `true` for all """ function set_transformed!! end -# For link!!, invlink!!, link, and invlink, we deliberately do not provide a fallback -# method for the case when no `vns` is provided, that would get all the keys from the -# `VarInfo`. Hence each subtype of `AbstractVarInfo` needs to implement separately the case -# where `vns` is provided and the one where it is not. This is because having separate -# implementations is typically much more performant, and because not all AbstractVarInfo -# types support partial linking. - -""" - link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) - link!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vns::NTuple{N,VarName}, model::Model) - -Transform variables in `vi` to their linked space, mutating `vi` if possible. - -Either transform all variables, or only ones specified in `vns`. - -Use the transformation `t`, or `default_transformation(model, vi)` if one is not provided. - -See also: [`default_transformation`](@ref), [`invlink!!`](@ref). -""" -function link!!(vi::AbstractVarInfo, model::Model) - return link!!(default_transformation(model, vi), vi, model) -end -function link!!(vi::AbstractVarInfo, vns, model::Model) - return link!!(default_transformation(model, vi), vi, vns, model) -end - """ - link([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) - link([t::AbstractTransformation, ]vi::AbstractVarInfo, vns::NTuple{N,VarName}, model::Model) - -Transform variables in `vi` to their linked space without mutating `vi`. + link(vi::AbstractVarInfo, model::Model) + link(vi::AbstractVarInfo, vns::NTuple{N,VarName}, model::Model) -Either transform all variables, or only ones specified in `vns`. +Transform all variables in `vi` to their linked space without mutating `vi` (i.e., replace +all the `TransformedValue`s in `vi.values` with the corresponding +`TransformedValue(linked_value, DynamicLink())`. If `vns` is provided, then only transform +the variables in `vns`. -Use the transformation `t`, or `default_transformation(model, vi)` if one is not provided. - -See also: [`default_transformation`](@ref), [`invlink`](@ref). +See also: [`invlink`](@ref). """ function link(vi::AbstractVarInfo, model::Model) - return link(default_transformation(model, vi), vi, model) + return link!!(deepcopy(vi), model) end function link(vi::AbstractVarInfo, vns, model::Model) - return link(default_transformation(model, vi), vi, vns, model) -end -function link(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) - return link!!(t, deepcopy(vi), model) + return link!!(deepcopy(vi), vns, model) end """ - invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) - invlink!!([t::AbstractTransformation, ]vi::AbstractVarInfo, vns::NTuple{N,VarName}, model::Model) + link!!(vi::AbstractVarInfo, model::Model) + link!!(vi::AbstractVarInfo, vns::NTuple{N,VarName}, model::Model) -Transform variables in `vi` to their constrained space, mutating `vi` if possible. - -Either transform all variables, or only ones specified in `vns`. - -Use the (inverse of) transformation `t`, or `default_transformation(model, vi)` if one is -not provided. - -See also: [`default_transformation`](@ref), [`link!!`](@ref). +Like `link`, but might mutate `vi` in-place if it is possible to do so. """ -function invlink!!(vi::AbstractVarInfo, model::Model) - return invlink!!(default_transformation(model, vi), vi, model) -end -function invlink!!(vi::AbstractVarInfo, vns, model::Model) - return invlink!!(default_transformation(model, vi), vi, vns, model) -end +function link!! end """ - invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) - invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, vns::NTuple{N,VarName}, model::Model) - -Transform variables in `vi` to their constrained space without mutating `vi`. - -Either transform all variables, or only ones specified in `vns`. + invlink(vi::AbstractVarInfo, model::Model) + invlink(vi::AbstractVarInfo, vns::NTuple{N,VarName}, model::Model) -Use the (inverse of) transformation `t`, or `default_transformation(model, vi)` if one is -not provided. +Transform all variables in `vi` to the original space without mutating `vi` (i.e., replace +all the `TransformedValue`s in `vi.values` with the corresponding +`TransformedValue(unlinked_value, Unlink())`. Note that the unlinked values are still +vectorised (that is a requirement of `vi.values`). If `vns` is provided, then only transform +the variables in `vns`. -See also: [`default_transformation`](@ref), [`link`](@ref). +See also: [`link`](@ref). """ function invlink(vi::AbstractVarInfo, model::Model) - return invlink(default_transformation(model, vi), vi, model) + return invlink!!(deepcopy(vi), model) end function invlink(vi::AbstractVarInfo, vns, model::Model) - return invlink(default_transformation(model, vi), vi, vns, model) -end -function invlink(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) - return invlink!!(t, deepcopy(vi), model) + return invlink!!(deepcopy(vi), vns, model) end +""" + invlink!!(vi::AbstractVarInfo, model::Model) + invlink!!(vi::AbstractVarInfo, vns::NTuple{N,VarName}, model::Model) + +Like `invlink`, but might mutate `vi` in-place if it is possible to do so. +""" +function invlink!! end + """ unflatten!!(vi::AbstractVarInfo, x::AbstractVector) @@ -834,11 +711,6 @@ This is the inverse operation of [`internal_values_as_vector`](@ref). Note that this does not re-evaluate the model (indeed it cannot!) so the contents of any accumulators in the `VarInfo` will almost certainly be inconsistent with the new values. - On top of that, it does not update the *transformations* stored inside the - `LinkedVectorValue`s and `VectorValue`s. If these transformations themselves depend on - the values of the variables, this can lead to incorrect results when trying to access - untransformed values, e.g. using `getindex(vi, vn)`. - **Because of these issues, we strongly recommend against using this function, unless absolutely necessary.** In many cases, usage of `unflatten!!(vi, x)` can be replaced with `InitFromVector(x, ldf::LogDensityFunction)`: please see the [DynamicPPL diff --git a/src/accumulators.jl b/src/accumulators.jl index 6b26f0f1e..7f688cc9b 100644 --- a/src/accumulators.jl +++ b/src/accumulators.jl @@ -20,10 +20,10 @@ In these functions: - `val` is the new value of the random variable sampled from a distribution (always in the original unlinked space), or the value on the left-hand side of an observe statement. -- `tval` is the original `AbstractTransformedValue` that was obtained from the - initialisation strategy. This is passed through unchanged to `accumulate_assume!!` since - it can be reused for some accumulators (e.g. when storing linked values, if the linked - value was already provided, it is faster to reuse it than to re-link `val`). +- `tval` is the original `TransformedValue` that was obtained from the initialisation + strategy. This is passed through unchanged to `accumulate_assume!!` since it can be reused + for some accumulators (e.g. when storing linked values, if the linked value was already + provided, it is faster to reuse it than to re-link `val`). - `dist` is the distribution on the RHS of the tilde statement. - `vn` is the `VarName` that is on the left-hand side of the tilde-statement. If the tilde-statement is a literal observation like `0.0 ~ Normal()`, then `vn` is `nothing`. diff --git a/src/accumulators/bijector.jl b/src/accumulators/bijector.jl deleted file mode 100644 index 42655eae5..000000000 --- a/src/accumulators/bijector.jl +++ /dev/null @@ -1,67 +0,0 @@ -struct BijectorAccumulator <: AbstractAccumulator - bijectors::Vector{Any} - sizes::Vector{Int} -end - -BijectorAccumulator() = BijectorAccumulator(Bijectors.Bijector[], UnitRange{Int}[]) - -function Base.:(==)(acc1::BijectorAccumulator, acc2::BijectorAccumulator) - return (acc1.bijectors == acc2.bijectors && acc1.sizes == acc2.sizes) -end - -function Base.copy(acc::BijectorAccumulator) - return BijectorAccumulator(copy(acc.bijectors), copy(acc.sizes)) -end - -accumulator_name(::Type{<:BijectorAccumulator}) = :Bijector - -function _zero(acc::BijectorAccumulator) - return BijectorAccumulator(empty(acc.bijectors), empty(acc.sizes)) -end -reset(acc::BijectorAccumulator) = _zero(acc) -split(acc::BijectorAccumulator) = _zero(acc) -function combine(acc1::BijectorAccumulator, acc2::BijectorAccumulator) - return BijectorAccumulator( - vcat(acc1.bijectors, acc2.bijectors), vcat(acc1.sizes, acc2.sizes) - ) -end - -function accumulate_assume!!( - acc::BijectorAccumulator, val, tval, logjac, vn, right, template -) - bijector = - Bijectors.VectorBijectors.to_linked_vec(right) ∘ - Bijectors.VectorBijectors.from_vec(right) - push!(acc.bijectors, bijector) - push!(acc.sizes, Bijectors.VectorBijectors.vec_length(right)) - return acc -end - -accumulate_observe!!(acc::BijectorAccumulator, right, left, vn, template) = acc - -""" - bijector(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) - -Returns a `Stacked <: Bijector` which maps from constrained to unconstrained space. - -The input to the bijector is a vector of values for the whole model, like the input to -`unflatten!!`. These are in constrained space, i.e., respecting variable constraints. -The output is a vector of unconstrained values. - -`init_strategy` is passed to `DynamicPPL.init!!` to determine what values the model is -evaluated with. This may affect the results if the prior distributions or constraints of -variables are dependent on other variables. -""" -function Bijectors.bijector( - model::DynamicPPL.Model, init_strategy::AbstractInitStrategy=InitFromPrior() -) - vi = OnlyAccsVarInfo((BijectorAccumulator(),)) - vi = last(DynamicPPL.init!!(model, vi, init_strategy, UnlinkAll())) - acc = getacc(vi, Val(:Bijector)) - ranges = foldl(acc.sizes; init=UnitRange{Int}[]) do cumulant, sz - last_index = length(cumulant) > 0 ? last(cumulant).stop : 0 - push!(cumulant, (last_index + 1):(last_index + sz)) - return cumulant - end - return Bijectors.Stacked(acc.bijectors, ranges) -end diff --git a/src/accumulators/linked_vec_transforms.jl b/src/accumulators/linked_vec_transforms.jl new file mode 100644 index 000000000..50ff897d0 --- /dev/null +++ b/src/accumulators/linked_vec_transforms.jl @@ -0,0 +1,32 @@ +const LINKEDVECTRANSFORM_ACCNAME = :LinkedVecTransformAccumulator +function _get_linked_vec_transform(val, tv, logjac, vn, dist) + return FixedTransform(Bijectors.VectorBijectors.from_linked_vec(dist)) +end + +""" + LinkedVecTransformAccumulator() + +An accumulator that stores the transform required to convert a linked vector into the +original, untransformed value. +""" +LinkedVecTransformAccumulator() = + VNTAccumulator{LINKEDVECTRANSFORM_ACCNAME}(_get_linked_vec_transform) + +""" + get_linked_vec_transforms(vi::DynamicPPL.AbstractVarInfo) + +Extract the transforms stored in the `LinkedVecTransformAccumulator` of an AbstractVarInfo. +Errors if the AbstractVarInfo does not have a `LinkedVecTransformAccumulator`. +""" +function get_linked_vec_transforms(vi::DynamicPPL.AbstractVarInfo) + return DynamicPPL.getacc(vi, Val(LINKEDVECTRANSFORM_ACCNAME)).values +end + +function get_linked_vec_transforms(rng::Random.AbstractRNG, model::DynamicPPL.Model) + accs = OnlyAccsVarInfo(LinkedVecTransformAccumulator()) + _, accs = init!!(rng, model, accs, InitFromPrior(), UnlinkAll()) + return get_linked_vec_transforms(accs) +end +function get_linked_vec_transforms(model::DynamicPPL.Model) + return get_linked_vec_transforms(Random.default_rng(), model) +end diff --git a/src/accumulators/vector_params.jl b/src/accumulators/vector_params.jl index 2ab48816e..306ec8820 100644 --- a/src/accumulators/vector_params.jl +++ b/src/accumulators/vector_params.jl @@ -37,39 +37,34 @@ end function DynamicPPL.accumulate_assume!!( acc::VectorParamAccumulator, val, - tval::AbstractTransformedValue, + tval::TransformedValue, logjac, vn::VarName, dist::Distribution, ::Any, ) - ral = acc.vn_ranges[vn] - # sometimes you might get UntransformedValue... _get_vector_tval is in - # src/accumulators/vector_values.jl. + rat = acc.vn_ranges[vn] vectorised_tval = _get_vector_tval(val, tval, logjac, vn, dist) - return _update_acc(acc, vectorised_tval, ral, vn) + return _update_acc(acc, vectorised_tval, rat, vn) end function _update_acc( acc::VectorParamAccumulator, - tval::Union{LinkedVectorValue,VectorValue}, - ral::RangeAndLinked, + tval::TransformedValue{V,T}, + rat::RangeAndTransform, vn::VarName, -) - if ( - (ral.is_linked && tval isa VectorValue) || - (!ral.is_linked && tval isa LinkedVectorValue) - ) +) where {V<:AbstractVector,T} + if rat.transform != tval.transform throw( ArgumentError( - "The LogDensityFunction specifies that `$vn` should be $(ral.is_linked ? "linked" : "unlinked"), but the vector values contain a $(tval isa LinkedVectorValue ? "linked" : "unlinked") value for that variable.", + "The transform associated with the VarName `$vn` in the LogDensityFunction is not the same as the transform of the TransformedValue provided for that VarName. This likely means that the vector values provided are not consistent with the LogDensityFunction (e.g. if they were obtained from a different model).", ), ) end vec_val = DynamicPPL.get_internal_value(tval) len = length(vec_val) - expected_len = length(ral.range) + expected_len = length(rat.range) if len != expected_len throw( ArgumentError( @@ -78,15 +73,15 @@ function _update_acc( ) end - if any(acc.set_indices[ral.range]) + if any(acc.set_indices[rat.range]) throw( ArgumentError( "Setting to the same indices in the output vector more than once. This likely means that the vector values provided are not consistent with the LogDensityFunction (e.g. if they were obtained from a different model).", ), ) end - Accessors.@set acc.vals = BangBang.setindex!!(acc.vals, vec_val, ral.range) - acc.set_indices[ral.range] .= true + Accessors.@set acc.vals = BangBang.setindex!!(acc.vals, vec_val, rat.range) + acc.set_indices[rat.range] .= true return acc end diff --git a/src/accumulators/vector_values.jl b/src/accumulators/vector_values.jl index 970c5893f..1973b03c4 100644 --- a/src/accumulators/vector_values.jl +++ b/src/accumulators/vector_values.jl @@ -1,44 +1,78 @@ const VECTORVAL_ACCNAME = :VectorValue -_get_vector_tval(val, tval::Union{VectorValue,LinkedVectorValue}, logjac, vn, dist) = tval -function _get_vector_tval(val, ::UntransformedValue, logjac, vn, dist) + +""" + _get_vector_tval(val, tval, logjac, vn, dist) + +Generate a `TransformedValue` that always has a vector as its stored value. +""" +function _get_vector_tval( + val, tval::TransformedValue{V,T}, logjac, vn, dist +) where {V<:AbstractVector,T} + # If it's already an AbstractVector transformed value, then we are done. + # `tval.transform` could be a DynamicLink(), Unlink(), or some fixed transform that + # vectorises; it doesn't matter. + return tval +end +function _get_vector_tval(val, tval::TransformedValue{V,T}, logjac, vn, dist) where {V,T} + # If it's *not* an AbstractVector transformed value, then in principle, we need to + # vectorise it before storing. We *could* do this by reversing the transformation, and + # then applying a vectorisation transform; but the truth is that this is most likely to + # be a user error where they tried to use a FixedTransform that does not vectorise. So + # we just error here. + return error( + "Expected a vectorised or untransformed value for variable $vn, but got a TransformedValue with a value of $(tval.value).", + ) +end +function _get_vector_tval( + val, ::TransformedValue{V,NoTransform}, logjac, vn, dist +) where {V} + # This is an untransformed value, so we need to vectorise it. We can do this by applying + # to_vec(dist). f = Bijectors.VectorBijectors.to_vec(dist) new_val, logjac = with_logabsdet_jacobian(f, val) @assert iszero(logjac) # otherwise we're in trouble... - return VectorValue(new_val, inverse(f)) + return TransformedValue(new_val, Unlink()) +end +function _get_vector_tval( + val, ::TransformedValue{V,NoTransform}, logjac, vn, dist +) where {V<:AbstractVector} + # This is the same as above but just shortcircuited because `to_vec` should always + # return TypedIdentity. Note that this method needs to be preserved to avoid method + # ambiguities. + return TransformedValue(val, Unlink()) end # This is equivalent to `varinfo.values` where `varinfo isa VarInfo` """ VectorValueAccumulator() -An accumulator that collects `VectorValue`s and `LinkedVectorValue`s seen during model -execution. +An accumulator that collects vectorised values, i.e. `TransformedValue{<:AbstractVector}`. -Whether a `VectorValue` or `LinkedVectorValue` is collected depends on the transform -strategy used when evaluating the model. For variables that are specified as being linked -(i.e., `DynamicLink()`), a `LinkedVectorValue` will be collected. Conversely, for variables -that are not specified as being linked, a `VectorValue` will be collected. +The exact type of the vectorised value (i.e., `tval.transform`) will depend on the transform +strategy that the model was evaluated with, and specifically, is equal to +`target_transform(transform_strategy, vn)`; *except* for the case where `target_transform` +is `Untransformed`, i.e., no transformation is to be applied. In this case, the +`VectorValueAccumulator` will apply a vectorisation transform to the untransformed value, +i.e., generate a `TransformedValue` with `Unlink()` as the transform. """ VectorValueAccumulator() = VNTAccumulator{VECTORVAL_ACCNAME}(_get_vector_tval) """ internal_values_as_vector(vnt::VarNamedTuple) -Concatenate all the `VectorValue`s and `LinkedVectorValue`s in `vnt` into a single vector. -This will error if any of the values in `vnt` are not `VectorValue`s or -`LinkedVectorValue`s. +Concatenate all the values in `vnt` into a single vector. This will error if any of the +values in `vnt` contain non-vector values. ```jldoctest julia> using DynamicPPL -julia> # In a real setting the other fields would be filled in with meaningful values. - vnt = @vnt begin - x := VectorValue([1.0, 2.0], nothing) - y := LinkedVectorValue([3.0], nothing) +julia> vnt = @vnt begin + x := TransformedValue([1.0, 2.0], Unlink()) + y := TransformedValue([3.0], DynamicLink()) end VarNamedTuple -├─ x => VectorValue{Vector{Float64}, Nothing}([1.0, 2.0], nothing) -└─ y => LinkedVectorValue{Vector{Float64}, Nothing}([3.0], nothing) +├─ x => TransformedValue{Vector{Float64}, Unlink}([1.0, 2.0], Unlink()) +└─ y => TransformedValue{Vector{Float64}, DynamicLink}([3.0], DynamicLink()) julia> internal_values_as_vector(vnt) 3-element Vector{Float64}: @@ -69,8 +103,8 @@ julia> # note InitFromParams provides parameters in untransformed space julia> # but because we specified LinkAll(), the vectorised values are transformed vector_vals = get_vector_values(accs) VarNamedTuple -├─ x => LinkedVectorValue{Vector{Float64}, Bijectors.VectorBijectors.TypedIdentity}([1.0, 2.0], Bijectors.VectorBijectors.TypedIdentity()) -└─ y => LinkedVectorValue{Vector{Float64}, Bijectors.VectorBijectors.OnlyWrap{Bijectors.VectorBijectors.Truncate{Float64, Float64}}}([0.0], Bijectors.VectorBijectors.OnlyWrap{Bijectors.VectorBijectors.Truncate{Float64, Float64}}(Bijectors.VectorBijectors.Truncate{Float64, Float64}(0.0, 1.0))) +├─ x => TransformedValue{Vector{Float64}, DynamicLink}([1.0, 2.0], DynamicLink()) +└─ y => TransformedValue{Vector{Float64}, DynamicLink}([0.0], DynamicLink()) julia> # we can extract the internal values as a single vector internal_values_as_vector(vector_vals) @@ -83,6 +117,9 @@ julia> # we can extract the internal values as a single vector function internal_values_as_vector(vnt::VarNamedTuple) return mapfoldl(pair -> _as_vector(pair.second), vcat, vnt; init=Union{}[]) end -_as_vector(val::VectorValue) = DynamicPPL.get_internal_value(val) -_as_vector(val::LinkedVectorValue) = DynamicPPL.get_internal_value(val) -_as_vector(val) = error("don't know how to convert $(typeof(val)) to a vector value") +_as_vector(val::TransformedValue{T}) where {T<:AbstractVector} = val.value +function _as_vector(val::TransformedValue{T}) where {T} + return error( + "Expected a TransformedValue with a vector as its value, but got a TransformedValue with a value of $val.", + ) +end diff --git a/src/contexts/default.jl b/src/contexts/default.jl index 266419a52..e6dea8322 100644 --- a/src/contexts/default.jl +++ b/src/contexts/default.jl @@ -39,14 +39,14 @@ function tilde_assume!!( # value is supposed to be linked or not. # This can definitely be unified in the future. tval = get_transformed_value(vi, vn) - trf = if tval isa LinkedVectorValue - # Note that we can't rely on the stored transform being correct (e.g. if new values - # were placed in `vi` via `unflatten!!`, so we regenerate the transforms. + trf = if tval.transform isa DynamicLink Bijectors.VectorBijectors.from_linked_vec(right) - elseif tval isa VectorValue + elseif tval.transform isa Unlink Bijectors.VectorBijectors.from_vec(right) + elseif tval.transform isa FixedTransform + tval.transform.transform else - error("Expected transformed value to be a VectorValue or LinkedVectorValue") + error("Expected transformed value to be a vectorised value") end x, inv_logjac = with_logabsdet_jacobian(trf, get_internal_value(tval)) vi = accumulate_assume!!(vi, x, tval, -inv_logjac, vn, right, template) diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 6e36ad868..92bcea094 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -14,13 +14,13 @@ abstract type AbstractInitStrategy end Generate a new value for a random variable with the given distribution. -This function must return an `AbstractTransformedValue`. +This function must return an `TransformedValue`. If `strategy` provides values that are already untransformed (e.g., a Float64 within (0, 1) -for `dist::Beta`, then you should return an `UntransformedValue`. +for `dist::Beta`, then you should return a `TransformedValue` with a `NoTransform()`. -Otherwise, often there are cases where this will return either a `VectorValue` or a -`LinkedVectorValue`, for example, if the strategy is reading from an existing `VarInfo`. +Otherwise, often there are cases where this will return a transformed value; for example, if +the strategy is reading from an existing `VarInfo`. """ function init end @@ -76,7 +76,7 @@ Obtain new values by sampling from the prior distribution. """ struct InitFromPrior <: AbstractInitStrategy end function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, ::InitFromPrior) - return UntransformedValue(rand(rng, dist)) + return TransformedValue(rand(rng, dist), NoTransform()) end """ @@ -106,7 +106,7 @@ end function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, u::InitFromUniform) sz = Bijectors.VectorBijectors.linked_vec_length(dist) y = u.lower .+ ((u.upper - u.lower) .* rand(rng, sz)) - return LinkedVectorValue(y, Bijectors.VectorBijectors.from_linked_vec(dist)) + return TransformedValue(y, DynamicLink()) end """ @@ -178,18 +178,10 @@ function init( p.fallback === nothing && error("A `missing` value was provided for the variable `$(vn)`.") init(rng, vn, dist, p.fallback) - elseif x isa VectorValue - # In this case, we can't trust the transform stored in x because the _value_ - # in x may have been changed via unflatten!! without the transform being - # updated. Therefore, we always recompute the transform here. - VectorValue(x.val, Bijectors.VectorBijectors.from_vec(dist)) - elseif x isa LinkedVectorValue - # Same as above. - LinkedVectorValue(x.val, Bijectors.VectorBijectors.from_linked_vec(dist)) - elseif x isa UntransformedValue + elseif x isa TransformedValue x else - UntransformedValue(x) + TransformedValue(x, NoTransform()) end else p.fallback === nothing && error("No value was provided for the variable `$(vn)`.") @@ -220,18 +212,10 @@ function init( ) return if haskey(p.params, vn) x = p.params[vn] - if x isa VectorValue - # In this case, we can't trust the transform stored in x because the _value_ - # in x may have been changed via unflatten!! without the transform being - # updated. Therefore, we always recompute the transform here. - VectorValue(x.val, Bijectors.VectorBijectors.from_vec(dist)) - elseif x isa LinkedVectorValue - # Same as above. - LinkedVectorValue(x.val, Bijectors.VectorBijectors.from_linked_vec(dist)) - elseif x isa UntransformedValue + if x isa TransformedValue x else - UntransformedValue(x) + TransformedValue(x, NoTransform()) end else error("No value was provided for the variable `$(vn)`.") @@ -247,7 +231,7 @@ function DynamicPPL.get_param_eltype(p::InitFromParamsUnsafe) end """ - RangeAndLinked + RangeAndTransform Suppose we have vectorised parameters `params::AbstractVector{<:Real}`. Each random variable in the model will in general correspond to a sub-vector of `params`. This struct stores @@ -256,12 +240,12 @@ an unlinked value. $(TYPEDFIELDS) """ -struct RangeAndLinked +struct RangeAndTransform{T} # indices that the variable corresponds to in the vectorised parameter range::UnitRange{Int} - # whether the variable is linked or unlinked - is_linked::Bool + transform::T end +DynamicPPL.get_transform(rat::RangeAndTransform) = rat.transform """ InitFromVector( @@ -278,13 +262,13 @@ end A struct that wraps a vector of parameter values, plus information about how random variables map to ranges in that vector. -The `transform_strategy` argument in fact duplicates information stored inside `varname_ranges`. -For example, if every `RangeAndLinked` in `varname_ranges` has `is_linked == true`, then -`transform_strategy` will be `LinkAll()`. +The `transform_strategy` argument in fact duplicates information stored inside +`varname_ranges`. For example, if every `RangeAndTransform` in `varname_ranges` has +`transform == DynamicLink()`, then `transform_strategy` will be `LinkAll()`. -However, storing `transform_strategy` here is a way to communicate at the type level whether all -variables are linked or unlinked, which provides much better performance in the case where -all variables are linked or unlinked, due to improved type stability. +However, storing `transform_strategy` here is a way to communicate at the type level whether +all variables are linked or unlinked, which provides much better performance in the case +where all variables are linked or unlinked, due to improved type stability. """ struct InitFromVector{ T<:AbstractVector{<:Real},V<:VarNamedTuple,L<:AbstractTransformStrategy @@ -310,28 +294,28 @@ faster. """ @inline maybe_view_ad(vect::AbstractArray, range) = view(vect, range) -function _get_range_and_linked(ifv::InitFromVector, vn::VarName) +function _get_range_and_transform(ifv::InitFromVector, vn::VarName) # The type assertion does nothing if `varname_ranges` has concrete element types, as is # the case for all type stable models. However, if the model is not type stable, # vr.varname_ranges[vn] may infer to have type `Any`. In this case it is helpful to - # assert that it is a RangeAndLinked, because even though it remains non-concrete, it'll - # allow the compiler to infer the types of `range` and `is_linked`. - return ifv.varname_ranges[vn]::RangeAndLinked + # assert that it is a RangeAndTransform, because even though it remains non-concrete, + # it'll allow the compiler to infer the types of `range`. + # TODO(penelopeysm): Investigate if this is still necessary + return ifv.varname_ranges[vn]::RangeAndTransform end function init(::Random.AbstractRNG, vn::VarName, dist::Distribution, ifv::InitFromVector) - range_and_linked = _get_range_and_linked(ifv, vn) - vect = maybe_view_ad(ifv.vect, range_and_linked.range) + rat = _get_range_and_transform(ifv, vn) + vect = maybe_view_ad(ifv.vect, rat.range) # This block here is why we store transform_strategy inside the InitFromVector, as it # allows for type stability. - return if ifv.transform_strategy isa LinkAll - LinkedVectorValue(vect, Bijectors.VectorBijectors.from_linked_vec(dist)) + tfm = if ifv.transform_strategy isa LinkAll + DynamicLink() elseif ifv.transform_strategy isa UnlinkAll - VectorValue(vect, Bijectors.VectorBijectors.from_vec(dist)) - elseif range_and_linked.is_linked - LinkedVectorValue(vect, Bijectors.VectorBijectors.from_linked_vec(dist)) + Unlink() else - VectorValue(vect, Bijectors.VectorBijectors.from_vec(dist)) + rat.transform end + return TransformedValue(vect, tfm) end function get_param_eltype(strategy::InitFromVector) return eltype(strategy.vect) diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index b1ae82f2b..7fb89141c 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -11,7 +11,7 @@ using DynamicPPL: ThreadSafeVarInfo, VarInfo, OnlyAccsVarInfo, - RangeAndLinked, + RangeAndTransform, default_accumulators, float_type_with_fallback, getlogjoint, @@ -74,8 +74,8 @@ are several functions in DynamicPPL that are 'supported' out of the box: parameters to be used for constructing the single vectorised representation of the model. The parameters stored in this argument determine whether the resulting `LogDensityFunction` will be linked, unlinked, or mixed. For example, if you pass a `VarNamedTuple` consisting -entirely of `LinkedVectorValue`s, then the resulting `LogDensityFunction` will be fully -linked. +entirely of `TransformedValue{T,DynamicLink}`s, then the resulting `LogDensityFunction` will +be fully linked. You can pass either: @@ -201,34 +201,13 @@ struct LogDensityFunction{ ); adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, ) - all_ranges = get_ranges_and_linked(vnt) - # Figure out if all variables are linked, unlinked, or mixed - linked_vns = Set{VarName}() - unlinked_vns = Set{VarName}() - for vn in keys(all_ranges) - if all_ranges[vn].is_linked - push!(linked_vns, vn) - else - push!(unlinked_vns, vn) - end - end - transform_strategy = if isempty(unlinked_vns) - LinkAll() - elseif isempty(linked_vns) - UnlinkAll() - else - # We could have a marginal performance optimisation here by checking whether - # linked_vns or unlinked_vns is smaller, and then using LinkSome or UnlinkSome - # accordingly, so that there are fewer `subsumes` checks. However, in practice, - # the mixed linking case performance is going to be a lot worse than in the - # fully linked or fully unlinked cases anyway, so this would be a bit of a - # premature optimisation. - LinkSome(linked_vns, UnlinkAll()) - end + all_ranges = get_rangeandtransforms(vnt) + transform_strategy = infer_transform_strategy_from_values(vnt) + # Get vectorised parameters. Note that `internal_values_as_vector` just concatenates # all the vectors inside in iteration order of the VNT's keys. *In principle*, the # result of that should always be consistent with the ranges extracted above via - # `get_ranges_and_linked`, since both are based on the same underlying VNT, and both + # `get_rangeandtransforms`, since both are based on the same underlying VNT, and both # iterate over the keys in the same order. However, this is an implementation # detail, and so we should probably not rely on it! # Therefore, we use `to_vector_params_inner` to also perform some checks that the @@ -524,25 +503,25 @@ _use_closure(::ADTypes.AbstractADType) = false ###################################################### """ - get_ranges_and_linked(vnt::VarNamedTuple) + get_rangeandtransforms(vnt::VarNamedTuple) -Given a `VarNamedTuple` that contains `VectorValue`s and `LinkedVectorValue`s, extract the -ranges of each variable in the vectorised parameter representation, along with whether each -variable is linked or unlinked. +Given a `VarNamedTuple` that contains vectorised values (i.e., +`TransformedValue{<:AbstractVector}`), extract the ranges of each variable in the vectorised +parameter representation, along with the transform status of each variable. This function returns a VarNamedTuple mapping all VarNames to their corresponding -`RangeAndLinked`. +`RangeAndTransform`. """ -function get_ranges_and_linked(vnt::VarNamedTuple) +function get_rangeandtransforms(vnt::VarNamedTuple) # Note: can't use map_values!! here as that might mutate the VNT itself! ranges_vnt, _ = mapreduce( identity, function ((ranges_vnt, offset), pair) vn, tv = pair - val = tv.val + val = get_internal_value(tv) range = offset:(offset + length(val) - 1) offset += length(val) - ral = RangeAndLinked(range, tv isa LinkedVectorValue) + ral = RangeAndTransform(range, tv.transform) template = vnt.data[AbstractPPL.getsym(vn)] ranges_vnt = templated_setindex!!(ranges_vnt, ral, vn, template) return ranges_vnt, offset @@ -583,9 +562,8 @@ end ldf::LogDensityFunction ) -Extract vectorised values from a `VarNamedTuple` that contains `VectorValue`s and -`LinkedVectorValue`s, and concatenate them into a single vector that is consistent with the -ranges specified in the `LogDensityFunction`. +Extract vectorised values from a `VarNamedTuple`, and concatenate them into a single vector +that is consistent with the ranges specified in the `LogDensityFunction`. This is useful when you want to regenerate new vectorised parameters but using a different initialisation strategy. @@ -619,13 +597,10 @@ function to_vector_params_inner( ral = ranges[vn] # check transform lines up - if ( - (ral.is_linked && tval isa VectorValue) || - (!ral.is_linked && tval isa LinkedVectorValue) - ) + if ral.transform != tval.transform throw( ArgumentError( - "The LogDensityFunction specifies that `$vn` should be $(ral.is_linked ? "linked" : "unlinked"), but the vector values contain a $(tval isa LinkedVectorValue ? "linked" : "unlinked") value for that variable.", + "The variable `$vn` has transform status $(ral.transform) in the LogDensityFunction, but the provided VarNamedTuple has transform status $(tval.transform) for this variable. This likely means that the vector values provided are not consistent with the LogDensityFunction (e.g. if they were obtained from a different model).", ), ) end diff --git a/src/onlyaccs.jl b/src/onlyaccs.jl index 91faaf7a9..77ccd77d2 100644 --- a/src/onlyaccs.jl +++ b/src/onlyaccs.jl @@ -57,7 +57,7 @@ end # This allows us to make use of the main tilde_assume!!(::InitContext) method without # having to duplicate the code here @inline function DynamicPPL.setindex_with_dist!!( - vi::OnlyAccsVarInfo, ::AbstractTransformedValue, ::Distribution, ::VarName, ::Any + vi::OnlyAccsVarInfo, ::TransformedValue, ::Distribution, ::VarName, ::Any ) return vi end diff --git a/src/submodel.jl b/src/submodel.jl index 31b6712ac..97618db66 100644 --- a/src/submodel.jl +++ b/src/submodel.jl @@ -35,8 +35,9 @@ the model can be sampled from but not necessarily evaluated for its log density. # Examples -## Simple example -```jldoctest submodel-to_submodel; setup=:(using Distributions) +```jldoctest submodel-to_submodel +julia> using DynamicPPL, Distributions + julia> @model function demo1(x) x ~ Normal() return 1 + abs(x) @@ -48,27 +49,29 @@ julia> @model function demo2(x, y) end; ``` -When we sample from the model `demo2(missing, 0.4)` random variable `x` will be sampled: +When we sample from the model `demo2(missing, 0.4)` the random variable `x` will be sampled, but +it will be prefixed with `a` (the left-hand side of the tilde): + ```jldoctest submodel-to_submodel -julia> vi = VarInfo(demo2(missing, 0.4)); +julia> model = demo2(missing, 0.4); -julia> @varname(a.x) in keys(vi) +julia> haskey(rand(model), @varname(a.x)) true ``` -The variable `a` is not tracked. However, it will be assigned the return value of `demo1`, -and can be used in subsequent lines of the model, as shown above. -```jldoctest submodel-to_submodel -julia> @varname(a) in keys(vi) -false -``` +The variable `a` will be assigned the return value of `demo1`, and can be used in subsequent +lines of the model, e.g. in the definition of `y` above. -We can check that the log joint probability of the model accumulated in `vi` is correct: +We can verify that the log joint probability of the model accumulated in `vi` is correct: ```jldoctest submodel-to_submodel -julia> x = vi[@varname(a.x)]; +julia> accs = setacc!!(OnlyAccsVarInfo(), RawValueAccumulator(false)); + +julia> _, accs = init!!(model, accs, InitFromPrior(), UnlinkAll()); -julia> getlogjoint(vi) ≈ logpdf(Normal(), x) + logpdf(Uniform(0, 1 + abs(x)), 0.4) +julia> x = get_raw_values(accs)[@varname(a.x)]; + +julia> getlogjoint(accs) ≈ logpdf(Normal(), x) + logpdf(Uniform(0, 1 + abs(x)), 0.4) true ``` @@ -87,52 +90,30 @@ julia> @model function demo2_no_prefix(x, z) return z ~ Uniform(-a, 1) end; -julia> vi = VarInfo(demo2_no_prefix(missing, 0.4)); +julia> model = demo2_no_prefix(missing, 0.4); -julia> @varname(x) in keys(vi) # here we just use `x` instead of `a.x` +julia> haskey(rand(model), @varname(x)) # here we just use `x` instead of `a.x` true ``` -However, not using prefixing is generally not recommended as it can lead to variable name clashes -unless one is careful. For example, if we're re-using the same model twice in a model, not using prefixing -will lead to variable name clashes: However, one can manually prefix using the [`prefix(::Model, input)`](@ref): +However, not using prefixing is generally not recommended as it can lead to variable name +clashes unless one is careful. For example, if the same submodel is used multiple times in a +model, not using prefixing will lead to variable name clashes. + +One can manually specify a prefix using [`prefix(::Model, prefix_varname)`](@ref): + ```jldoctest submodel-to_submodel-prefix julia> @model function demo2(x, y, z) - a ~ to_submodel(prefix(demo1(x), :sub1), false) - b ~ to_submodel(prefix(demo1(y), :sub2), false) + a ~ to_submodel(prefix(demo1(x), @varname(sub1)), false) + b ~ to_submodel(prefix(demo1(y), @varname(sub2)), false) return z ~ Uniform(-a, b) end; -julia> vi = VarInfo(demo2(missing, missing, 0.4)); - -julia> @varname(sub1.x) in keys(vi) -true +julia> model = demo2(missing, missing, 0.4); -julia> @varname(sub2.x) in keys(vi) +julia> haskey(rand(model), @varname(sub1.x)) true -``` - -Variables `a` and `b` are not tracked, but are assigned the return values of the respective -calls to `demo1`: -```jldoctest submodel-to_submodel-prefix -julia> @varname(a) in keys(vi) -false - -julia> @varname(b) in keys(vi) -false -``` - -We can check that the log joint probability of the model accumulated in `vi` is correct: - -```jldoctest submodel-to_submodel-prefix -julia> sub1_x = vi[@varname(sub1.x)]; - -julia> sub2_x = vi[@varname(sub2.x)]; - -julia> logprior = logpdf(Normal(), sub1_x) + logpdf(Normal(), sub2_x); - -julia> loglikelihood = logpdf(Uniform(-1 - abs(sub1_x), 1 + abs(sub2_x)), 0.4); -julia> getlogjoint(vi) ≈ logprior + loglikelihood +julia> haskey(rand(model), @varname(sub2.x)) true ``` """ diff --git a/src/test_utils/models.jl b/src/test_utils/models.jl index 8efbc5a64..d54ac188c 100644 --- a/src/test_utils/models.jl +++ b/src/test_utils/models.jl @@ -164,50 +164,6 @@ function rand_prior_true(rng::Random.AbstractRNG, model::Model{typeof(demo_lkjch return (x=x,) end -# Model to test `StaticTransformation` with. -""" - demo_static_transformation() - -Simple model for which [`default_transformation`](@ref) returns a [`StaticTransformation`](@ref). -""" -@model function demo_static_transformation() - s ~ InverseGamma(2, 3) - m ~ Normal(0, sqrt(s)) - 1.5 ~ Normal(m, sqrt(s)) - 2.0 ~ Normal(m, sqrt(s)) - - return (; s, m, x=[1.5, 2.0]) -end - -function DynamicPPL.default_transformation(::Model{typeof(demo_static_transformation)}) - b = Bijectors.Stacked(Bijectors.elementwise(exp), identity) - return DynamicPPL.StaticTransformation(b) -end - -posterior_mean(::Model{typeof(demo_static_transformation)}) = (s=49 / 24, m=7 / 6) -function logprior_true(::Model{typeof(demo_static_transformation)}, s, m) - return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m) -end -function loglikelihood_true(::Model{typeof(demo_static_transformation)}, s, m) - return logpdf(Normal(m, sqrt(s)), 1.5) + logpdf(Normal(m, sqrt(s)), 2.0) -end -function varnames(::Model{typeof(demo_static_transformation)}) - return [@varname(s), @varname(m)] -end -function logprior_true_with_logabsdet_jacobian( - model::Model{typeof(demo_static_transformation)}, s, m -) - return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) -end - -function rand_prior_true( - rng::Random.AbstractRNG, model::Model{typeof(demo_static_transformation)} -) - s = rand(rng, InverseGamma(2, 3)) - m = rand(rng, Normal(0, sqrt(s))) - return (s=s, m=m) -end - # A collection of models for which the posterior should be "similar". # Some utility methods for these. function _demo_logprior_true_with_logabsdet_jacobian(model, s, m) @@ -889,5 +845,4 @@ const ALL_MODELS = ( demo_dynamic_constraint(), demo_one_variable_multiple_constraints(), demo_lkjchol(), - demo_static_transformation(), ) diff --git a/src/test_utils/varinfo.jl b/src/test_utils/varinfo.jl index 880446d3d..8f27492a7 100644 --- a/src/test_utils/varinfo.jl +++ b/src/test_utils/varinfo.jl @@ -4,19 +4,19 @@ # Utilities for testing varinfos. """ - test_values(vi::AbstractVarInfo, vals::NamedTuple, vns) + test_values(vnt::VarNamedTuple, vals::NamedTuple, vns) -Test that `vi[vn]` corresponds to the correct value in `vals` for every `vn` in `vns`. +Test that `vnt[vn]` corresponds to the correct value in `vals` for every `vn` in `vns`. """ -function test_values(vi::AbstractVarInfo, vals::NamedTuple, vns; compare=isequal, kwargs...) +function test_values(vnt::VarNamedTuple, vals::NamedTuple, vns; compare=isequal, kwargs...) for vn in vns val = AbstractPPL.getvalue(vals, vn) # TODO(mhauru) Workaround for https://github.com/JuliaLang/LinearAlgebra.jl/pull/1404 # Remove once the fix is all Julia versions we support. if val isa Cholesky - @test compare(vi[vn].L, val.L; kwargs...) + @test compare(vnt[vn].L, val.L; kwargs...) else - @test compare(vi[vn], val; kwargs...) + @test compare(vnt[vn], val; kwargs...) end end end diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 9faa71ebc..29cbc703a 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -114,12 +114,12 @@ haskey(vi::ThreadSafeVarInfo, vn::VarName) = haskey(vi.varinfo, vn) is_transformed(vi::ThreadSafeVarInfo) = is_transformed(vi.varinfo) -function link!!(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...) - return Accessors.@set vi.varinfo = link!!(t, vi.varinfo, args...) +function link!!(vi::ThreadSafeVarInfo, args...) + return Accessors.@set vi.varinfo = link!!(vi.varinfo, args...) end -function invlink!!(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...) - return Accessors.@set vi.varinfo = invlink!!(t, vi.varinfo, args...) +function invlink!!(vi::ThreadSafeVarInfo, args...) + return Accessors.@set vi.varinfo = invlink!!(vi.varinfo, args...) end get_transform_strategy(vi::ThreadSafeVarInfo) = get_transform_strategy(vi.varinfo) diff --git a/src/transformed_values.jl b/src/transformed_values.jl index 10d474de1..c451bc558 100644 --- a/src/transformed_values.jl +++ b/src/transformed_values.jl @@ -8,202 +8,97 @@ # executing the model, but `unflatten!!` does not have that information. As long as we # depend on the behaviour of `unflatten!!` somewhere, we cannot get rid of it. -# TODO(mhauru) Related to the above, I think we should reconsider whether we should store -# transformations at all. We rarely use them, since they may be dynamic in a model. -# tilde_assume!! rather gets the transformation from the current distribution encountered -# during model execution. However, this would change the interface quite a lot, so I want to -# finish implementing VarInfo using VNT (mostly) respecting the old interface first. -# -# NOTE(penelopeysm) The above is in principle doable right now. the main issue with removing -# the transform is that we cannot get `varinfo[vn]` any more. It is arguable whether this -# method is really needed. On one hand, it is a pretty useful way of seeing the current -# value of a variable in the VarInfo. On the other hand, it is not guaranteed to be correct -# (because `unflatten!` might change the required transform); so one could argue that the -# question of "what is the true value" is generally unanswerable, and we should not expose a -# method that pretends to know the answer.. I would lean towards removing it, but doing so -# would require a fair amount of changes in the test suite, so it will have to wait for a -# time when fewer big PRs are ongoing. - """ - AbstractTransformedValue - -An abstract type for values that enter the DynamicPPL tilde-pipeline. - -These values are generated by an [`AbstractInitStrategy`](@ref): the function -[`DynamicPPL.init`](@ref) should return an `AbstractTransformedValue`. - -Each `AbstractTransformedValue` contains some version of the actual variable's value, -together with a transformation that can be used to convert the internal value back to the -original space. - -Current subtypes are [`VectorValue`](@ref), [`LinkedVectorValue`](@ref), and -[`UntransformedValue`](@ref). DynamicPPL's [`VarInfo`](@ref) type stores either -`VectorValue`s or `LinkedVectorValue`s internally, depending on the link status of the -`VarInfo`. - -!!! warning - Even though the subtypes listed above are public, this abstract type is not itself part - of the public API and should not be subtyped by end users. Much of DynamicPPL's model - evaluation methods depends on these subtypes having predictable behaviour, i.e., their - transforms should always be `Bijectors.VectorBijectors.from_linked_vec(dist)`, - `Bijectors.VectorBijectors.from_vec(dist)`, or their inverse. If you create a new - subtype of `AbstractTransformedValue` and use it, DynamicPPL will not know how to handle - it and may either error or silently give incorrect results. - - In principle, it should be possible to subtype this and allow for custom transformations - to be used (not just the 'default' ones). However, this is not currently implemented. - -Subtypes of this should implement the following functions: - -- `DynamicPPL.get_transform(tv::AbstractTransformedValue)`: Get the transformation that - converts the internal value back to the original space. - -- `DynamicPPL.get_internal_value(tv::AbstractTransformedValue)`: Get the internal value - stored in `tv`. + abstract type AbstractTransform end -- `DynamicPPL.set_internal_value(tv::AbstractTransformedValue, new_val)`: Create a new - `AbstractTransformedValue` with the same transformation as `tv`, but with internal value - `new_val`. +An abstract type to represent the intended transformation for a variable. """ -abstract type AbstractTransformedValue end +abstract type AbstractTransform end """ - get_transform(tv::AbstractTransformedValue) - -Get the transformation that converts the internal value back to the raw value. - -!!! warning - If the distribution associated with the variable has changed since this - `AbstractTransformedValue` was created, this transform may be inaccurate. This can - happen e.g. if `unflatten!!` has been called on a VarInfo containing this. + DynamicLink <: AbstractTransform - Consequently, when the distribution on the right-hand side of a tilde-statement is - available, you should always prefer regenerating the transform from that distribution - rather than using this function. +A type indicating that a target transformation should be derived by recomputing +`Bijectors.VectorBijectors.from_linked_vec(dist)`, where `dist` is the distribution on the +right-hand side of the tilde. """ -function get_transform end +struct DynamicLink <: AbstractTransform end """ - get_internal_value(tv::AbstractTransformedValue) + Unlink <: AbstractTransform -Get the internal value stored in `tv`. +A type indicating that a target transformation should be derived by recomputing +`Bijectors.VectorBijectors.from_vec(dist)`, where `dist` is the distribution on the +right-hand side of the tilde. """ -function get_internal_value end +struct Unlink <: AbstractTransform end """ - set_internal_value(tv::AbstractTransformedValue, new_val) + NoTransform <: AbstractTransform -Create a new `AbstractTransformedValue` with the same transformation as `tv`, but with -internal value `new_val`. +A type indicating that the value is not transformed. """ -function set_internal_value end +struct NoTransform <: AbstractTransform end """ - VectorValue{V<:AbstractVector,T} - -A transformed value that stores its internal value as a vectorised form. This is what -VarInfo sees as an "unlinked value". + FixedTransform{F} <: AbstractTransform -These values can be generated when using `InitFromParams` with a VarInfo's internal values. +A type to represent a fixed (static) transformation of type `F`. """ -struct VectorValue{V<:AbstractVector,T} <: AbstractTransformedValue - "The internal (vectorised) value." - val::V - """The unvectorisation transform required to convert `val` back to the original space. - - Note that this transform is cached and thus may be inaccurate if `unflatten!!` is called - on the VarInfo containing this `VectorValue`. This transform is only ever used when - calling `varinfo[vn]` to get the original value back; in all other cases, where model - evaluation occurs, the correct transform is determined from the distribution associated - with the variable.""" - transform::T - function VectorValue(val::V, tfm::T) where {V<:AbstractVector,T} - return new{V,T}(val, tfm) - end +struct FixedTransform{F} <: AbstractTransform + transform::F +end +Base.:(==)(ft1::FixedTransform, ft2::FixedTransform) = ft1.transform == ft2.transform +function Base.isequal(ft1::FixedTransform, ft2::FixedTransform) + return isequal(ft1.transform, ft2.transform) end """ - LinkedVectorValue{V<:AbstractVector,T} + TransformedValue{V,T<:AbstractTransform} -A transformed value that stores its internal value as a linked and vectorised form. This is -what VarInfo sees as a "linked value". +A struct to represent a value that has undergone some transformation. -These values can be generated when using `InitFromParams` with a VarInfo's internal values. -""" -struct LinkedVectorValue{V<:AbstractVector,T} <: AbstractTransformedValue - "The internal (linked + vectorised) value." - val::V - """The unlinking transform required to convert `val` back to the original space. +The *transformed* value is stored in the `value` field, and the *inverse* transformation is +stored in the `transform` field. - Note that this transform is cached and thus may be inaccurate if `unflatten!!` is called - on the VarInfo containing this `VectorValue`. This transform is only ever used when - calling `varinfo[vn]` to get the original value back; in all other cases, where model - evaluation occurs, the correct transform is determined from the distribution associated - with the variable.""" +That means that `get_transform(tv)(get_internal_value(tv))` should return the raw, +untransformed, value associated with `tv`. +""" +struct TransformedValue{V,T<:AbstractTransform} + value::V transform::T - function LinkedVectorValue(val::V, tfm::T) where {V<:AbstractVector,T} - return new{V,T}(val, tfm) - end end - -for T in (:VectorValue, :LinkedVectorValue) - @eval begin - function Base.:(==)(tv1::$T, tv2::$T) - return (tv1.val == tv2.val) & (tv1.transform == tv2.transform) - end - function Base.isequal(tv1::$T, tv2::$T) - return isequal(tv1.val, tv2.val) && isequal(tv1.transform, tv2.transform) - end - - get_transform(tv::$T) = tv.transform - get_internal_value(tv::$T) = tv.val - - function set_internal_value(tv::$T, new_val) - return $T(new_val, tv.transform) - end - end +function Base.:(==)(tv1::TransformedValue, tv2::TransformedValue) + return (get_internal_value(tv1) == get_internal_value(tv2)) & + (get_transform(tv1) == get_transform(tv2)) end - -""" - UntransformedValue{V} - -A raw, untransformed, value. - -These values can be generated from initialisation strategies such as `InitFromPrior`, -`InitFromUniform`, and `InitFromParams` on a standard container type. -""" -struct UntransformedValue{V} <: AbstractTransformedValue - "The value." - val::V - UntransformedValue(val::V) where {V} = new{V}(val) +function Base.isequal(tv1::TransformedValue, tv2::TransformedValue) + return isequal(get_internal_value(tv1), get_internal_value(tv2)) && + isequal(get_transform(tv1), get_transform(tv2)) end -Base.:(==)(tv1::UntransformedValue, tv2::UntransformedValue) = tv1.val == tv2.val -Base.isequal(tv1::UntransformedValue, tv2::UntransformedValue) = isequal(tv1.val, tv2.val) -get_transform(::UntransformedValue) = Bijectors.VectorBijectors.TypedIdentity() -get_internal_value(tv::UntransformedValue) = tv.val -set_internal_value(::UntransformedValue, new_val) = UntransformedValue(new_val) """ - abstract type AbstractTransform end + get_transform(tv::TransformedValue) -An abstract type to represent the intended transformation for a variable. +Get the function that converts the transformed value back to the raw value. """ -abstract type AbstractTransform end +get_transform(tv::TransformedValue) = tv.transform """ - DynamicLink <: AbstractTransform + get_internal_value(tv::TransformedValue) -A type indicating that a target transformation should be derived by recomputing the invlink -transform from the distribution on the right-hand side of the tilde. +Get the internal value stored in `tv`. """ -struct DynamicLink <: AbstractTransform end +get_internal_value(tv::TransformedValue) = tv.value """ - Unlink <: AbstractTransform + set_internal_value(tv::TransformedValue, new_val) -A type indicating that the target transformation should be nothing. +Create a new `TransformedValue` with the same transformation as `tv`, but with +internal value `new_val`. """ -struct Unlink <: AbstractTransform end +set_internal_value(tv::TransformedValue, new_val) = + TransformedValue(new_val, get_transform(tv)) """ abstract type AbstractTransformStrategy end @@ -220,9 +115,9 @@ Regardless of what initialisation strategy is used (and what kind of transformed `init()` returns, the log-Jacobian that is accumulated is always the log-Jacobian for the forward transform specified by `target_transform(strategy, vn)`. -That is, even if `init()` returns an `UntransformedValue`, if the transform strategy is -`LinkAll()` (which returns `DynamicLink` for all variables), then the log-Jacobian for -linking will be accumulated during model evaluation. +That is, even if `init()` returns an unlinked or untransformed value, if the transform +strategy is `LinkAll()` (which returns `DynamicLink` for all variables), then the +log-Jacobian for linking will be accumulated during model evaluation. Subtypes in DynamicPPL are [`LinkAll`](@ref), [`UnlinkAll`](@ref), [`LinkSome`](@ref), and [`UnlinkSome`](@ref). @@ -230,13 +125,11 @@ Subtypes in DynamicPPL are [`LinkAll`](@ref), [`UnlinkAll`](@ref), [`LinkSome`]( abstract type AbstractTransformStrategy end """ - target_transform(linker::AbstractTransformStrategy, vn::VarName) + target_transform(linker::AbstractTransformStrategy, vn::VarName)::AbstractTransform Determine whether a variable with name `vn` should be linked according to the `linker` -strategy. Returns `DynamicLink()` if the variable should be linked, or `Unlink()` if it -should not. - -This function can in the future be extended to support fixed transformations. +strategy. Returns a subtype of `AbstractTransform` that indicates the intended +transformation for the variable. """ function target_transform end @@ -256,6 +149,48 @@ Indicate that all variables should be unlinked. struct UnlinkAll <: AbstractTransformStrategy end target_transform(::UnlinkAll, ::VarName) = Unlink() +""" + UnsafePassThrough() <: AbstractTransformStrategy + +Indicate that the transform strategy should not be applied, and that the transform status of +each variable should be determined only by the initialisation strategy. + +!!! warning + This is unsafe because it conflates the initialisation strategy with the transform + strategy: for example, the log-Jacobian accumulated is determined by the initialisation + strategy, which is not ideal. + + As a result of this, DynamicPPL does *not* export `UnsafePassThrough`. +""" +struct UnsafePassThrough <: AbstractTransformStrategy end +target_transform(::UnsafePassThrough, ::VarName) = error("should not be called") + +""" + WithTransforms(tfms::VarNamedTuple, fallback) <: AbstractTransformStrategy + +Indicate that the variables in `tfms` should be transformed according to their corresponding +values in `tfms`, which should be subtypes of `AbstractTransform`. specified in `tfms`. The +link statuses of other variables are determined by the `fallback` strategy. +""" +struct WithTransforms{V<:VarNamedTuple,L<:AbstractTransformStrategy} <: + AbstractTransformStrategy + tfms::V + fallback::L +end +function Base.:(==)(wt1::WithTransforms, wt2::WithTransforms) + return (wt1.tfms == wt2.tfms) & (wt1.fallback == wt2.fallback) +end +function Base.isequal(wt1::WithTransforms, wt2::WithTransforms) + return isequal(wt1.tfms, wt2.tfms) && isequal(wt1.fallback, wt2.fallback) +end +function target_transform(linker::WithTransforms, vn::VarName) + return if haskey(linker.tfms, vn) + linker.tfms[vn] + else + target_transform(linker.fallback, vn) + end +end + """ LinkSome(vns::Set{<:VarName}, fallback) <: AbstractTransformStrategy @@ -311,7 +246,7 @@ end """ DynamicPPL.apply_transform_strategy( strategy::AbstractTransformStrategy, - tv::AbstractTransformedValue, + tv::TransformedValue, vn::VarName, dist::Distribution, ) @@ -329,11 +264,15 @@ Specifically, this function does a number of things: A table summarising the possible transformations is as follows: - | tv isa ... | `target_transform(...) isa DynamicLink` | `target_transform(...) isa Unlink` | + | tv.transform isa ...| `target_transform(...) isa DynamicLink` | `target_transform(...) isa Unlink` | |---------------------|---------------------------------|------------------------------------| - | `LinkedVectorValue` | -> `LinkedVectorValue` | -> `UntransformedValue` | - | `VectorValue` | -> `LinkedVectorValue` | -> `VectorValue` | - | `UntransformedValue`| -> `LinkedVectorValue` | -> `UntransformedValue` | + | `DynamicLink` | -> `DynamicLink` | -> `NoTransform` | + | `Unlink` | -> `DynamicLink` | -> `Unlink` | + | `NoTransform` | -> `DynamicLink` | -> `NoTransform` | + | `FixedTransform` | errors | errors | + + Note that, for the last row, when using `FixedTransform` we require that `target_transform` + exactly matches the fixed transform, otherwise an error is thrown. - If `vn` is supposed to be linked, calculates the associated log-Jacobian adjustment for the **forward** linking transformation (i.e., from unlinked to linked). @@ -346,10 +285,10 @@ This function returns a tuple of `(raw_value, new_tv, logjac)`. """ function apply_transform_strategy( strategy::AbstractTransformStrategy, - tv::LinkedVectorValue, + tv::TransformedValue{T,DynamicLink}, vn::VarName, dist::Distribution, -) +) where {T<:AbstractVector{<:Real}} # tval is already linked. We need to get the raw value plus logjac finvlink = Bijectors.VectorBijectors.from_linked_vec(dist) raw_value, inv_logjac = with_logabsdet_jacobian(finvlink, get_internal_value(tv)) @@ -358,19 +297,23 @@ function apply_transform_strategy( # No need to transform further (raw_value, tv, -inv_logjac) elseif target isa Unlink - # Need to return an unlinked value. We _could_ generate a VectorValue here, with the - # vectorisation transform. However, sometimes that's not needed (e.g. when - # evaluating with an OnlyAccsVarInfo). So we just return an UntransformedValue. If a - # downstream function requires a VectorValue, it's on them to generate it. - (raw_value, UntransformedValue(raw_value), zero(LogProbType)) + # Need to return an unlinked value. We _could_ vectorise and generate a Unlink() + # here, with the vectorisation transform. However, sometimes that's not needed (e.g. + # when evaluating with an OnlyAccsVarInfo). So we just return an untransformed + # value. If a downstream function requires a vectorised value, it's on them to + # generate it. + (raw_value, TransformedValue(raw_value, NoTransform()), zero(LogProbType)) else error("unknown target transform $target") end end function apply_transform_strategy( - strategy::AbstractTransformStrategy, tv::VectorValue, vn::VarName, dist::Distribution -) + strategy::AbstractTransformStrategy, + tv::TransformedValue{T,Unlink}, + vn::VarName, + dist::Distribution, +) where {T<:AbstractVector{<:Real}} invlink = Bijectors.VectorBijectors.from_vec(dist) raw_value = invlink(get_internal_value(tv)) target = target_transform(strategy, vn) @@ -378,8 +321,7 @@ function apply_transform_strategy( # Need to link the value. We calculate the logjac flink = Bijectors.VectorBijectors.to_linked_vec(dist) linked_value, logjac = with_logabsdet_jacobian(flink, raw_value) - finvlink = Bijectors.inverse(flink) - linked_tv = LinkedVectorValue(linked_value, finvlink) + linked_tv = TransformedValue(linked_value, DynamicLink()) (raw_value, linked_tv, logjac) elseif target isa Unlink # No need to transform further @@ -391,23 +333,74 @@ end function apply_transform_strategy( strategy::AbstractTransformStrategy, - tv::UntransformedValue, + tv::TransformedValue{T,NoTransform}, vn::VarName, dist::Distribution, -) +) where {T} raw_value = get_internal_value(tv) target = target_transform(strategy, vn) return if target isa DynamicLink # Need to link the value. We calculate the logjac flink = Bijectors.VectorBijectors.to_linked_vec(dist) linked_value, logjac = with_logabsdet_jacobian(flink, raw_value) - finvlink = Bijectors.inverse(flink) - linked_tv = LinkedVectorValue(linked_value, finvlink) + linked_tv = TransformedValue(linked_value, DynamicLink()) (raw_value, linked_tv, logjac) elseif target isa Unlink # No need to transform further (raw_value, tv, zero(LogProbType)) + elseif target isa FixedTransform + # TODO!(penelopeysm): This relies on inverse() being defined. Is that bad? + fwd_transform = inverse(target.transform) + transformed_value, logjac = with_logabsdet_jacobian(fwd_transform, raw_value) + transformed_tv = TransformedValue(transformed_value, target) + (raw_value, transformed_tv, logjac) else error("unknown target transform $target") end end + +function apply_transform_strategy( + strategy::AbstractTransformStrategy, + tv::TransformedValue{T,FixedTransform{F}}, + vn::VarName, + ::Distribution, +) where {T,F} + # target = tv.transform + # target = target_transform(strategy, vn) + # # TODO(penelopeysm): Note that in principle we could probably allow different target + # # transforms. However, for now let's keep it simple and error if it doesn't match. + # if target != tv.transform + # error( + # "Variable $vn has a fixed transform, but the transform strategy expects it to be transformed differently.", + # ) + # end + raw_value, logjac = with_logabsdet_jacobian( + tv.transform.transform, get_internal_value(tv) + ) + return (raw_value, tv, logjac) +end + +""" + infer_transform_strategy_from_values(vnt::VarNamedTuple) + +Takes a VNT of things with transforms, and infers a transform strategy that is consistent +with the transforms specified in the VNT. For all values `v` in the VNT, `get_transform(v)` +should return an `AbstractTransform`. +""" +function infer_transform_strategy_from_values(vnt::VarNamedTuple) + # map_values!! might mutate the VNT, so deepcopy to avoid this + transforms_vnt = map_values!!(get_transform, deepcopy(vnt)) + tfms = values(transforms_vnt) + # TODO(penelopeysm): In an ideal world, could we reliably use eltype(tfms) to infer + # this? I'm just worried about the possibility of tfms having an overly abstract type, + # hence the check on every element individually. + return if all(x -> x isa DynamicLink, tfms) + LinkAll() + elseif all(x -> x isa Unlink, tfms) + UnlinkAll() + else + # Bundle all the transforms into a single one, with a default fallback of + # unlinked. + WithTransforms(transforms_vnt, UnlinkAll()) + end +end diff --git a/src/varinfo.jl b/src/varinfo.jl index 13ca10bb7..79b20aa42 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -7,17 +7,28 @@ The default implementation of `AbstractVarInfo`, storing variable values and accumulators. -`VarInfo` is quite a thin wrapper around a `VarNamedTuple` storing the variable values, and -a tuple of accumulators. The only really noteworthy thing about it is that it stores the -values of variables vectorised as instances of [`AbstractTransformedValue`](@ref). That is, -it stores each value as a special vector with a flag indicating whether it is just a -vectorised value ([`VectorValue`](@ref)), or whether it is also linked -([`LinkedVectorValue`](@ref)). It also stores the size of the actual post-transformation -value. These are all accessible via [`AbstractTransformedValue`](@ref). - -`VarInfo` additionally stores a transform strategy, which reflects the linked status of -variables inside the `VarInfo`. For example, a `VarInfo{LinkAll}` should contain only -`LinkedVectorValue`s in its `values` field. +A `VarInfo`, `vi`, is quite a thin wrapper around + + - `vi.values`: a `VarNamedTuple` storing the variable values, and + - `vi.accs`: a tuple of accumulators. + +The only really noteworthy thing about it is that `vi.values` specifically stores the values +of variables as `[TransformedValue{<:AbstractVector}](@ref TransformedValue)`. + +That is, regardless of what the value of a variable is in the original distribution, the +VarInfo stores a *vectorised* version of the value. It is not particularly concerned about +whether the variable is linked or not: you can mix unlinked variables with linked variables +in a `VarInfo`. + +!!! note + This functionality is identical to that in [`VectorValueAccumulator`](@ref), and going + forward we recommend using that instead of `VarInfo`. + +On top of that, `VarInfo` also stores a transform strategy, which reflects the linked status +of variables inside the `VarInfo`. For example, a `VarInfo{LinkAll}` should contain only +`TransformedValue{T,LinkAll}`s in its `values` field. This unfortunately leads to redundancy +of information, but is necessary for type stability, since that allows us to have +compile-time knowledge of what transformations are applied. Because the job of `VarInfo` is to store transformed values, there is no generic `setindex!!` implementation on `VarInfo` itself. Instead, all storage must go via @@ -25,12 +36,14 @@ Because the job of `VarInfo` is to store transformed values, there is no generic transformed form. This in turn means that the distribution on the right-hand side of a tilde-statement must be available when modifying a VarInfo. -You can use `getindex` on `VarInfo` to obtain values in the support of the original -distribution. To directly get access to the internal vectorised values, use -[`getindex_internal`](@ref), [`setindex_internal!!`](@ref), and [`unflatten!!`](@ref). +Furthermore, since no untransformed (raw) values are stored in `VarInfo`, there is no +generic `getindex` implementation that returns raw values. If you need this functionality, +you should make sure that `vi.accs` contains a `RawValueAccumulator` and use that to get the +raw values. To directly get access to the internal vectorised values in `vi.values`, you can +use [`getindex_internal`](@ref), [`setindex_internal!!`](@ref), and [`unflatten!!`](@ref). -For more details on the internal storage, see documentation of -[`AbstractTransformedValue`](@ref) and [`VarNamedTuple`](@ref). +For more details on the internal storage, see documentation of [`TransformedValue`](@ref) +and [`VarNamedTuple`](@ref). # Fields $(TYPEDFIELDS) @@ -157,8 +170,6 @@ function setaccs!!(vi::VarInfo, accs::AccumulatorTuple) return VarInfo(vi.transform_strategy, vi.values, accs) end -transformation(::VarInfo) = DynamicTransformation() - function Base.copy(vi::VarInfo) return VarInfo(vi.transform_strategy, copy(vi.values), copy(getaccs(vi))) end @@ -171,12 +182,7 @@ Base.keys(vi::VarInfo) = keys(vi.values) # Union{Vector{Union{}}, Vector{Float64}} (I suppose this is because it can't tell whether # the result will be empty or not...? Not sure). function Base.values(vi::VarInfo) - return mapreduce( - p -> DynamicPPL.get_transform(p.second)(DynamicPPL.get_internal_value(p.second)), - push!, - vi.values; - init=Any[], - ) + return mapreduce(p -> p.second, push!, vi.values; init=Any[]) end function Base.show(io::IO, ::MIME"text/plain", vi::VarInfo) @@ -202,14 +208,6 @@ function Base.show(io::IO, ::MIME"text/plain", vi::VarInfo) return nothing end -function Base.getindex(vi::VarInfo, vn::VarName) - tv = getindex(vi.values, vn) - return DynamicPPL.get_transform(tv)(DynamicPPL.get_internal_value(tv)) -end -function Base.getindex(vi::VarInfo, vns::AbstractVector{<:VarName}) - return [getindex(vi, vn) for vn in vns] -end - Base.isempty(vi::VarInfo) = isempty(vi.values) Base.empty(vi::VarInfo) = VarInfo(UnlinkAll(), empty(vi.values), map(reset, vi.accs)) function BangBang.empty!!(vi::VarInfo) @@ -225,7 +223,7 @@ This does not change the transformation or linked status of the variable. """ function setindex_internal!!(vi::VarInfo, val, vn::VarName) old_tv = getindex(vi.values, vn) - new_tv = set_internal_value(old_tv, val) + new_tv = TransformedValue(val, old_tv.transform) new_values = setindex!!(vi.values, new_tv, vn) return VarInfo(vi.transform_strategy, new_values, vi.accs) end @@ -235,62 +233,82 @@ end tfm_strategy::AbstractTransformStrategy, vi_is_empty::Bool, new_vn::VarName, - new_vn_is_linked::Bool + new_vn_transform::AbstractTransform ) -Given an old transform strategy `tfm_strategy`, and the linked status of a new variable -`new_vn` to be added to a `VarInfo` with that transform strategy, return an updated +Given an old transform strategy `tfm_strategy`, and the transformation of a new variable +`new_vn` to be added to a `VarInfo` which has that transform strategy, return an updated transform strategy that accounts for the addition of `new_vn`. """ +function update_transform_strategy(::LinkAll, ::Bool, ::VarName, ::DynamicLink) + return LinkAll() +end +function update_transform_strategy(::LinkAll, vi_is_empty::Bool, new_vn::VarName, ::Unlink) + return vi_is_empty ? UnlinkAll() : UnlinkSome(Set([new_vn]), LinkAll()) +end + +function update_transform_strategy(ls::LinkSome, ::Bool, vn::VarName, ::DynamicLink) + return if vn in ls.vns + ls + else + LinkSome(Set([vn]) ∪ ls.vns, ls.fallback) + end +end +function update_transform_strategy(ls::LinkSome, ::Bool, vn::VarName, ::Unlink) + return if vn in ls.vns + LinkSome(setdiff(ls.vns, Set([vn])), ls.fallback) + else + ls + end +end + +function update_transform_strategy(::UnlinkAll, ::Bool, ::VarName, ::Unlink) + return UnlinkAll() +end function update_transform_strategy( - tfm_strategy::AbstractTransformStrategy, - vi_is_empty::Bool, - new_vn::VarName, - new_vn_is_linked::Bool, + ::UnlinkAll, vi_is_empty::Bool, new_vn::VarName, ::DynamicLink ) - if new_vn_is_linked - if tfm_strategy isa LinkAll || vi_is_empty - LinkAll() - elseif target_transform(tfm_strategy, new_vn) isa DynamicLink - # can reuse - tfm_strategy - else - # have to wrap - LinkSome(Set([new_vn]), tfm_strategy) - end + return vi_is_empty ? LinkAll() : LinkSome(Set([new_vn]), UnlinkAll()) +end + +function update_transform_strategy(ls::UnlinkSome, ::Bool, vn::VarName, ::Unlink) + return if vn in ls.vns + ls else - if tfm_strategy isa UnlinkAll || vi_is_empty - UnlinkAll() - elseif target_transform(tfm_strategy, new_vn) isa Unlink - tfm_strategy - else - UnlinkSome(Set([new_vn]), tfm_strategy) - end + UnlinkSome(Set([vn]) ∪ ls.vns, ls.fallback) + end +end +function update_transform_strategy(ls::UnlinkSome, ::Bool, vn::VarName, ::DynamicLink) + return if vn in ls.vns + UnlinkSome(setdiff(ls.vns, Set([vn])), ls.fallback) + else + ls end end """ setindex_with_dist!!( vi::VarInfo, - tval::Union{VectorValue,LinkedVectorValue}, + tval::TransformedValue{<:AbstractVector{<:Real},<:Any}, dist::Distribution, vn::VarName, template::Any, ) -Set the value of `vn` in `vi` to `tval`. Note that this will cause the linked status of `vi` -to update according to what `tval` is. That means that whether or not a variable is -considered to be 'linked' is determined by `tval` rather than the previous status of `vi`. +Store a transformed value that has already been vectorised. This might include dynamically +transformed variables (which have `tval.transform` as a `DynamicLink` or `Unlink`), or +statically transformed variables (which have `tval.transform` as a `FixedTransform`). +However, in either case, it is mandatory that `tval.value` is a vector. + +Note that this will cause the linked status of `vi` to update according to what `tval` is. +That means that whether or not a variable is considered to be 'linked' is determined by +`tval` rather than the previous status of `vi`. """ function setindex_with_dist!!( - vi::VarInfo, - tval::Union{VectorValue,LinkedVectorValue}, - ::Distribution, - vn::VarName, - template::Any, -) + vi::VarInfo, tval::TransformedValue{T,V}, ::Distribution, vn::VarName, template::Any +) where {T<:AbstractVector{<:Real},V} new_transform_strategy = update_transform_strategy( - vi.transform_strategy, isempty(vi), vn, tval isa LinkedVectorValue + vi.transform_strategy, isempty(vi), vn, tval.transform ) return VarInfo( new_transform_strategy, templated_setindex!!(vi.values, tval, vn, template), vi.accs @@ -300,67 +318,69 @@ end """ setindex_with_dist!!( vi::VarInfo, - tval::UntransformedValue, + utval::TransformedValue{<:Any,NoTransform}, dist::Distribution, vn::VarName, template::Any ) -Vectorise `tval` (into a `VectorValue`) and store it. (Note that if `setindex_with_dist!!` -receives an `UntransformedValue`, the variable is always considered unlinked, since if it -were to be linked, `apply_transform_strategy` will already have done so.) +Vectorise `utval` and store it. (Note that if `setindex_with_dist!!` receives an +untransformed value, the variable is always considered unlinked, since if it were to be +linked, `apply_transform_strategy` will already have done so.) """ function setindex_with_dist!!( - vi::VarInfo, tval::UntransformedValue, dist::Distribution, vn::VarName, template -) + vi::VarInfo, + tval::TransformedValue{V,NoTransform}, + dist::Distribution, + vn::VarName, + template, +) where {V} raw_value = DynamicPPL.get_internal_value(tval) - tval = VectorValue( - Bijectors.VectorBijectors.to_vec(dist)(raw_value), - Bijectors.VectorBijectors.from_vec(dist), - ) + vectorised_value = Bijectors.VectorBijectors.to_vec(dist)(raw_value) + tval = TransformedValue(vectorised_value, Unlink()) + return setindex_with_dist!!(vi, tval, dist, vn, template) +end +function setindex_with_dist!!( + vi::VarInfo, + tval::TransformedValue{V,NoTransform}, + dist::Distribution, + vn::VarName, + template, +) where {V<:AbstractVector{<:Real}} + # This method is needed for resolving ambiguities. It does the same thing as + # above, but skipping the vectorisation step, since to_vec(dist) for a vector + # is always identity. + tval = TransformedValue(DynamicPPL.get_internal_value(tval), Unlink()) return setindex_with_dist!!(vi, tval, dist, vn, template) end """ set_transformed!!(vi::VarInfo, linked::Bool, vn::VarName) -Set the linked status of variable `vn` in `vi` to `linked`. +If `linked`, set the variable `vn` in `vi` to be linked (i.e., change its stored transform +to be `DynamicLink()`). Otherwise, set it to be unlinked (i.e., change its stored transform +to be `Unlink()`). This will also update the transform strategy of `vi` accordingly. -Note that this function is potentially unsafe as it does not change the value or -transformation of the variable! +!!! warning + Note that this function is potentially unsafe as it does not change the value of the + variable! """ function set_transformed!!(vi::VarInfo, linked::Bool, vn::VarName) + # TODO!(penelopeysm): Why do we still need this? old_tv = getindex(vi.values, vn) - new_tv = if linked - LinkedVectorValue(old_tv.val, old_tv.transform) - else - VectorValue(old_tv.val, old_tv.transform) - end + new_transform = linked ? DynamicLink() : Unlink() + new_tv = TransformedValue(old_tv.value, new_transform) new_values = setindex!!(vi.values, new_tv, vn) new_transform_strategy = update_transform_strategy( - vi.transform_strategy, isempty(vi), vn, linked + vi.transform_strategy, isempty(vi), vn, new_transform ) return VarInfo(new_transform_strategy, new_values, vi.accs) end -# VarInfo does not care whether the transformation was Static or Dynamic, it just tracks -# whether one was applied at all. -function set_transformed!!(vi::VarInfo, ::AbstractTransformation, vn::VarName) - return set_transformed!!(vi, true, vn) -end - -set_transformed!!(vi::VarInfo, ::AbstractTransformation) = set_transformed!!(vi, true) - -function set_transformed!!(vi::VarInfo, ::NoTransformation, vn::VarName) - return set_transformed!!(vi, false, vn) -end - -set_transformed!!(vi::VarInfo, ::NoTransformation) = set_transformed!!(vi, false) - function set_transformed!!(vi::VarInfo, linked::Bool) - ctor = linked ? LinkedVectorValue : VectorValue + tfm = linked ? DynamicLink() : Unlink() new_values = map_values!!(vi.values) do tv - ctor(tv.val, tv.transform) + TransformedValue(tv.value, tfm) end new_transform_strategy = linked ? LinkAll() : UnlinkAll() return VarInfo(new_transform_strategy, new_values, vi.accs) @@ -371,14 +391,12 @@ end Get the internal (vectorised) value of variable `vn` in `vi`. """ -getindex_internal(vi::VarInfo, vn::VarName) = getindex(vi.values, vn).val -# TODO(mhauru) The below should be removed together with unflatten!!. -getindex_internal(vi::VarInfo, ::Colon) = internal_values_as_vector(vi) +getindex_internal(vi::VarInfo, vn::VarName) = get_internal_value(getindex(vi.values, vn)) """ get_transformed_value(vi::VarInfo, vn::VarName) -Get the entire `AbstractTransformedValue` for variable `vn` in `vi`. +Get the entire `TransformedValue` for variable `vn` in `vi`. """ get_transformed_value(vi::VarInfo, vn::VarName) = getindex(vi.values, vn) @@ -388,27 +406,8 @@ function is_transformed(vi::VarInfo, vn::VarName) elseif vi.transform_strategy isa UnlinkAll false else - getindex(vi.values, vn) isa LinkedVectorValue - end -end - -function from_internal_transform(::VarInfo, ::VarName, dist::Distribution) - return Bijectors.VectorBijectors.from_vec(dist) -end - -function from_linked_internal_transform(::VarInfo, ::VarName, dist::Distribution) - return Bijectors.VectorBijectors.from_linked_vec(dist) -end - -function from_internal_transform(vi::VarInfo, vn::VarName) - return DynamicPPL.get_transform(getindex(vi.values, vn)) -end - -function from_linked_internal_transform(vi::VarInfo, vn::VarName) - if !is_transformed(vi, vn) - error("Variable $vn is not linked; cannot get linked transformation.") + get_transformed_value(vi, vn).transform isa DynamicLink end - return DynamicPPL.get_transform(getindex(vi.values, vn)) end """ @@ -426,14 +425,14 @@ This is the inverse of [`unflatten!!`](@ref). internal_values_as_vector(vi::VarInfo) = internal_values_as_vector(vi.values) """ - DynamicPPL.update_link_status!!( + DynamicPPL.update_transform_status!!( orig_vi::VarInfo, transform_strategy::AbstractTransformStrategy, model::Model ) Given an original `VarInfo` `orig_vi`, update the link status of its variables according to the new `transform_strategy`. """ -function update_link_status!!( +function update_transform_status!!( orig_vi::VarInfo, transform_strategy::AbstractTransformStrategy, model::Model ) # We'll just recalculate logjac from the start, rather than trying to adjust the old @@ -449,50 +448,22 @@ function update_link_status!!( return VarInfo(transform_strategy, new_vector_vals.values, orig_vi.accs) end -function link!!(::DynamicTransformation, vi::VarInfo, vns, model::Model) - return update_link_status!!(vi, LinkSome(Set(vns), get_transform_strategy(vi)), model) -end -function invlink!!(::DynamicTransformation, vi::VarInfo, vns, model::Model) - return update_link_status!!(vi, UnlinkSome(Set(vns), get_transform_strategy(vi)), model) -end -function link!!(::DynamicTransformation, vi::VarInfo, model::Model) - return update_link_status!!(vi, LinkAll(), model) +# These are mostly convenience functions +function link!!(vi::VarInfo, vns, model::Model) + return update_transform_status!!( + vi, LinkSome(Set(vns), get_transform_strategy(vi)), model + ) end -function invlink!!(::DynamicTransformation, vi::VarInfo, model::Model) - return update_link_status!!(vi, UnlinkAll(), model) +function invlink!!(vi::VarInfo, vns, model::Model) + return update_transform_status!!( + vi, UnlinkSome(Set(vns), get_transform_strategy(vi)), model + ) end - -function link!!(t::StaticTransformation{<:Bijectors.Transform}, vi::VarInfo, ::Model) - # TODO(mhauru) This assumes that the user has defined the bijector using the same - # variable ordering as what `vi[:]` and `unflatten!!(vi, x)` use. This is a bad user - # interface. - b = inverse(t.bijector) - x = vi[:] - y, logjac = with_logabsdet_jacobian(b, x) - # TODO(mhauru) This doesn't set the transforms of `vi`. With the old Metadata that meant - # that getindex(vi, vn) would apply the default link transform of the distribution. With - # the new VarNamedTuple-based VarInfo it means that getindex(vi, vn) won't apply any - # link transform. Neither is correct, rather the transform should be the inverse of b. - vi = unflatten!!(vi, y) - if hasacc(vi, Val(:LogJacobian)) - vi = acclogjac!!(vi, logjac) - end - return set_transformed!!(vi, t) +function link!!(vi::VarInfo, model::Model) + return update_transform_status!!(vi, LinkAll(), model) end - -function invlink!!(t::StaticTransformation{<:Bijectors.Transform}, vi::VarInfo, ::Model) - b = t.bijector - y = vi[:] - x, inv_logjac = with_logabsdet_jacobian(b, y) - - # Mildly confusing: we need to _add_ the logjac of the inverse transform, - # because we are trying to remove the logjac of the forward transform - # that was previously accumulated when linking. - vi = unflatten!!(vi, x) - if hasacc(vi, Val(:LogJacobian)) - vi = acclogjac!!(vi, inv_logjac) - end - return set_transformed!!(vi, NoTransformation()) +function invlink!!(vi::VarInfo, model::Model) + return update_transform_status!!(vi, UnlinkAll(), model) end """ @@ -523,16 +494,14 @@ mutable struct VectorChunkIterator!{T<:AbstractVector} vec::T index::Int end -for T in (:VectorValue, :LinkedVectorValue) - @eval begin - function (vci::VectorChunkIterator!)(tv::$T) - old_val = tv.val - len = length(old_val) - new_val = @view vci.vec[(vci.index):(vci.index + len - 1)] - vci.index += len - return $T(new_val, tv.transform) - end - end +function (vci::VectorChunkIterator!)( + tv::TransformedValue{V,T} +) where {V<:AbstractVector{<:Real},T} + old_val = tv.value + len = length(old_val) + new_val = @view vci.vec[(vci.index):(vci.index + len - 1)] + vci.index += len + return TransformedValue(new_val, tv.transform) end function unflatten!!(vi::VarInfo, vec::AbstractVector) vci = VectorChunkIterator!(vec, 1) @@ -568,6 +537,7 @@ function Base.merge(varinfo_left::VarInfo, varinfo_right::VarInfo) new_values = merge(varinfo_left.values, varinfo_right.values) new_accs = map(copy, getaccs(varinfo_right)) new_transform_strategy = + # Some shortcircuits to maximise type stability if varinfo_left.transform_strategy isa LinkAll && varinfo_right.transform_strategy isa LinkAll LinkAll() @@ -575,22 +545,7 @@ function Base.merge(varinfo_left::VarInfo, varinfo_right::VarInfo) varinfo_right.transform_strategy isa UnlinkAll UnlinkAll() else - linked_vns = Set{VarName}() - unlinked_vns = Set{VarName}() - for (vn, tval) in pairs(new_values) - if tval isa LinkedVectorValue - push!(linked_vns, vn) - else - push!(unlinked_vns, vn) - end - end - if isempty(linked_vns) - UnlinkAll() - elseif isempty(unlinked_vns) - LinkAll() - else - LinkSome(linked_vns, UnlinkSome(unlinked_vns, LinkAll())) - end + infer_transform_strategy_from_values(new_values) end return VarInfo(new_transform_strategy, new_values, new_accs) end diff --git a/test/accumulators.jl b/test/accumulators.jl index ba2585f86..581ea64dc 100644 --- a/test/accumulators.jl +++ b/test/accumulators.jl @@ -116,7 +116,7 @@ TEST_ACCUMULATORS = ( @testset "accumulate_assume" begin val = 2.0 - tval = DynamicPPL.UntransformedValue(nothing) + tval = DynamicPPL.TransformedValue(nothing, NoTransform()) logjac = pi vn = @varname(x) dist = Normal() diff --git a/test/bijector.jl b/test/bijector.jl deleted file mode 100644 index ef1bc14ff..000000000 --- a/test/bijector.jl +++ /dev/null @@ -1,42 +0,0 @@ -module DynamicPPLBijectorTests - -using Dates: now -@info "Testing $(@__FILE__)..." -__now__ = now() - -using Distributions -using DynamicPPL -using Bijectors: bijector, inverse -using Test - -@testset "bijector.jl" begin - @testset "bijector" begin - @model function test() - m ~ Normal() - s ~ InverseGamma(3, 3) - return c ~ Dirichlet([1.0, 1.0]) - end - - m = test() - b = bijector(m) - - # m ∈ ℝ, s ∈ ℝ+, c ∈ 2-simplex - # check dimensionalities and ranges - @test b.length_in == 4 - @test b.length_out == 3 - @test b.ranges_in == [1:1, 2:2, 3:4] - @test b.ranges_out == [1:1, 2:2, 3:3] - @test b.ranges_out == [1:1, 2:2, 3:3] - - # check support of mapped variables - binv = inverse(b) - zs = mapslices(binv, randn(b.length_out, 10000); dims=1) - - @test all(zs[2, :] .≥ 0) - @test all(sum(zs[3:4, :]; dims=1) .≈ 1.0) - end -end - -@info "Completed $(@__FILE__) in $(now() - __now__)." - -end # module diff --git a/test/chains.jl b/test/chains.jl index 9bdee234d..4863fc7af 100644 --- a/test/chains.jl +++ b/test/chains.jl @@ -69,54 +69,25 @@ using Test end end -_safe_length(x) = length(x) -# This actually gives N^2 elements, although there are only really N(N+1)/2 parameters in -# the Cholesky factor. It doesn't really matter because we are comparing like for like i.e. -# both sides of the sum will have the same overcounting. -_safe_length(c::LinearAlgebra.Cholesky) = length(c.UL) - @testset "ParamsWithStats from LogDensityFunction" begin @testset "$(m.f)" for m in DynamicPPL.TestUtils.ALL_MODELS - if m.f === DynamicPPL.TestUtils.demo_static_transformation - # TODO(mhauru) These tests are broken for demo_static_transformation because - # vi[vn] doesn't know which transform it should apply to the internally stored - # value. This requires a rethink, either of StaticTransformation or of what the - # comparison in this test should be. - @test false broken = true - continue - end @testset "$transform_strategy" for transform_strategy in (UnlinkAll(), LinkAll()) - vi = VarInfo(m, InitFromPrior(), transform_strategy) - params = [x for x in vi[:]] - # Get the ParamsWithStats using LogDensityFunction - ldf = DynamicPPL.LogDensityFunction(m, getlogjoint, vi) - ps = ParamsWithStats(params, ldf) - - # The keys are not necessarily going to be the same, because `ps.params` was - # obtained via RawValueAccumulator, which only stores raw values. However, `vi` - # stores TransformedValue objects. So, if you have something like - # x[4:5] ~ MvNormal(zeros(2), I) - # then `ps.params` will have keys `x[4]` and `x[5]` (since it just contains a - # PartialArray with those two elements unmasked), whereas `vi` will have the key - # `x[4:5]` which stores an ArrayLikeBlock with two elements. - # - # On top of that, the ParamsWithStats VNT will also have been `densify!!`-ed, so - # it may well just store a single vector `x` rather than individual `x[i]`'s. - # - # What we CAN do, though, is to check the size of the thing obtained by - # indexing into the keys. For `ps.params`, indexing into `x[4]` and `x[5]` will - # give two floats, each of "length" 1. For `vi`, indexing into `x[4:5]` will - # give a single object that has length 2. So we can check that the total number - # of _things_ contained inside is the same. - # - # Unfortunately, we need _safe_length to handle Cholesky. - @test sum(_safe_length(ps.params[vn]) for vn in keys(ps.params)) == - sum(_safe_length(vi[vn]) for vn in keys(vi)) + ldf = DynamicPPL.LogDensityFunction(m, getlogjoint, transform_strategy) + param_vector = rand(ldf) + # This will give us a VNT of values.params`. + actual_vnt = ParamsWithStats(param_vector, ldf).params + # We should make sure that those values line up with the values inside the vector. + accs = OnlyAccsVarInfo(RawValueAccumulator(true)) + _, accs = DynamicPPL.init!!( + m, accs, InitFromVector(param_vector, ldf), transform_strategy + ) + expected_vnt = DynamicPPL.densify!!(get_raw_values(accs)) # Iterate over all variables to check that their values match - for vn in keys(vi) - @test ps.params[vn] == vi[vn] + @test Set(keys(actual_vnt)) == Set(keys(expected_vnt)) + for vn in keys(actual_vnt) + @test actual_vnt[vn] == expected_vnt[vn] end end end diff --git a/test/compiler.jl b/test/compiler.jl index 8b3f184ae..55f5adb35 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -489,7 +489,7 @@ module Issue537 end end m = demo3(1000.0, missing) # Mean of `y` should be close to 1000. - @test abs(mean([VarInfo(m)[@varname(y)] for i in 1:10]) - 1000) ≤ 10 + @test abs(mean([rand(m)[@varname(y)] for i in 1:10]) - 1000) ≤ 10 # Prefixed submodels and usage of submodel return values. @model function demo_return(x) @@ -507,7 +507,7 @@ module Issue537 end @test @varname(sub1.x) ∈ ks @test @varname(sub2.x) ∈ ks @test @varname(z) ∈ ks - @test abs(mean([VarInfo(m)[@varname(z)] for i in 1:10]) - 100) ≤ 10 + @test abs(mean([rand(m)[@varname(z)] for i in 1:10]) - 100) ≤ 10 # AR1 model. Dynamic prefixing. @model function AR1(num_steps, α, μ, σ, ::Type{TV}=Vector{Float64}) where {TV} @@ -758,7 +758,7 @@ module Issue537 end @test model() isa NamedTuple{(:x, :y)} # `VarInfo` should only contain `x`. - varinfo = VarInfo(model) + varinfo = rand(model) @test haskey(varinfo, @varname(x)) @test !haskey(varinfo, @varname(y)) @@ -792,21 +792,15 @@ module Issue537 end end # As above, but the variables should now have their names prefixed with `b.a`. model = demo_tracked_subsubmodel_prefix() - varinfo = VarInfo(model) - @test haskey(varinfo, @varname(b.a.x)) - @test length(keys(varinfo)) == 1 + vnt = rand(model) + @test haskey(vnt, @varname(b.a.x)) + @test length(keys(vnt)) == 1 vi = OnlyAccsVarInfo((RawValueAccumulator(true),)) _, vi = init!!(model, vi, InitFromPrior(), UnlinkAll()) values = get_raw_values(vi) @test haskey(values, @varname(b.a.x)) @test haskey(values, @varname(b.a.y)) - - vi = OnlyAccsVarInfo((RawValueAccumulator(false),)) - _, vi = init!!(model, vi, InitFromPrior(), UnlinkAll()) - values = get_raw_values(vi) - @test haskey(values, @varname(b.a.x)) - @test length(keys(varinfo)) == 1 end @testset "signature parsing + TypeWrap" begin @@ -866,7 +860,7 @@ module Issue537 end retval, vi = DynamicPPL.init!!(nt(data), VarInfo()) @test retval == 5.0 @test vi isa VarInfo - @test vi[@varname(m)] isa Real + @test only(DynamicPPL.getindex_internal(vi, @varname(m))) isa Real end @testset "convert_model_argument" begin diff --git a/test/conditionfix.jl b/test/conditionfix.jl index b5de9df31..ef91a67b6 100644 --- a/test/conditionfix.jl +++ b/test/conditionfix.jl @@ -372,11 +372,12 @@ DynamicPPL.setchildcontext(::MyParentContext, child) = MyParentContext(child) return data.x end fixm = DynamicPPL.fix(ntfix(), (; data=(; x=5.0))) - retval, vi = DynamicPPL.init!!(fixm, VarInfo()) + accs = OnlyAccsVarInfo(RawValueAccumulator(false)) + retval, accs = DynamicPPL.init!!(fixm, accs, InitFromPrior(), UnlinkAll()) # The fixed data should overwrite the NamedTuple that came before it @test retval == 5.0 - @test vi isa VarInfo - @test vi[@varname(m)] isa Real + # `m` should still be sampled + @test get_raw_values(accs)[@varname(m)] isa Real end @testset "can condition/fix on each individual part of a multivariate" begin diff --git a/test/contexts/init.jl b/test/contexts/init.jl index e775bbd0e..5cb33c16c 100644 --- a/test/contexts/init.jl +++ b/test/contexts/init.jl @@ -28,7 +28,8 @@ using Test this_vi = deepcopy(empty_vi) _, vi = DynamicPPL.init!!(model, this_vi, strategy, UnlinkAll()) @test Set(keys(vi)) == Set([@varname(x), @varname(y)]) - x, y = vi[@varname(x)], vi[@varname(y)] + x = only(DynamicPPL.getindex_internal(vi, @varname(x))) + y = DynamicPPL.getindex_internal(vi, @varname(y)) @test x isa Real @test y isa AbstractVector{<:Real} @test length(y) == 2 @@ -48,21 +49,22 @@ using Test old_x, old_y = 100000.00, [300000.00, 500000.00] vi = DynamicPPL.setindex_with_dist!!( vi, - UntransformedValue(old_x), + TransformedValue(old_x, NoTransform()), Normal(), @varname(x), DynamicPPL.NoTemplate(), ) vi = DynamicPPL.setindex_with_dist!!( vi, - UntransformedValue(old_y), + TransformedValue(old_y, NoTransform()), MvNormal(fill(old_x, 2), I), @varname(y), DynamicPPL.NoTemplate(), ) # then overwrite it _, new_vi = DynamicPPL.init!!(model, vi, strategy, UnlinkAll()) - new_x, new_y = new_vi[@varname(x)], new_vi[@varname(y)] + new_x = only(DynamicPPL.getindex_internal(new_vi, @varname(x))) + new_y = DynamicPPL.getindex_internal(new_vi, @varname(y)) # check that the values are (presumably) different @test old_x != new_x @test old_y != new_y @@ -72,14 +74,23 @@ using Test function test_rng_respected(strategy::AbstractInitStrategy) @testset "check that RNG is respected: $(typeof(strategy))" begin model = test_init_model() - empty_vi = VarInfo() - _, vi1 = DynamicPPL.init!!(Xoshiro(468), model, deepcopy(empty_vi), strategy) - _, vi2 = DynamicPPL.init!!(Xoshiro(468), model, deepcopy(empty_vi), strategy) - _, vi3 = DynamicPPL.init!!(Xoshiro(469), model, deepcopy(empty_vi), strategy) - @test vi1[@varname(x)] == vi2[@varname(x)] - @test vi1[@varname(y)] == vi2[@varname(y)] - @test vi1[@varname(x)] != vi3[@varname(x)] - @test vi1[@varname(y)] != vi3[@varname(y)] + accs = OnlyAccsVarInfo((RawValueAccumulator(false),)) + _, accs1 = DynamicPPL.init!!( + Xoshiro(468), model, deepcopy(accs), strategy, UnlinkAll() + ) + _, accs2 = DynamicPPL.init!!( + Xoshiro(468), model, deepcopy(accs), strategy, UnlinkAll() + ) + _, accs3 = DynamicPPL.init!!( + Xoshiro(469), model, deepcopy(accs), strategy, UnlinkAll() + ) + vnt1 = get_raw_values(accs1) + vnt2 = get_raw_values(accs2) + vnt3 = get_raw_values(accs3) + @test vnt1[@varname(x)] == vnt2[@varname(x)] + @test vnt1[@varname(y)] == vnt2[@varname(y)] + @test vnt1[@varname(x)] != vnt3[@varname(x)] + @test vnt1[@varname(y)] != vnt3[@varname(y)] end end @@ -116,10 +127,10 @@ using Test for vn in (@varname(a), @varname(b)) if DynamicPPL.target_transform(transform_strategy, vn) isa DynamicLink @test DynamicPPL.is_transformed(vi, vn) - # The VarInfo should hold a LinkedVectorValue - lvv = vi.values[vn] - @test lvv isa LinkedVectorValue - linked_vec = DynamicPPL.get_internal_value(lvv) + # The VarInfo should hold a linked value + tv = vi.values[vn] + @test tv.transform isa DynamicLink + linked_vec = DynamicPPL.get_internal_value(tv) val, inv_logjac = Bijectors.with_logabsdet_jacobian( from_linked_vec, linked_vec ) @@ -127,11 +138,11 @@ using Test expected_logjac -= inv_logjac else @test !DynamicPPL.is_transformed(vi, vn) - # The VarInfo should hold a VectorValue - vv = vi.values[vn] - @test vv isa VectorValue + # The VarInfo should hold a non-linked value + tv = vi.values[vn] + @test tv.transform isa Unlink # it should wrap a single value - val = only(DynamicPPL.get_internal_value(vv)) + val = only(DynamicPPL.get_internal_value(tv)) expected_logprior += logpdf(dist, val) end end @@ -150,9 +161,11 @@ using Test @testset "check that values are within support" begin @model just_unif() = x ~ Uniform(0.0, 1e-7) for _ in 1:100 - _, vi = DynamicPPL.init!!(just_unif(), VarInfo(), InitFromPrior()) - @test vi[@varname(x)] isa Real - @test 0.0 <= vi[@varname(x)] <= 1e-7 + accs = OnlyAccsVarInfo((RawValueAccumulator(false),)) + _, accs = DynamicPPL.init!!(just_unif(), accs, InitFromPrior(), UnlinkAll()) + x = get_raw_values(accs)[@varname(x)] + @test x isa Real + @test 0.0 <= x <= 1e-7 end end diff --git a/test/ext/DynamicPPLMarginalLogDensitiesExt.jl b/test/ext/DynamicPPLMarginalLogDensitiesExt.jl index 32c4bb479..7d0a2e43d 100644 --- a/test/ext/DynamicPPLMarginalLogDensitiesExt.jl +++ b/test/ext/DynamicPPLMarginalLogDensitiesExt.jl @@ -77,26 +77,40 @@ using ADTypes: AutoForwardDiff vi_unlinked = VarInfo(model) vi_linked = DynamicPPL.link(vi_unlinked, model) + function get_raw_values_from_tvals(vi::VarInfo) + # TODO(penelopeysm) Fix this in the source code itself. + init_strat = InitFromParams(vi.values) + accs = OnlyAccsVarInfo(RawValueAccumulator(false)) + _, accs = init!!(model, accs, init_strat, UnlinkAll()) + return get_raw_values(accs) + end + @testset "unlinked VarInfo" begin mx = marginalize(model, [@varname(x)]; varinfo=vi_unlinked) mx([0.5]) # evaluate at some point to force calculation of Laplace approx vi = VarInfo(mx) - @test vi[@varname(x)] ≈ mode(Normal()) + vnt = get_raw_values_from_tvals(vi) + @test vnt[@varname(x)] ≈ mode(Normal()) + vi = VarInfo(mx, [0.5]) # this 0.5 is unlinked - @test vi[@varname(x)] ≈ mode(Normal()) - @test vi[@varname(y)] ≈ 0.5 + vnt = get_raw_values_from_tvals(vi) + @test vnt[@varname(x)] ≈ mode(Normal()) + @test vnt[@varname(y)] ≈ 0.5 end @testset "linked VarInfo" begin mx = marginalize(model, [@varname(x)]; varinfo=vi_linked) mx([0.5]) # evaluate at some point to force calculation of Laplace approx vi = VarInfo(mx) - @test vi[@varname(x)] ≈ mode(Normal()) + vnt = get_raw_values_from_tvals(vi) + @test vnt[@varname(x)] ≈ mode(Normal()) + vi = VarInfo(mx, [0.5]) # this 0.5 is linked - binv = Bijectors.inverse(Bijectors.bijector(Beta(2, 2))) - @test vi[@varname(x)] ≈ mode(Normal()) + vnt = get_raw_values_from_tvals(vi) + binv = Bijectors.VectorBijectors.from_linked_vec(Beta(2, 2)) + @test vnt[@varname(x)] ≈ mode(Normal()) # when using getindex it always returns unlinked values - @test vi[@varname(y)] ≈ binv(0.5) + @test vnt[@varname(y)] ≈ binv([0.5]) end end end diff --git a/test/linking.jl b/test/linking.jl index ba22ddc0e..35424c58b 100644 --- a/test/linking.jl +++ b/test/linking.jl @@ -95,7 +95,12 @@ end DynamicPPL.getlogjoint_internal(vi_linked) ≈ log(2) # The non-internal logjoint should be the same since it doesn't depend on linking. @test DynamicPPL.getlogjoint(vi) ≈ DynamicPPL.getlogjoint(vi_linked) - @test vi_linked[@varname(m)] == LowerTriangular(vi[@varname(m)]) + # The internal values should be convertible to the same thing. + @test Bijectors.VectorBijectors.from_linked_vec(dist)( + vi_linked.values[@varname(m)].value + ) ≈ LowerTriangular( + Bijectors.VectorBijectors.from_vec(dist)(vi.values[@varname(m)].value) + ) # Linked one should be working with a lower-dimensional representation. @test length(vi_linked[:]) < length(vi[:]) @test length(vi_linked[:]) == length(y) @@ -106,7 +111,12 @@ end DynamicPPL.invlink(vi_linked, model) end @test length(vi_invlinked[:]) == length(vi[:]) - @test vi_invlinked[@varname(m)] ≈ LowerTriangular(vi[@varname(m)]) + # The internal values should be convertible to the same thing. + @test Bijectors.VectorBijectors.from_vec(dist)( + vi_invlinked.values[@varname(m)].value + ) ≈ LowerTriangular( + Bijectors.VectorBijectors.from_vec(dist)(vi.values[@varname(m)].value) + ) # The non-internal logjoint should still be the same, again since # it doesn't depend on linking. @test DynamicPPL.getlogjoint(vi_invlinked) ≈ DynamicPPL.getlogjoint(vi) @@ -122,11 +132,12 @@ end @testset "d=$d" for d in [2, 3, 5] model = demo_lkj(d) dist = LKJCholesky(d, 1.0, uplo) + fromvec_transform = Bijectors.VectorBijectors.from_vec(dist) values_original = NamedTuple(rand(model)) values_original_x_only = (x=values_original.x,) vis = DynamicPPL.TestUtils.setup_varinfos(model, values_original_x_only) @testset "$(short_varinfo_name(vi))" for vi in vis - val = vi[@varname(x)] + val = fromvec_transform(vi.values[@varname(x)].value) # Ensure that `reconstruct` works as intended. @test val isa Cholesky @test val.uplo == uplo diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index 7ccf4efcb..d1c4cb89b 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -21,10 +21,10 @@ using Mooncake: Mooncake dist = Beta(2, 2) @model f() = x ~ dist expected_ral_unlinked = @vnt begin - x := DynamicPPL.RangeAndLinked(1:1, false) + x := DynamicPPL.RangeAndTransform(1:1, Unlink()) end expected_ral_linked = @vnt begin - x := DynamicPPL.RangeAndLinked(1:1, true) + x := DynamicPPL.RangeAndTransform(1:1, DynamicLink()) end oavi_unlinked = begin accs = OnlyAccsVarInfo(VectorValueAccumulator()) diff --git a/test/model.jl b/test/model.jl index f5ad0e00b..78186d97e 100644 --- a/test/model.jl +++ b/test/model.jl @@ -65,23 +65,27 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() model = GDEMO_DEFAULT # sample from model and extract variables - vi = VarInfo(model) - s = vi[@varname(s)] - m = vi[@varname(m)] + accs = OnlyAccsVarInfo() + accs = setacc!!(accs, RawValueAccumulator(false)) + _, accs = init!!(model, accs, InitFromPrior(), UnlinkAll()) + raw_values = get_raw_values(accs) + + s = raw_values[@varname(s)] + m = raw_values[@varname(m)] # extract log pdf of variable object - lp = getlogjoint(vi) + lp = getlogjoint(accs) # log prior probability - lprior = logprior(model, vi) + lprior = logprior(model, raw_values) @test lprior ≈ logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m) # log likelihood - llikelihood = loglikelihood(model, vi) + llikelihood = loglikelihood(model, raw_values) @test llikelihood ≈ loglikelihood(Normal(m, sqrt(s)), [1.5, 2.0]) # log joint probability - ljoint = logjoint(model, vi) + ljoint = logjoint(model, raw_values) @test ljoint ≈ lprior + llikelihood @test ljoint ≈ lp @@ -106,11 +110,11 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() # Construct mapping of varname symbols to varname-parent symbols. # Here, varname_leaves is used to ensure compatibility with the # variables stored in the chain - var_info = VarInfo(model) + vnt = rand(model) chain_sym_map = Dict{Symbol,Symbol}() - for vn_parent in keys(var_info) + for vn_parent in keys(vnt) sym = DynamicPPL.getsym(vn_parent) - vn_children = AbstractPPL.varname_leaves(vn_parent, var_info[vn_parent]) + vn_children = AbstractPPL.varname_leaves(vn_parent, vnt[vn_parent]) for vn_child in vn_children chain_sym_map[Symbol(vn_child)] = sym end @@ -283,21 +287,6 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @test !any(map(x -> x isa DynamicPPL.AbstractVarInfo, call_retval)) end - @testset "Dynamic constraints" begin - model = DynamicPPL.TestUtils.demo_dynamic_constraint() - vi = VarInfo(model) - vi = link!!(vi, model) - - for i in 1:10 - # Sample with large variations. - r_raw = randn(length(vi[:])) * 10 - vi = DynamicPPL.unflatten!!(vi, r_raw) - @test vi[@varname(m)] == r_raw[1] - @test vi[@varname(x)] != r_raw[2] - model(vi) - end - end - @testset "rand" begin model = GDEMO_DEFAULT @@ -441,21 +430,21 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() vns = DynamicPPL.TestUtils.varnames(model) vns_split = DynamicPPL.TestUtils.varnames_split(model) example_values = DynamicPPL.TestUtils.rand_prior_true(model) - varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values) - @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - # We can set the include_colon_eq arg to false because none of - # the demo models contain :=. The behaviour when - # include_colon_eq is true is tested in test/compiler.jl - varinfo = DynamicPPL.setacc!!(varinfo, RawValueAccumulator(false)) - _, varinfo = init!!(model, varinfo, InitFromPrior(), UnlinkAll()) - realizations = get_raw_values(varinfo) - # Ensure that all variables are found. - vns_found = collect(keys(realizations)) - @test vns_split ∩ vns_found == vns_split ∪ vns_found - # Ensure that the values are the same. - for vn in vns - test_is_equal(realizations[vn], varinfo[vn]) - end + + # We can set the include_colon_eq arg to false because none of + # the demo models contain :=. The behaviour when + # include_colon_eq is true is tested in test/compiler.jl + accs = OnlyAccsVarInfo(RawValueAccumulator(false)) + _, accs = init!!( + model, accs, InitFromParams(example_values, nothing), UnlinkAll() + ) + raw_vals = get_raw_values(accs) + # Ensure that all variables are found. + vns_found = collect(keys(raw_vals)) + @test vns_split ∩ vns_found == vns_split ∪ vns_found + # Ensure that the values are the same. + for vn in vns + test_is_equal(raw_vals[vn], AbstractPPL.getvalue(example_values, vn)) end end diff --git a/test/runtests.jl b/test/runtests.jl index c9a6ad295..1c6644ca1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -45,7 +45,6 @@ Random.seed!(100) end if GROUP in [TEST_GROUP_ALL, TEST_GROUP_GROUP2] - include("bijector.jl") include("logdensityfunction.jl") @testset "extensions" begin include("ext/DynamicPPLMCMCChainsExt.jl") diff --git a/test/submodels.jl b/test/submodels.jl index c17c044e1..867cd97e9 100644 --- a/test/submodels.jl +++ b/test/submodels.jl @@ -10,6 +10,13 @@ mutable struct P b::Float64 end +function get_logp_and_rawval_accs(model::Model) + accs = OnlyAccsVarInfo() + accs = setacc!!(accs, RawValueAccumulator(false)) + _, accs = init!!(model, accs, InitFromPrior(), UnlinkAll()) + return accs +end + @testset "submodels.jl" begin @testset "$op with AbstractPPL API" for op in [condition, fix] x_val = 1.0 @@ -40,10 +47,12 @@ end # Test that the value was correctly set @test model()[1] == x_val # Test that the logp was correctly set - vi = VarInfo(model) - @test getlogjoint(vi) == x_logp + logpdf(Normal(), vi[@varname(a.y)]) + accs = get_logp_and_rawval_accs(model) + raw_vals = get_raw_values(accs) + @test getlogjoint(accs) == + x_logp + logpdf(Normal(), raw_vals[@varname(a.y)]) # Check the keys - @test Set(keys(VarInfo(model))) == Set([@varname(a.y)]) + @test Set(keys(raw_vals)) == Set([@varname(a.y)]) end end @@ -72,10 +81,11 @@ end # Test that the value was correctly set @test model()[1] == x_val # Test that the logp was correctly set - vi = VarInfo(model) - @test getlogjoint(vi) == x_logp + logpdf(Normal(), vi[@varname(y)]) + accs = get_logp_and_rawval_accs(model) + raw_vals = get_raw_values(accs) + @test getlogjoint(accs) == x_logp + logpdf(Normal(), raw_vals[@varname(y)]) # Check the keys - @test Set(keys(VarInfo(model))) == Set([@varname(y)]) + @test Set(keys(raw_vals)) == Set([@varname(y)]) end end @@ -104,10 +114,12 @@ end # Test that the value was correctly set @test model()[1] == x_val # Test that the logp was correctly set - vi = VarInfo(model) - @test getlogjoint(vi) == x_logp + logpdf(Normal(), vi[@varname(b.y)]) + accs = get_logp_and_rawval_accs(model) + raw_vals = get_raw_values(accs) + @test getlogjoint(accs) == + x_logp + logpdf(Normal(), raw_vals[@varname(b.y)]) # Check the keys - @test Set(keys(VarInfo(model))) == Set([@varname(b.y)]) + @test Set(keys(raw_vals)) == Set([@varname(b.y)]) end end @@ -125,13 +137,13 @@ end return (p.a, p.b) end expected_vns = Set([@varname(p.a.x[1]), @varname(p.a.y), @varname(p.b)]) - @test Set(keys(VarInfo(g()))) == expected_vns + @test Set(keys(rand(g()))) == expected_vns # Check that we can condition/fix on any of them from the outside for vn in expected_vns op_g = op(g(), (vn => 1.0)) - vi = VarInfo(op_g) - @test Set(keys(vi)) == symdiff(expected_vns, Set([vn])) + vnt = rand(op_g) + @test Set(keys(vnt)) == symdiff(expected_vns, Set([vn])) end end @@ -148,11 +160,12 @@ end end # No conditioning - vi = VarInfo(h()) - @test Set(keys(vi)) == Set([@varname(a.b.x), @varname(a.b.y)]) - @test getlogjoint(vi) == - logpdf(Normal(), vi[@varname(a.b.x)]) + - logpdf(Normal(), vi[@varname(a.b.y)]) + accs = get_logp_and_rawval_accs(h()) + raw_vals = get_raw_values(accs) + @test Set(keys(raw_vals)) == Set([@varname(a.b.x), @varname(a.b.y)]) + @test getlogjoint(accs) == + logpdf(Normal(), raw_vals[@varname(a.b.x)]) + + logpdf(Normal(), raw_vals[@varname(a.b.y)]) # Conditioning/fixing at the top level op_h = op(h(), (@varname(a.b.x) => x_val)) @@ -174,9 +187,11 @@ end models = [("top", op_h), ("middle", h2()), ("bottom", h3())] @testset "$name" for (name, model) in models - vi = VarInfo(model) - @test Set(keys(vi)) == Set([@varname(a.b.y)]) - @test getlogjoint(vi) == x_logp + logpdf(Normal(), vi[@varname(a.b.y)]) + accs = get_logp_and_rawval_accs(model) + raw_vals = get_raw_values(accs) + @test Set(keys(raw_vals)) == Set([@varname(a.b.y)]) + @test getlogjoint(accs) == + x_logp + logpdf(Normal(), raw_vals[@varname(a.b.y)]) end end end @@ -190,11 +205,11 @@ end return a ~ to_submodel(f(inner_x)) end - vi = VarInfo(g(1.0)) - @test Set(keys(vi)) == Set([@varname(a.y)]) + vnt = rand(g(1.0)) + @test Set(keys(vnt)) == Set([@varname(a.y)]) - vi = VarInfo(g(missing)) - @test Set(keys(vi)) == Set([@varname(a.x), @varname(a.y)]) + vnt = rand(g(missing)) + @test Set(keys(vnt)) == Set([@varname(a.x), @varname(a.y)]) end @testset ":= in submodels" begin @@ -210,12 +225,12 @@ end end model = outer1() - a, vi = DynamicPPL.init!!(model, VarInfo()) - @test only(keys(vi)) == @varname(x.a) + vnt = rand(model) + @test only(keys(vnt)) == @varname(x.a) - vi = OnlyAccsVarInfo((RawValueAccumulator(true),)) - _, vi = init!!(model, vi, InitFromPrior(), UnlinkAll()) - vnt = get_raw_values(vi) + accs = OnlyAccsVarInfo((RawValueAccumulator(true),)) + a, accs = init!!(model, accs, InitFromPrior(), UnlinkAll()) + vnt = get_raw_values(accs) @test vnt[@varname(x.a)] == a @test vnt[@varname(x.b)] == vnt[@varname(x.a)] + 1.0 end @@ -235,12 +250,12 @@ end end model = outer2() - a, vi = DynamicPPL.init!!(model, VarInfo()) - @test only(keys(vi)) == @varname(x.a) + vnt = rand(model) + @test only(keys(vnt)) == @varname(x.a) - vi = OnlyAccsVarInfo((RawValueAccumulator(true),)) - _, vi = init!!(model, vi, InitFromPrior(), UnlinkAll()) - vnt = get_raw_values(vi) + accs = OnlyAccsVarInfo((RawValueAccumulator(true),)) + a, accs = init!!(model, accs, InitFromPrior(), UnlinkAll()) + vnt = get_raw_values(accs) @test vnt[@varname(x.a)] == a @test vnt[@varname(x.b[1])] == vnt[@varname(x.a)] + 1.0 # If the templating fails, then x.b will be stored as a GrowableArray, and @@ -288,14 +303,11 @@ end end end model = outer() - vi = VarInfo(model) - @test Set(keys(vi)) == Set([@varname(x[i].a) for i in 1:4]) + vnt = rand(model) + @test Set(keys(vnt)) == Set([@varname(x[i].a) for i in 1:4]) for i in 1:4 - # Need to be careful about what we're testing here. If we do vi[vn], then - # it expects that vi.values[vn] isa AbstractTransformedValue. That is true - # of the inner keys (x[i].a), but x[i] is not itself a key. - @test vi.values[@varname(x[i])] isa VarNamedTuple - @test vi[@varname(x[i].a)] isa Float64 + @test vnt[@varname(x[i])] isa VarNamedTuple + @test vnt[@varname(x[i].a)] isa Float64 end end end diff --git a/test/varinfo.jl b/test/varinfo.jl index bde2ab67b..e588e5940 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -18,6 +18,12 @@ function check_varinfo_keys(varinfo, vns) vns_varinfo = keys(varinfo) @test union(vns_varinfo, vns) == intersect(vns_varinfo, vns) end +function check_varinfo_values(varinfo1, varinfo2, vns) + for vn in vns + @test DynamicPPL.get_transformed_value(varinfo1, vn) == + DynamicPPL.get_transformed_value(varinfo2, vn) + end +end function check_metadata_type_equal(v1::VarInfo, v2::VarInfo) @test typeof(v1.values) == typeof(v2.values) @@ -60,21 +66,26 @@ end @test !haskey(vi, vn) @test !(vn in keys(vi)) - vi = DynamicPPL.setindex_with_dist!!(vi, UntransformedValue(x), Normal(), vn, x) + vi = DynamicPPL.setindex_with_dist!!( + vi, TransformedValue(x, NoTransform()), Normal(), vn, x + ) @test !isempty(vi) @test haskey(vi, vn) @test vn in keys(vi) - @test length(vi[vn]) == 1 - @test vi[vn] == x + @test DynamicPPL.getindex_internal(vi, vn) == [x] @test vi[:] == [x] - vi = DynamicPPL.setindex_with_dist!!(vi, UntransformedValue(2 * x), Normal(), vn, x) - @test vi[vn] == 2 * x + vi = DynamicPPL.setindex_with_dist!!( + vi, TransformedValue(2 * x, NoTransform()), Normal(), vn, x + ) + @test DynamicPPL.getindex_internal(vi, vn) == [2 * x] @test vi[:] == [2 * x] vi = empty!!(vi) @test isempty(vi) - vi = DynamicPPL.setindex_with_dist!!(vi, UntransformedValue(x), Normal(), vn, x) + vi = DynamicPPL.setindex_with_dist!!( + vi, TransformedValue(x, NoTransform()), Normal(), vn, x + ) @test !isempty(vi) end @@ -250,7 +261,9 @@ end vn_x = @varname x x = rand() - vi = DynamicPPL.setindex_with_dist!!(vi, UntransformedValue(x), Normal(), vn_x, x) + vi = DynamicPPL.setindex_with_dist!!( + vi, TransformedValue(x, NoTransform()), Normal(), vn_x, x + ) # is_transformed is unset by default @test !is_transformed(vi, vn_x) @@ -291,20 +304,20 @@ end model = gdemo([1.0, 1.5], [2.0, 2.5]) all_transformed(vi) = mapreduce( - p -> p.second isa DynamicPPL.LinkedVectorValue, &, vi.values; init=true + p -> p.second.transform isa DynamicPPL.DynamicLink, &, vi.values; init=true ) any_transformed(vi) = mapreduce( - p -> p.second isa DynamicPPL.LinkedVectorValue, |, vi.values; init=false + p -> p.second.transform isa DynamicPPL.DynamicLink, |, vi.values; init=false ) # Check that linking and invlinking set the `is_transformed` flag accordingly vi = VarInfo(model) - vals = values(vi) + vals = vi[:] vi = link!!(vi, model) @test all_transformed(vi) vi = invlink!!(vi, model) @test !any_transformed(vi) - @test values(vi) ≈ vals atol = 1e-10 + @test vi[:] ≈ vals atol = 1e-10 # Transform only one variable all_vns = keys(vi) @@ -325,7 +338,7 @@ end @test !any_transformed(subset(vi, other_vns)) vi = invlink!!(vi, (vn,), model) @test !any_transformed(vi) - @test values(vi) ≈ vals atol = 1e-10 + @test vi[:] ≈ vals atol = 1e-10 end end @@ -350,11 +363,11 @@ end vi = VarInfo(Xoshiro(468), model, InitFromPrior(), transform_strategy) for vn in keys(vi) if vn in expected_linked_vns - @test DynamicPPL.get_transformed_value(vi, vn) isa - DynamicPPL.LinkedVectorValue + @test DynamicPPL.get_transformed_value(vi, vn).transform isa + DynamicPPL.DynamicLink else - @test DynamicPPL.get_transformed_value(vi, vn) isa - DynamicPPL.VectorValue + @test DynamicPPL.get_transformed_value(vi, vn).transform isa + DynamicPPL.Unlink end end # Test that initialising directly is the same as linking later (if rng is the @@ -568,7 +581,7 @@ end # Should now only contain the variables in `vns_subset`. check_varinfo_keys(varinfo_subset, vns_subset) # Values should be the same. - @test [varinfo_subset[vn] for vn in vns_subset] == [varinfo[vn] for vn in vns_subset] + check_varinfo_values(varinfo_subset, varinfo, vns_subset) # `merge` with the original. varinfo_merged = merge(varinfo, varinfo_subset) @@ -576,7 +589,7 @@ end # Should be equivalent. check_varinfo_keys(varinfo_merged, vns) # Values should be the same. - @test [varinfo_merged[vn] for vn in vns] == [varinfo[vn] for vn in vns] + check_varinfo_values(varinfo_merged, varinfo, vns) end @testset "$(convert(Vector{VarName}, vns_subset))" for ( @@ -586,7 +599,7 @@ end # Should now only contain the variables in `vns_subset`. check_varinfo_keys(varinfo_subset, vns_target) # Values should be the same. - @test [varinfo_subset[vn] for vn in vns_target] == [varinfo[vn] for vn in vns_target] + check_varinfo_values(varinfo_subset, varinfo, vns_target) # `merge` with the original. varinfo_merged = merge(varinfo, varinfo_subset) @@ -594,7 +607,7 @@ end # Should be equivalent. check_varinfo_keys(varinfo_merged, vns) # Values should be the same. - @test [varinfo_merged[vn] for vn in vns] == [varinfo[vn] for vn in vns] + check_varinfo_values(varinfo_subset, varinfo, vns_target) end @testset "$(convert(Vector{VarName}, vns_subset)) order" for vns_subset in @@ -603,7 +616,9 @@ end vns_subset_reversed = reverse(vns_subset) varinfo_subset_reversed = subset(varinfo, vns_subset_reversed) @test varinfo_subset[:] == varinfo_subset_reversed[:] - ground_truth = [varinfo[vn] for vn in vns_subset] + ground_truth = [ + only(DynamicPPL.getindex_internal(varinfo, vn)) for vn in vns_subset + ] @test varinfo_subset[:] == ground_truth end end @@ -621,8 +636,12 @@ end varinfo_merged = merge(varinfo, varinfo) # Varnames should be unchanged. check_varinfo_keys(varinfo_merged, vns) - # Values should be the same. - @test [varinfo_merged[vn] for vn in vns] == [varinfo[vn] for vn in vns] + # Values should be the same. (Have to use `get_values` since varinfo + # might be a TSVI) + for vn in keys(varinfo_merged) + @test DynamicPPL.get_values(varinfo_merged)[vn] == + DynamicPPL.get_values(varinfo)[vn] + end # Metadata types should be exactly the same. check_metadata_type_equal(varinfo_merged, varinfo) end @@ -633,7 +652,10 @@ end # Varnames should be unchanged. check_varinfo_keys(varinfo_merged, vns) # Values should be the same. - @test [varinfo_merged[vn] for vn in vns] == [varinfo[vn] for vn in vns] + for vn in keys(varinfo_merged) + @test DynamicPPL.get_values(varinfo_merged)[vn] == + DynamicPPL.get_values(varinfo)[vn] + end # Metadata types should be exactly the same. check_metadata_type_equal(varinfo_merged, varinfo) end @@ -645,7 +667,10 @@ end # Varnames should be unchanged. check_varinfo_keys(varinfo_merged, vns) # Values should be the same. - @test [varinfo_merged[vn] for vn in vns] == [varinfo[vn] for vn in vns] + for vn in keys(varinfo_merged) + @test DynamicPPL.get_values(varinfo_merged)[vn] == + DynamicPPL.get_values(varinfo)[vn] + end # Metadata types should be exactly the same. check_metadata_type_equal(varinfo_merged, varinfo) @@ -655,7 +680,10 @@ end # Varnames should be unchanged. check_varinfo_keys(varinfo_merged, vns) # Values should be the same. - @test [varinfo_merged[vn] for vn in vns] == [varinfo[vn] for vn in vns] + for vn in keys(varinfo_merged) + @test DynamicPPL.get_values(varinfo_merged)[vn] == + DynamicPPL.get_values(varinfo)[vn] + end # Metadata types should be exactly the same. check_metadata_type_equal(varinfo_merged, varinfo) end @@ -665,9 +693,16 @@ end varinfo_changed = last( init!!(model, deepcopy(varinfo), InitFromParams(x, nothing)) ) - # After `merge`, we should have the same values as `x`. + # After `merge`, we should have the same values as `x` (or, to be + # precise, we have things that will give `x` after we reevaluate with + # those parameters). varinfo_merged = merge(varinfo, varinfo_changed) - DynamicPPL.TestUtils.test_values(varinfo_merged, x, vns) + init_strat = InitFromParams( + DynamicPPL.get_values(varinfo_merged), nothing + ) + accs = OnlyAccsVarInfo(RawValueAccumulator(false)) + _, accs = init!!(model, accs, init_strat, UnlinkAll()) + DynamicPPL.TestUtils.test_values(get_raw_values(accs), x, vns) end end end @@ -693,7 +728,7 @@ end check_varinfo_keys(varinfo_merged, vns) # Right has precedence. - @test varinfo_merged[@varname(x)] == varinfo_right[@varname(x)] + @test varinfo_merged.values[@varname(x)] == varinfo_right.values[@varname(x)] @test DynamicPPL.is_transformed(varinfo_merged, @varname(x)) end end @@ -702,13 +737,17 @@ end @testset "merge different dimensions" begin vn = @varname(x) vi_single = DynamicPPL.setindex_with_dist!!( - VarInfo(), UntransformedValue(1.0), Normal(), vn, 1.0 + VarInfo(), TransformedValue(1.0, NoTransform()), Normal(), vn, 1.0 ) vi_double = DynamicPPL.setindex_with_dist!!( - VarInfo(), UntransformedValue([0.5, 0.6]), MvNormal(zeros(2), I), vn, [0.5, 0.6] + VarInfo(), + TransformedValue([0.5, 0.6], NoTransform()), + MvNormal(zeros(2), I), + vn, + [0.5, 0.6], ) - @test merge(vi_single, vi_double)[vn] == [0.5, 0.6] - @test merge(vi_double, vi_single)[vn] == 1.0 + @test DynamicPPL.getindex_internal(merge(vi_single, vi_double), vn) == [0.5, 0.6] + @test DynamicPPL.getindex_internal(merge(vi_double, vi_single), vn) == [1.0] end @testset "issue #842" begin From 40ee83d887dd3bf8ff2013af5b0f685462d35f4a Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 30 Mar 2026 11:25:26 +0100 Subject: [PATCH 2/3] Generalise LinkedVecTransformAcc -> FixedTransformAcc --- HISTORY.md | 11 ++-- docs/src/accs/existing.md | 4 +- src/DynamicPPL.jl | 6 +- src/accumulators/fixed_transforms.jl | 76 +++++++++++++++++++++++ src/accumulators/linked_vec_transforms.jl | 32 ---------- 5 files changed, 86 insertions(+), 43 deletions(-) create mode 100644 src/accumulators/fixed_transforms.jl delete mode 100644 src/accumulators/linked_vec_transforms.jl diff --git a/HISTORY.md b/HISTORY.md index 94244092d..b665bdf81 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -35,8 +35,7 @@ In such cases, using `FixedTransform` can lead to substantial performance improv To use `FixedTransform` with `LogDensityFunction`, you need to: 1. Create a `VarNamedTuple` mapping `VarName`s to `FixedTransform`s for the variables in your model. - This can be done using `get_linked_vec_transforms(model)`, which automatically calculates `Bijectors.VectorBijectors.from_linked_vec(dist)` for each variable in the model. - TODO: Control whether it's linked or not???? + This can be done using `DynamicPPL.FixedTransformAccumulator` (see the DynamicPPL docs for more info), but is most easily done by calling `get_fixed_transforms(model, transform_strategy)`, where `transform_strategy` says whether you want linked or unlinked transforms. 2. Wrap the `VarNamedTuple` inside `WithTransforms(vnt, UnlinkAll())`. `WithTransforms` is a subtype of `AbstractTransformStrategy`, much like `LinkAll()`. @@ -59,13 +58,13 @@ In the current version, we have removed this method to prevent the possibility o In particular, to access raw (untransformed) values, you should use an `OnlyAccsVarInfo` with a `RawValueAccumulator`. There is [a migration guide available on the DynamicPPL documentation](https://turinglang.org/DynamicPPL.jl/stable/migration/) and we are very happy to add more examples to this if you run into something that is not covered. -### LinkedVecTransformAccumulator +### FixedTransformAccumulator TODO, this part is still being worked on. - - `BijectorAccumulator` → `LinkedVecTransformAccumulator` - - `get_linked_vec_transforms(::VarInfo)` - - `get_linked_vec_transforms(::Model)` + - `BijectorAccumulator` → `FixedTransformAccumulator` + - `get_fixed_transforms(::VarInfo)` + - `get_fixed_transforms(::Model)` ## Miscellaneous breaking changes diff --git a/docs/src/accs/existing.md b/docs/src/accs/existing.md index f2e7f51ea..a11d6cb1c 100644 --- a/docs/src/accs/existing.md +++ b/docs/src/accs/existing.md @@ -42,6 +42,6 @@ get_vector_params ```@docs PriorDistributionAccumulator get_priors -LinkedVecTransformAccumulator -get_linked_vec_transforms +FixedTransformAccumulator +get_fixed_transforms ``` diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 8bc540d9d..8d6e6ba74 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -102,8 +102,8 @@ export AbstractVarInfo, # Accumulators - miscellany PriorDistributionAccumulator, get_priors, - LinkedVecTransformAccumulator, - get_linked_vec_transforms, + FixedTransformAccumulator, + get_fixed_transforms, # Working with internal values as vectors unflatten!!, internal_values_as_vector, @@ -257,7 +257,7 @@ include("accumulators/vnt.jl") include("accumulators/vector_values.jl") include("accumulators/priors.jl") include("accumulators/raw_values.jl") -include("accumulators/linked_vec_transforms.jl") +include("accumulators/fixed_transforms.jl") include("accumulators/pointwise_logdensities.jl") include("abstract_varinfo.jl") include("threadsafe.jl") diff --git a/src/accumulators/fixed_transforms.jl b/src/accumulators/fixed_transforms.jl new file mode 100644 index 000000000..e1db053af --- /dev/null +++ b/src/accumulators/fixed_transforms.jl @@ -0,0 +1,76 @@ +const FIXED_TRANSFORM_ACCNAME = :FixedTransformAccumulator + +function _get_fixed_transform( + val, tv::TransformedValue{V,DynamicLink}, logjac, vn, dist +) where {V} + return FixedTransform(Bijectors.VectorBijectors.from_linked_vec(dist)) +end +function _get_fixed_transform( + val, tv::TransformedValue{V,Unlink}, logjac, vn, dist +) where {V} + return FixedTransform(Bijectors.VectorBijectors.from_vec(dist)) +end +function _get_fixed_transform( + val, tv::TransformedValue{V,<:FixedTransform}, logjac, vn, dist +) where {V} + return tv.transform +end + +""" + FixedTransformAccumulator() + +An accumulator that calculates and stores the 'fixed' transforms for all variables in a model. + +Normally, when running a model with a transform strategy such as `LinkAll`, the transforms are +calculated *during* model execution and not cached. This ensures that the transforms are up-to-date +with the current variable values, which can matter in cases such as + +```julia +x ~ Normal() +y ~ truncated(Normal(); lower=x) +``` + +or + +```julia +x ~ Normal() +y ~ (x > 0 ? Normal() : Exponential()) +``` + +where the transforms for `y` depend on the value of `x`. + +""" +FixedTransformAccumulator() = VNTAccumulator{FIXED_TRANSFORM_ACCNAME}(_get_fixed_transform) + +""" + get_fixed_transforms(vi::DynamicPPL.AbstractVarInfo) + +Extract the transforms stored in the [`FixedTransformAccumulator`](@ref) of an +AbstractVarInfo. Errors if the AbstractVarInfo does not have a `FixedTransformAccumulator`. +""" +function get_fixed_transforms(vi::DynamicPPL.AbstractVarInfo) + return DynamicPPL.getacc(vi, Val(FIXED_TRANSFORM_ACCNAME)).values +end + +""" + get_fixed_transforms( + model::DynamicPPL.Model, + transform_strategy::AbstractTransformStrategy + ) + +Extract the fixed transforms for all variables in a model by running the model with the +given transform strategy. + +Note that, even though this method evaluates the model once, this method does *not* accept +an RNG argument to control that evaluation. This is because the fixed transforms are +supposed to be *fixed*, i.e., they should not depend on random choices made during model +execution! +""" +function get_fixed_transforms( + model::DynamicPPL.Model, transform_strategy::AbstractTransformStrategy +) + rng = Random.default_rng() + accs = OnlyAccsVarInfo(FixedTransformAccumulator()) + _, accs = init!!(rng, model, accs, InitFromPrior(), transform_strategy) + return get_fixed_transforms(accs) +end diff --git a/src/accumulators/linked_vec_transforms.jl b/src/accumulators/linked_vec_transforms.jl deleted file mode 100644 index 50ff897d0..000000000 --- a/src/accumulators/linked_vec_transforms.jl +++ /dev/null @@ -1,32 +0,0 @@ -const LINKEDVECTRANSFORM_ACCNAME = :LinkedVecTransformAccumulator -function _get_linked_vec_transform(val, tv, logjac, vn, dist) - return FixedTransform(Bijectors.VectorBijectors.from_linked_vec(dist)) -end - -""" - LinkedVecTransformAccumulator() - -An accumulator that stores the transform required to convert a linked vector into the -original, untransformed value. -""" -LinkedVecTransformAccumulator() = - VNTAccumulator{LINKEDVECTRANSFORM_ACCNAME}(_get_linked_vec_transform) - -""" - get_linked_vec_transforms(vi::DynamicPPL.AbstractVarInfo) - -Extract the transforms stored in the `LinkedVecTransformAccumulator` of an AbstractVarInfo. -Errors if the AbstractVarInfo does not have a `LinkedVecTransformAccumulator`. -""" -function get_linked_vec_transforms(vi::DynamicPPL.AbstractVarInfo) - return DynamicPPL.getacc(vi, Val(LINKEDVECTRANSFORM_ACCNAME)).values -end - -function get_linked_vec_transforms(rng::Random.AbstractRNG, model::DynamicPPL.Model) - accs = OnlyAccsVarInfo(LinkedVecTransformAccumulator()) - _, accs = init!!(rng, model, accs, InitFromPrior(), UnlinkAll()) - return get_linked_vec_transforms(accs) -end -function get_linked_vec_transforms(model::DynamicPPL.Model) - return get_linked_vec_transforms(Random.default_rng(), model) -end From 042b2cb64ee874fca6977bd8384d956ea646daee Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 30 Mar 2026 11:26:33 +0100 Subject: [PATCH 3/3] Fix test --- test/accumulators.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/accumulators.jl b/test/accumulators.jl index 581ea64dc..1e209e4c5 100644 --- a/test/accumulators.jl +++ b/test/accumulators.jl @@ -31,7 +31,7 @@ TEST_ACCUMULATORS = ( LogJacobianAccumulator(1.0), RawValueAccumulator(false), DynamicPPL.DebugRawValueAccumulator(), - DynamicPPL.BijectorAccumulator(), + DynamicPPL.FixedTransformAccumulator(), DynamicPPL.VNTAccumulator{DynamicPPL.POINTWISE_ACCNAME}( DynamicPPL.PointwiseLogProb{true,true}() ),