Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ProductOf{InverseWishart{Float64, PDMats.PDMat{Float64, Matrix{Float64}}}, ExponentialFamily.InverseWishartFast{Float64, Matrix{Float64}}} does not exist or is not dispatched correctly. #216

Closed
wmkouw opened this issue Oct 24, 2024 · 3 comments · Fixed by #217
Assignees
Labels
bug Something isn't working enhancement New feature or request

Comments

@wmkouw
Copy link
Member

wmkouw commented Oct 24, 2024

Expected behaviour of ProductOf{InverseWishart{Float64, PDMats.PDMat{Float64, Matrix{Float64}}}, ExponentialFamily.InverseWishartFast{Float64, Matrix{Float64}}}

InverseWishartFast{Float64, Matrix{Float64}}(
ν: ..
S: ..
)

But I get

The expression `q(Q)` has an undefined functional form of type `ProductOf{InverseWishart{Float64, PDMats.PDMat{Float64, Matrix{Float64}}}, ExponentialFamily.InverseWishartFast{Float64, Matrix{Float64}}}`. 
This is likely because the inference backend does not support the product of these distributions. 
As a result, `RxInfer` cannot compute key quantities such as the `mean` or `var` of `q(Q)`.

This error occurs when specifying a state-space model with an unknown process noise covariance matrix:

@model function LGDS_Q(y, priors,A,C,σ,T)
    "State estimation in a linear Gaussian dynamical system with unknown process noise"
    
    z_0 ~ priors[:z_0]
    Q ~ priors[:Q]
    
    z_kmin1 = z_0
    for k in 1:T
        
        z[k] ~ MvNormalMeanCovariance(A * z_kmin1, Q)
        y[k] ~ NormalMeanVariance(dot(C, z[k]), σ^2)
        z_kmin1 = z[k]
        
    end
end

priors = Dict(:z_0 => MvNormalMeanCovariance(zeros(2), diageye(2)),
              :Q  => InverseWishart(10, diageye(2)))

init = @initialization begin
    q(z) = MvNormalMeanCovariance(zeros(2), diageye(2))
    q(Q) = InverseWishart(10, diageye(2))
end

constraints = @constraints begin
    q(z_0,z,Q) = q(z_0, z)q(Q)
end

results = infer(
    model          = LGDS_Q(priors=priors, A=A,C=C, σ=σ, T=T),
    data           = (y = [observations[k] for k in 1:T],),
    constraints    = constraints,
    iterations     = 100,
    options        = (limit_stack_depth = 100,),
    initialization = init,
    free_energy    = true,
    showprogress   = true,
)

After a discussion in the lab, we discovered this was due to a missing conversion to WishartFast when parsing parametric specifications from Dictionaries.

But upon inspection of wishart.jl, it seems that products of Wishart and InverseWishart are no longer supported. I think it's very important to be able to play with distribution products when designing a model (i.e., what has a closed product and what doesn't). So I advocate for implementing the product rules for Wishart and InverseWishart types. The easiest approach is probably to just convert types, but prod(ClosedProd(), Wishart(..), Wishart(..)) and other type combinations within the Wishart family should be possible.

@wmkouw wmkouw added bug Something isn't working enhancement New feature or request labels Oct 24, 2024
@Nimrais
Copy link
Member

Nimrais commented Oct 24, 2024

Ah yes I see, we have a similar issue already #192.

Thanks for reporting! Yes indeed, the conversion is just good enough.

@wouterwln
Copy link
Member

Conversion is enough, InverseWishartFast is just InverseWishart without checks, so if we can create an InverseWishart distribution we can create an InverseWishartFast distribution out of it. I'll open a PR in a couple of minutes

@Nimrais
Copy link
Member

Nimrais commented Oct 24, 2024

@wouterwln I haven't seen your message, can you then also add this, just merge into your branch to fix another issue as well https://github.com/ReactiveBayes/ExponentialFamily.jl/pull/217/files?

@wouterwln wouterwln linked a pull request Oct 24, 2024 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment