Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
bvdmitri committed Oct 3, 2023
1 parent bb32410 commit 230d456
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 14 deletions.
2 changes: 1 addition & 1 deletion src/distributions/wishart.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ function Distributions.mean(::typeof(cholinv), distribution::Wishart)
return mean(InverseWishart(ν, cholinv(S)))
end

vague(::Type{<:Wishart}, dims::Int) = Wishart(dims, huge .* Eye(dims))
vague(::Type{<:Wishart}, dims::Int) = Wishart(dims, huge .* Array(Eye(dims)))

Base.ndims(dist::Wishart) = size(dist, 1)

Expand Down
2 changes: 1 addition & 1 deletion src/distributions/wishart_inverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ function Distributions._logpdf(d::InverseWishartFast, X::AbstractMatrix{<:Real})
return Distributions.logkernel(dist, X) + dist.logc0
end

vague(::Type{<:InverseWishart}, dims::Integer) = InverseWishart(dims + 2, tiny .* Eye(dims))
vague(::Type{<:InverseWishart}, dims::Integer) = InverseWishart(dims + 2, tiny .* Array(Eye(dims)))

Base.ndims(dist::InverseWishart) = size(dist, 1)

Expand Down
4 changes: 2 additions & 2 deletions test/distributions/test_mv_normal_wishart.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ include("../testutils.jl")
end

