diff --git a/Project.toml b/Project.toml index 6510e7ea0..7c987962d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.24.7" +version = "0.24.8" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/ext/DynamicPPLForwardDiffExt.jl b/ext/DynamicPPLForwardDiffExt.jl index 10371b3fe..4bc33e217 100644 --- a/ext/DynamicPPLForwardDiffExt.jl +++ b/ext/DynamicPPLForwardDiffExt.jl @@ -26,7 +26,7 @@ function LogDensityProblemsAD.ADgradient( ForwardDiff.Tag(f, eltype(θ)) end chunk_size = getchunksize(ad) - chunk = if chunk_size == 0 + chunk = if chunk_size == 0 || chunk_size === nothing ForwardDiff.Chunk(θ) else ForwardDiff.Chunk(length(θ), chunk_size) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 49605832a..dabcc36ab 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -289,7 +289,7 @@ function dot_tilde_assume(context::AbstractContext, args...) return dot_tilde_assume(NodeTrait(dot_tilde_assume, context), context, args...) end function dot_tilde_assume(rng, context::AbstractContext, args...) - return dot_tilde_assume(rng, NodeTrait(dot_tilde_assume, context), context, args...) + return dot_tilde_assume(NodeTrait(dot_tilde_assume, context), rng, context, args...) end function dot_tilde_assume(::IsLeaf, ::AbstractContext, right, left, vns, vi) @@ -302,7 +302,7 @@ end function dot_tilde_assume(::IsParent, context::AbstractContext, args...) return dot_tilde_assume(childcontext(context), args...) end -function dot_tilde_assume(rng, ::IsParent, context::AbstractContext, args...) +function dot_tilde_assume(::IsParent, rng, context::AbstractContext, args...) return dot_tilde_assume(rng, childcontext(context), args...) end diff --git a/test/ext/DynamicPPLForwardDiffExt.jl b/test/ext/DynamicPPLForwardDiffExt.jl index 1227a8c95..8de28046b 100644 --- a/test/ext/DynamicPPLForwardDiffExt.jl +++ b/test/ext/DynamicPPLForwardDiffExt.jl @@ -1,5 +1,5 @@ @testset "tag" begin - for chunksize in (0, 1, 10) + for chunksize in (nothing, 0, 1, 10) ad = ADTypes.AutoForwardDiff(; chunksize=chunksize) standardtag = if !isdefined(Base, :get_extension) DynamicPPL.DynamicPPLForwardDiffExt.standardtag