How to extract priors from a Turing model? #2009
Replies: 4 comments 4 replies
-
May I ask why you want the priors? We might already have some functionality for this:) |
Beta Was this translation helpful? Give feedback.
-
I find it very convenient to plot the priors when fitting/describing a model (to give a sense of the different distribution types). Hence, I would like to implement a function that you could conveniently run on a model, that extracts the priors of the model and plots them :) But even besides plotting, extracting them would still be useful: I could imagine a function to create a textual "report" of the priors used by creating text like "For the parameter x, the prior used was ...". TLDR; lots of use cases for extracting priors from a fitted model |
Beta Was this translation helpful? Give feedback.
-
Oh I most certainly agree that it's a useful thing to have:) It's just that in general it's difficult to get something that's "good". For example, in the case of visualizing the prior, I'd instead suggest to just sample from the prior directly, i.e. chain = sample(model, Prior(), 1000) and then inspect this. This will work even in cases where there are dependencies between the prior variables, e.g. hierarchical models, while explicitly extracting the priors might not be as useful here since some of them will change depending on the particular realization. But with the following: using DynamicPPL: OrderedDict, SamplingContext, AbstractContext, IsParent, VarName, Distribution, evaluate!!, VarInfo
import DynamicPPL: tilde_assume, dot_tilde_assume, childcontext, setchildcontext, NodeTrait
Base.@kwdef struct PriorExtractorContext{D,Ctx} <: AbstractContext
priors::D=OrderedDict{VarName,Any}()
context::Ctx=SamplingContext()
end
NodeTrait(::PriorExtractorContext) = IsParent()
childcontext(context::PriorExtractorContext) = context.context
setchildcontext(parent::PriorExtractorContext, child) = PriorExtractorContext(parent.priors, child)
function tilde_assume(context::PriorExtractorContext, right, vn, vi)
setprior!(context, vn, right)
return tilde_assume(childcontext(context), right, vn, vi)
end
function dot_tilde_assume(context::PriorExtractorContext, right, left, vn, vi)
setprior!(context, vn, right)
return dot_tilde_assume(childcontext(context), right, left, vn, vi)
end
function setprior!(context::PriorExtractorContext, vn::VarName, dist::Distribution)
context.priors[vn] = dist
end
function setprior!(context::PriorExtractorContext, vns::AbstractArray{<:VarName}, dist::Distribution)
for vn in vns
context.priors[vn] = dist
end
end
function setprior!(context::PriorExtractorContext, vns::AbstractArray{<:VarName}, dists::AbstractArray{<:Distribution})
# TODO: Support broadcasted expressions properly.
for (vn, dist) in zip(vns, dists)
context.priors[vn] = dist
end
end
"""
extract_priors(model::Model)
Extract the priors from a model. This is done by sampling from the model and
recording the distributions that are used to generate the samples.
"""
function extract_priors(model::Model)
context = PriorExtractorContext()
evaluate!!(model, VarInfo(), context)
return context.priors
end you should be able to do so: julia> model = DynamicPPL.TestUtils.DEMO_MODELS[1]
Model{typeof(DynamicPPL.TestUtils.demo_dot_assume_dot_observe), (:x, Symbol("##arg#378")), (), (), Tuple{Vector{Float64}, DataType}, Tuple{}, DefaultContext}(DynamicPPL.TestUtils.demo_dot_assume_dot_observe, (x = [1.5, 2.0], var"##arg#378" = Vector{Float64}), NamedTuple(), DefaultContext())
julia> extract_priors(model)
OrderedDict{VarName, Any} with 4 entries:
s[1] => InverseGamma{Float64}(…
s[2] => InverseGamma{Float64}(…
m[1] => Normal{Float64}(μ=0.0, σ=3.78366)
m[2] => Normal{Float64}(μ=0.0, σ=1.02669) For this particular model, it's indeed a hierarchical prior and so the sigmas are different. If you want something that extracts the actual graph of the model, it's less easy to do (still possible in many cases, just requires some more complex machinery). |
Beta Was this translation helpful? Give feedback.
-
Based on this PR it would seems like |
Beta Was this translation helpful? Give feedback.
-
I would just like to mention here this dicourse post, which I might have posted at the wrong place.
Assuming the following model:
Is there a way to extract the priors set in the model? Something like:
I couldn't find any solutions in the Turing documentation, so thanks for any pointers!
Beta Was this translation helpful? Give feedback.
All reactions