Skip to content

Commit

Permalink
Merge pull request #217 from ReactiveBayes/fix-wishart-wishartfast-prod
Browse files Browse the repository at this point in the history
Fix add wishart wishartfast prod
  • Loading branch information
wouterwln authored Oct 24, 2024
2 parents 4ac105d + 9f311fd commit 5fe84b4
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/distributions/wishart.jl
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,12 @@ function BayesBase.prod(::PreserveTypeProd{Distribution}, left::WishartFast, rig
return WishartFast(df, invV)
end

BayesBase.default_prod_rule(::Type{<:Wishart}, ::Type{<:WishartFast}) = PreserveTypeProd(Distribution)

function BayesBase.prod(::PreserveTypeProd{Distribution}, left::Wishart, right::WishartFast)
return prod(PreserveTypeProd(Distribution), convert(WishartFast, left), right)
end

function BayesBase.insupport(ef::ExponentialFamilyDistribution{WishartFast}, x::Matrix)
return size(getindex(unpack_parameters(ef), 2)) == size(x) && isposdef(x)
end
Expand Down
12 changes: 12 additions & 0 deletions src/distributions/wishart_inverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,18 @@ function BayesBase.prod(::PreserveTypeProd{Distribution}, left::InverseWishartFa
return InverseWishartFast(df, V)
end

BayesBase.default_prod_rule(::Type{<:InverseWishart}, ::Type{<:InverseWishartFast}) = PreserveTypeProd(Distribution)

function BayesBase.prod(::PreserveTypeProd{Distribution}, left::InverseWishart, right::InverseWishartFast)
return prod(PreserveTypeProd(Distribution), convert(InverseWishartFast, left), right)
end

BayesBase.default_prod_rule(::Type{<:InverseWishart}, ::Type{<:InverseWishart}) = PreserveTypeProd(Distribution)

function BayesBase.prod(::PreserveTypeProd{Distribution}, left::InverseWishart, right::InverseWishart)
return prod(PreserveTypeProd(Distribution), convert(InverseWishartFast, left), convert(InverseWishartFast, right))
end

function BayesBase.insupport(ef::ExponentialFamilyDistribution{InverseWishartFast}, x::Matrix)
return size(getindex(unpack_parameters(ef), 2)) == size(x) && isposdef(x)
end
Expand Down
34 changes: 34 additions & 0 deletions test/distributions/wishart_inverse_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -228,3 +228,37 @@ end
end
end
end

@testitem "InverseWishart: prod between InverseWishart and InverseWishartFast" begin
include("distributions_setuptests.jl")

import ExponentialFamily: InverseWishartFast
import Distributions: InverseWishart

for Sleft in rand(InverseWishart(10, Array(Eye(2))), 2), Sright in rand(InverseWishart(10, Array(Eye(2))), 2), νright in (6, 7), νleft in (4, 5)
let left = InverseWishart(νleft, Sleft), right = InverseWishart(νleft, Sleft), right_fast = convert(InverseWishartFast, right)
# Test commutativity of the product
prod_result1 = prod(PreserveTypeProd(Distribution), left, right_fast)
prod_result2 = prod(PreserveTypeProd(Distribution), right_fast, left)

@test prod_result1.ν prod_result2.ν
@test prod_result1.S prod_result2.S

# Test that the product preserves type
@test prod_result1 isa InverseWishartFast
@test prod_result2 isa InverseWishartFast

# prod stays if we convert fisrt and then do product
left_fast = convert(InverseWishartFast, left)
prod_fast = prod(ClosedProd(), left_fast, right_fast)

@test prod_fast.ν prod_result1.ν
@test prod_fast.S prod_result2.S

# prod for Inverse Wishart is defenied
prod_result_not_fast = prod(PreserveTypeProd(Distribution), left, right)
@test prod_result_not_fast.ν prod_result1.ν
@test prod_result_not_fast.S prod_result1.S
end
end
end
30 changes: 30 additions & 0 deletions test/distributions/wishart_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,33 @@ end
end
end
end

@testitem "Wishart: prod between Wishart and WishartFast" begin
include("distributions_setuptests.jl")

import ExponentialFamily: WishartFast
import Distributions: Wishart

for Sleft in rand(Wishart(10, Array(Eye(2))), 2), Sright in rand(Wishart(10, Array(Eye(2))), 2), νright in (6, 7), νleft in (4, 5)
let left = Wishart(νleft, Sleft), right = WishartFast(νright, Sright)
# Test commutativity of the product
prod_result1 = prod(PreserveTypeProd(Distribution), left, right)
prod_result2 = prod(PreserveTypeProd(Distribution), right, left)

@test prod_result1.ν prod_result2.ν
@test prod_result1.invS prod_result2.invS

# Test that the product preserves type
@test prod_result1 isa WishartFast
@test prod_result2 isa WishartFast

# prod stays the same if we convert fisrt and then do product
left_fast = convert(WishartFast, left)
prod_fast = prod(ClosedProd(), left_fast, right)

@test prod_fast.ν prod_result1.ν
@test prod_fast.invS prod_result2.invS
end
end
end

0 comments on commit 5fe84b4

Please sign in to comment.