Skip to content

Commit

Permalink
fixes for RxInfer
Browse files Browse the repository at this point in the history
  • Loading branch information
bvdmitri committed Oct 31, 2023
1 parent 7e733d8 commit deff015
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
6 changes: 3 additions & 3 deletions src/distributions/normal_family/normal_family.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,18 @@ function BayesBase.mean_invcov(dist::Normal)
return mean, inv(var)
end

function BayesBase.weightedmean_invcov(dist::FullNormal)
function BayesBase.weightedmean_invcov(dist::Union{FullNormal, MvNormal})
mean, var = mean_cov(dist)
invcov = cholinv(var)
return invcov * mean, invcov
end

function BayesBase.mean_invcov(dist::FullNormal)
function BayesBase.mean_invcov(dist::Union{FullNormal, MvNormal})
mean, cov = mean_cov(dist)
return mean, cholinv(cov)
end

function BayesBase.mean_std(dist::FullNormal)
function BayesBase.mean_std(dist::Union{FullNormal, MvNormal})
mean, cov = mean_cov(dist)
return mean, cholsqrt(cov)
end
Expand Down
8 changes: 5 additions & 3 deletions test/distributions/dirichlet_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@ end
@testitem "Dirichlet: mean(::typeof(log))" begin
include("distributions_setuptests.jl")

@test mean(log, Dirichlet([1.0, 1.0, 1.0])) [-1.5000000000000002, -1.5000000000000002, -1.5000000000000002]
@test mean(log, Dirichlet([1.1, 2.0, 2.0])) [-1.9517644694670657, -1.1052251939575213, -1.1052251939575213]
@test mean(log, Dirichlet([3.0, 1.2, 5.0])) [-1.2410879175727905, -2.4529121492634465, -0.657754584239457]
import Base.Broadcast: BroadcastFunction

@test mean(BroadcastFunction(log), Dirichlet([1.0, 1.0, 1.0])) [-1.5000000000000002, -1.5000000000000002, -1.5000000000000002]
@test mean(BroadcastFunction(log), Dirichlet([1.1, 2.0, 2.0])) [-1.9517644694670657, -1.1052251939575213, -1.1052251939575213]
@test mean(BroadcastFunction(log), Dirichlet([3.0, 1.2, 5.0])) [-1.2410879175727905, -2.4529121492634465, -0.657754584239457]
end

@testitem "Dirichlet: ExponentialFamilyDistribution" begin
Expand Down

0 comments on commit deff015

Please sign in to comment.