@testset "ExponentialFamilyDistribution{MvNormalWishart}" begin
@testset for dim in (3), invS in rand(Wishart(10, Eye(dim)), 4)
@testset for dim in (3), invS in rand(Wishart(10, Array(Eye(dim))), 4)
ν = dim + 2
@testset let (d = MvNormalWishart(rand(dim), invS, rand(), ν))
ef = test_exponentialfamily_interface(
Expand Down Expand Up @@ -54,7 +54,7 @@ include("../testutils.jl")
end

@testset "prod with ExponentialFamilyDistribution{MvNormalWishart}" begin
for Sleft in rand(Wishart(10, Eye(2)), 2), Sright in rand(Wishart(10, Eye(2)), 2), νright in (6, 7), νleft in (4, 5)
for Sleft in rand(Wishart(10, Array(Eye(2))), 2), Sright in rand(Wishart(10, Array(Eye(2))), 2), νright in (6, 7), νleft in (4, 5)
let left = MvNormalWishart(rand(2), Sleft, rand(), νleft), right = MvNormalWishart(rand(2), Sright, rand(), νright)
@test test_generic_simple_exponentialfamily_product(
left,
Expand Down
8 changes: 4 additions & 4 deletions test/distributions/test_wishart.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,16 +77,16 @@ import StatsFuns: logmvgamma
end

@testset "ExponentialFamilyDistribution{WishartFast}" begin
@testset for dim in (3), invS in rand(Wishart(10, Eye(dim)), 2)
@testset for dim in (3), invS in rand(Wishart(10, Array(Eye(dim))), 2)
ν = dim + 2
@testset let (d = WishartFast(ν, invS))
ef = test_exponentialfamily_interface(d; option_assume_no_allocations = false, test_fisherinformation_against_hessian = false)
(η1, η2) = unpack_parameters(WishartFast, getnaturalparameters(ef))

for x in Eye(dim)
for x in (Eye(dim), Diagonal(ones(dim)), Array(Eye(dim)))
@test @inferred(isbasemeasureconstant(ef)) === ConstantBaseMeasure()
@test @inferred(basemeasure(ef, x)) === 1.0
@test @inferred(sufficientstatistics(ef, x)) === (logdet(x), x)
@test all(@inferred(sufficientstatistics(ef, x)) .≈ (logdet(x), x))
@test @inferred(logpartition(ef)) -(η1 + (dim + 1) / 2) * logdet(-η2) + logmvgamma(dim, η1 + (dim + 1) / 2)
end
end
Expand All @@ -113,7 +113,7 @@ import StatsFuns: logmvgamma
end

@testset "prod with ExponentialFamilyDistribution{Wishart}" begin
for Sleft in rand(Wishart(10, Eye(2)), 2), Sright in rand(Wishart(10, Eye(2)), 2), νright in (6, 7), νleft in (4, 5)
for Sleft in rand(Wishart(10, Array(Eye(2))), 2), Sright in rand(Wishart(10, Array(Eye(2))), 2), νright in (6, 7), νleft in (4, 5)
let left = WishartFast(νleft, Sleft), right = WishartFast(νright, Sright)
@test test_generic_simple_exponentialfamily_product(
left,
Expand Down
12 changes: 6 additions & 6 deletions test/distributions/test_wishart_inverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,16 @@ include("../testutils.jl")
end

@testset "ExponentialFamilyDistribution{InverseWishartFast}" begin
@testset for dim in (3), S in rand(InverseWishart(10, Eye(dim)), 2)
@testset for dim in (3), S in rand(InverseWishart(10, Array(Eye(dim))), 2)
ν = dim + 4
@testset let (d = InverseWishartFast(ν, S))
ef = test_exponentialfamily_interface(d; option_assume_no_allocations = false, test_fisherinformation_against_hessian = false)
(η1, η2) = unpack_parameters(InverseWishartFast, getnaturalparameters(ef))

for x in Eye(dim)
for x in (Eye(dim), Diagonal(ones(dim)), Array(Eye(dim)))
@test @inferred(isbasemeasureconstant(ef)) === ConstantBaseMeasure()
@test @inferred(basemeasure(ef, x)) === 1.0
@test @inferred(sufficientstatistics(ef, x)) === (logdet(x), inv(x))
@test all(@inferred(sufficientstatistics(ef, x)) .≈ (logdet(x), inv(x)))
@test @inferred(logpartition(ef)) (η1 + (dim + 1) / 2) * logdet(-η2) + logmvgamma(dim, -(η1 + (dim + 1) / 2))
end
end
Expand Down Expand Up @@ -116,7 +116,7 @@ include("../testutils.jl")
samples = rand(rng, InverseWishart(ν, S), Int(1e6))
@test isapprox(mean(logdet, InverseWishartFast(ν, S)), mean(logdet.(samples)), atol = 1e-2)

ν, S = 4.0, Eye(3)
ν, S = 4.0, Array(Eye(3))
samples = rand(rng, InverseWishart(ν, S), Int(1e6))
@test isapprox(mean(logdet, InverseWishartFast(ν, S)), mean(logdet.(samples)), atol = 1e-2)
end
Expand All @@ -127,7 +127,7 @@ include("../testutils.jl")
samples = rand(rng, InverseWishart(ν, S), Int(1e6))
@test isapprox(mean(inv, InverseWishartFast(ν, S)), mean(inv.(samples)), atol = 1e-2)

ν, S = 4.0, Eye(3)
ν, S = 4.0, Array(Eye(3))
samples = rand(rng, InverseWishart(ν, S), Int(1e6))
@test isapprox(mean(inv, InverseWishartFast(ν, S)), mean(inv.(samples)), atol = 1e-2)
end
Expand Down Expand Up @@ -185,7 +185,7 @@ include("../testutils.jl")
end

@testset "prod with ExponentialFamilyDistribution{InverseWishartFast}" begin
for Sleft in rand(InverseWishart(10, Eye(2)), 2), Sright in rand(InverseWishart(10, Eye(2)), 2), νright in (6, 7), νleft in (4, 5)
for Sleft in rand(InverseWishart(10, Array(Eye(2))), 2), Sright in rand(InverseWishart(10, Array(Eye(2))), 2), νright in (6, 7), νleft in (4, 5)
let left = InverseWishartFast(νleft, Sleft), right = InverseWishartFast(νright, Sright)
@test test_generic_simple_exponentialfamily_product(
left,
Expand Down

0 comments on commit 230d456

Please sign in to comment.