Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/Integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ jobs:
- Molly
- MPI
- Comrade
- Turing
exclude:
- version: '1.10'
os: linux-x86-n2-32
Expand Down
2 changes: 1 addition & 1 deletion test/integration/Bijectors/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@ Enzyme = { path = "../../.." }
EnzymeCore = { path = "../../../lib/EnzymeCore" }

[compat]
Bijectors = "=0.13.16"
Bijectors = "=0.15.14"
FiniteDifferences = "0.12.32"
StableRNGs = "1.0.2"
131 changes: 98 additions & 33 deletions test/integration/Bijectors/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,19 @@ module BijectorsIntegrationTests
using Bijectors: Bijectors
using Enzyme: Enzyme
using FiniteDifferences: FiniteDifferences
using LinearAlgebra: LinearAlgebra
using LinearAlgebra: I, cholesky, Hermitian
using Random: randn
using StableRNGs: StableRNG
using Test: @test, @test_broken, @testset
using Test: @test, @test_broken, @testset, @test_throws

rng = StableRNG(23)

"""
Enum type for choosing Enzyme autodiff modes.
"""
@enum ModeSelector Neither Forward Reverse Both
includes_forward(mode::ModeSelector) = mode === Forward || mode === Both
includes_reverse(mode::ModeSelector) = mode === Reverse || mode === Both

"""
Type for specifying a test case for `Enzyme.gradient`.
Expand Down Expand Up @@ -42,41 +44,49 @@ end

# Default values for most arguments.
function TestCase(
f, value;
name=nothing, runtime_activity=Neither, broken=Neither, skip=Neither, splat=false
)
f, value;
name = nothing, runtime_activity = Neither, broken = Neither, skip = Neither, splat = false
)
return TestCase(f, value, name, runtime_activity, broken, skip, splat)
end

"""
Test Enzyme.gradient, both Forward and Reverse mode, against FiniteDifferences.grad.
"""
function test_grad(case::TestCase; rtol=1e-6, atol=1e-6)
function test_grad(case::TestCase; rtol = 1.0e-6, atol = 1.0e-6)
@nospecialize
f = case.func
# We'll call the function as f(x...), so wrap in a singleton tuple if need be.
x = case.splat ? case.value : (case.value,)
finitediff = FiniteDifferences.grad(FiniteDifferences.central_fdm(4, 1), f, x...)[1]

f_mode = if (case.runtime_activity === Both || case.runtime_activity === Forward)
f_mode = if includes_forward(case.runtime_activity)
Enzyme.set_runtime_activity(Enzyme.Forward)
else
Enzyme.Forward
end
r_mode = if (case.runtime_activity === Both || case.runtime_activity === Reverse)
r_mode = if includes_reverse(case.runtime_activity)
Enzyme.set_runtime_activity(Enzyme.Reverse)
else
Enzyme.Reverse
end

if !(case.skip === Forward) && !(case.skip === Both)
if case.broken === Both || case.broken === Forward
if !includes_forward(case.skip)
if includes_forward(case.broken)
@test_broken(
Enzyme.gradient(f_mode, Enzyme.Const(f), x...)[1] ≈ finitediff,
rtol = rtol,
atol = atol,
)
else
# If runtime activity was requested, check that it actually was needed.
if includes_forward(case.runtime_activity)
@test_throws Enzyme.Compiler.EnzymeRuntimeActivityError Enzyme.gradient(
Enzyme.Forward,
Enzyme.Const(f),
x...,
)
end
@test(
Enzyme.gradient(f_mode, Enzyme.Const(f), x...)[1] ≈ finitediff,
rtol = rtol,
Expand All @@ -85,14 +95,22 @@ function test_grad(case::TestCase; rtol=1e-6, atol=1e-6)
end
end

if !(case.skip === Reverse) && !(case.skip === Both)
if case.broken === Both || case.broken === Reverse
if !includes_reverse(case.skip)
if includes_reverse(case.broken)
@test_broken(
Enzyme.gradient(r_mode, Enzyme.Const(f), x...)[1] ≈ finitediff,
rtol = rtol,
atol = atol,
)
else
# If runtime activity was requested, check that it actually was needed.
if includes_reverse(case.runtime_activity)
@test_throws Enzyme.Compiler.EnzymeRuntimeActivityError Enzyme.gradient(
Enzyme.Reverse,
Enzyme.Const(f),
x...,
)
end
@test(
Enzyme.gradient(r_mode, Enzyme.Const(f), x...)[1] ≈ finitediff,
rtol = rtol,
Expand All @@ -107,22 +125,23 @@ end
A helper function that returns a TestCase that evaluates sum(bijector(inverse(bijector)(x)))
"""
function sum_b_binv_test_case(
bijector, dim; runtime_activity=Neither, name=nothing, broken=Neither, skip=Neither
)
bijector, dim; runtime_activity = Neither, name = nothing, broken = Neither, skip = Neither
)
if name === nothing
name = string(bijector)
end
b_inv = Bijectors.inverse(bijector)
return TestCase(
x -> sum(bijector(b_inv(x))),
randn(rng, dim);
runtime_activity=runtime_activity, name=name, broken=broken, skip=skip
runtime_activity = runtime_activity, name = name, broken = broken, skip = skip
)
end

