From b72a96368dc247549b9bbe10e13282ccd003517c Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Mon, 28 Oct 2024 05:30:03 +0000 Subject: [PATCH] use `fix` and fix some errors --- src/abstractprobprog.jl | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/abstractprobprog.jl b/src/abstractprobprog.jl index c1b137f..98b37ea 100644 --- a/src/abstractprobprog.jl +++ b/src/abstractprobprog.jl @@ -131,13 +131,12 @@ end params, ) -> T -Draw a sample from the joint distribution specified by `model` conditioned on the values in -`params`. +Draw a sample from the predictive distribution specified by `model` with its parameters fixed to `params`. The sample will be returned as format specified by `T`. """ -function StatsBase.predict(rand::AbstractRNG, T::Type, model::AbstractProbabilisticProgram, params) - return rand(rng, T, condition(model, params)) +function StatsBase.predict(rng::AbstractRNG, T::Type, model::AbstractProbabilisticProgram, params) + return rand(rng, T, fix(model, params)) end function StatsBase.predict(T::Type, model::AbstractProbabilisticProgram, params) return StatsBase.predict(Random.default_rng(), T, model, params) @@ -145,6 +144,6 @@ end function StatsBase.predict(model::AbstractProbabilisticProgram, params) return StatsBase.predict(NamedTuple, model, params) end -function StatsBase.predict(rng::AbstractRNG, params) +function StatsBase.predict(rng::AbstractRNG, T::Type, model::AbstractProbabilisticProgram, params) return StatsBase.predict(rng, NamedTuple, model, params) end