@testset "Bijectors integration tests" begin
test_cases = TestCase[
sum_b_binv_test_case(Bijectors.VecCorrBijector(), 3),
sum_b_binv_test_case(Bijectors.VecCorrBijector(), (1, 1)),
sum_b_binv_test_case(Bijectors.VecCorrBijector(), 0),
sum_b_binv_test_case(Bijectors.CorrBijector(), (3, 3)),
sum_b_binv_test_case(Bijectors.CorrBijector(), (0, 0)),
Expand All @@ -141,16 +160,23 @@ end
sum_b_binv_test_case(Bijectors.PDBijector(), (3, 3)),
sum_b_binv_test_case(Bijectors.PDVecBijector(), 3),
sum_b_binv_test_case(
Bijectors.Permute([
0 1 0;
1 0 0;
0 0 1
]),
Bijectors.Permute(
[
0 1 0;
1 0 0;
0 0 1
]
),
(3, 3),
),
# TODO(mhauru) Both modes broken because of
# https://github.com/EnzymeAD/Enzyme.jl/issues/2035
sum_b_binv_test_case(Bijectors.PlanarLayer(3), (3, 3); broken=Both),
# NOTE(penelopeysm) This requires runtime activity on 1.11 reverse-mode, and 1.11
# forward-mode fails as this calls gemm! and runtime activity is not yet supported
# for BLAS calls. 1.10 works fine without runtime activity.
sum_b_binv_test_case(
Bijectors.PlanarLayer(3), (3, 3);
runtime_activity = ((v"1.10" <= VERSION < v"1.11" ? Neither : Reverse)),
broken = ((v"1.10" <= VERSION < v"1.11") ? Neither : Forward)
),
sum_b_binv_test_case(Bijectors.RadialLayer(3), 3),
sum_b_binv_test_case(Bijectors.Reshape((2, 3), (3, 2)), (2, 3)),
sum_b_binv_test_case(Bijectors.Scale(0.2), 3),
Expand All @@ -160,47 +186,86 @@ end
sum_b_binv_test_case(Bijectors.TruncatedBijector(-0.2, 0.5), 3),

# Below, some test cases that don't fit the sum_b_binv_test_case mold.

TestCase(
function (x)
return sum(Bijectors.PDVecBijector()(x * x' + I))
end,
randn(rng, 4, 4),
name = "PDVecBijector forward only",
),
TestCase(
function (x)
binv = Bijectors.inverse(Bijectors.PDVecBijector())
return sum(cholesky(Hermitian(binv(x), :L)).L)
end,
Bijectors.PDVecBijector()((x -> x * x' + I)(randn(rng, 4, 4))),
name = "PDVecBijector inverse only + lower Cholesky",
),
TestCase(
function (x)
binv = Bijectors.inverse(Bijectors.PDVecBijector())
return sum(cholesky(Hermitian(binv(x), :U)).U)
end,
Bijectors.PDVecBijector()((x -> x * x' + I)(randn(rng, 4, 4))),
name = "PDVecBijector inverse only + upper Cholesky",
),
TestCase(
function (x)
b = Bijectors.RationalQuadraticSpline([-0.2, 0.1, 0.5], [-0.3, 0.3, 0.9], [1.0, 0.2, 1.0])
binv = Bijectors.inverse(b)
return sum(binv(b(x)))
end,
randn(rng);
name="RationalQuadraticSpline on scalar",
name = "RationalQuadraticSpline on scalar",
),

TestCase(
function (x)
b = Bijectors.OrderedBijector()
binv = Bijectors.inverse(b)
return sum(binv(b(x)))
end,
randn(rng, 7);
name="OrderedBijector",
name = "OrderedBijector",
),

TestCase(
function (x)
layer = Bijectors.PlanarLayer(x[1:2], x[3:4], x[5:5])
flow = Bijectors.transformed(Bijectors.MvNormal(zeros(2), LinearAlgebra.I), layer)
flow = Bijectors.transformed(Bijectors.MvNormal(zeros(2), I), layer)
x = x[6:7]
return Bijectors.logpdf(flow.dist, x) - Bijectors.logabsdetjac(flow.transform, x)
end,
randn(rng, 7);
name="PlanarLayer7"
name = "PlanarLayer7 forward"
),
TestCase(
function (x)
layer = Bijectors.PlanarLayer(x[1:2], x[3:4], x[5:5])
flow = Bijectors.transformed(Bijectors.MvNormal(zeros(2), I), layer)
x = reshape(x[6:end], 2, :)
return sum(Bijectors.logpdf(flow.dist, x) - Bijectors.logabsdetjac(flow.transform, x))
end,
randn(rng, 11);
name = "PlanarLayer11 forward"
),
TestCase(
function (x)
layer = Bijectors.PlanarLayer(x[1:2], x[3:4], x[5:5])
flow = Bijectors.transformed(Bijectors.MvNormal(zeros(2), I), Bijectors.inverse(layer))
x = x[6:7]
return Bijectors.logpdf(flow.dist, x) - Bijectors.logabsdetjac(flow.transform, x)
end,
randn(rng, 7);
name = "PlanarLayer7 inverse"
),

TestCase(
function (x)
layer = Bijectors.PlanarLayer(x[1:2], x[3:4], x[5:5])
flow = Bijectors.transformed(Bijectors.MvNormal(zeros(2), LinearAlgebra.I), layer)
flow = Bijectors.transformed(Bijectors.MvNormal(zeros(2), I), Bijectors.inverse(layer))
x = reshape(x[6:end], 2, :)
return sum(Bijectors.logpdf(flow.dist, x) - Bijectors.logabsdetjac(flow.transform, x))
end,
randn(rng, 11);
name="PlanarLayer11"
name = "PlanarLayer11 inverse"
),
]

Expand Down
17 changes: 17 additions & 0 deletions test/integration/Turing/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"

[sources]
Enzyme = {path = "../../.."}
EnzymeCore = {path = "../../../lib/EnzymeCore"}

[compat]
Turing = "=0.42.1"
DynamicPPL = "=0.39.9"
79 changes: 79 additions & 0 deletions test/integration/Turing/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
module TuringIntegrationTests

using ADTypes: AutoEnzyme
using DynamicPPL
using Enzyme: Enzyme
import ForwardDiff
using StableRNGs: StableRNG
using Test
using Turing

adtypes = (
AutoEnzyme(; mode = Enzyme.set_runtime_activity(Enzyme.Forward)),
AutoEnzyme(; mode = Enzyme.set_runtime_activity(Enzyme.Reverse)),
)
Comment on lines +11 to +14
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

once TuringLang/DynamicPPL.jl#1172 is in it shouldn't need Const.


# Some supplements to DynamicPPL.TestUtils.ALL_MODELS.
@model function assume_normal()
a ~ Normal()
end
dppl_lda = begin
v = 100 # words
k = 5 # topics
m = 10 # number of docs
alpha = ones(k)
beta = ones(v)
phi = rand(Dirichlet(beta), k)
theta = rand(Dirichlet(alpha), m)
doc_lengths = rand(Poisson(1_000), m)
n = sum(doc_lengths)
w = Vector{Int}(undef, n)
doc = Vector{Int}(undef, n)
for i in 1:m
local idx = sum(doc_lengths[1:(i - 1)]) # starting index for inner loop
for j in 1:doc_lengths[i]
z = rand(Categorical(theta[:, i]))
w[idx + j] = rand(Categorical(phi[:, z]))
doc[idx + j] = i
end
end
@model function dppl_lda(k, m, w, doc, alpha, beta)
theta ~ product_distribution(fill(Dirichlet(alpha), m))
phi ~ product_distribution(fill(Dirichlet(beta), k))
log_phi_dot_theta = log.(phi * theta)
@addlogprob! sum(log_phi_dot_theta[CartesianIndex.(w, doc)])
end
dppl_lda
end
MODELS = [
DynamicPPL.TestUtils.ALL_MODELS...,
assume_normal(),
dppl_lda(k, m, w, doc, alpha, beta),
]

@testset "AD on logdensity" begin
# This code is essentially what Turing's HMC/NUTS samplers use internally
@testset "$(model.f)" for model in MODELS
@testset "AD type: $(adtype)" for adtype in adtypes
@test DynamicPPL.TestUtils.AD.run_ad(model, adtype; rng = StableRNG(468), test = true, benchmark = false) isa Any
end
end
end

@testset "AD / Gibbs sampling" begin
# The code to differentiate for the Gibbs sampler is slightly different from the
# HMC/NUTS samplers (even though each individual variable is sampled with HMC) so we
# have to test it separately.
@testset "AD type: $(adtype)" for adtype in adtypes
spl = Gibbs(
@varname(s) => HMC(0.1, 10; adtype = adtype),
@varname(m) => HMC(0.1, 10; adtype = adtype),
)
@testset "model=$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
@info "Sampling model=$(model.f) with AD type=$(adtype)"
@test sample(StableRNG(468), model, spl, 2; progress = false) isa Any
end
end
end

end # module