From 6cda516f77d8c77dda52b2589457c1d9d8dedf5b Mon Sep 17 00:00:00 2001 From: Raphael-Tresor <40422324+Raphael-Tresor@users.noreply.github.com> Date: Thu, 26 Sep 2024 15:56:16 +0200 Subject: [PATCH 1/3] fix-rand-matrixDirichlet --- src/distributions/matrix_dirichlet.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/distributions/matrix_dirichlet.jl b/src/distributions/matrix_dirichlet.jl index 90e4eff6..607c5633 100644 --- a/src/distributions/matrix_dirichlet.jl +++ b/src/distributions/matrix_dirichlet.jl @@ -75,7 +75,7 @@ end function BayesBase.rand!(rng::AbstractRNG, dist::MatrixDirichlet, container::AbstractMatrix{T}) where {T <: Real} samples = vmap(d -> rand(rng, Dirichlet(convert(Vector, d))), eachcol(dist.a)) - @views for row in 1:isqrt(length(container)) + @views for row in 1:size(container,2) b = container[:, row] b[:] .= samples[row] end From 5740b52ff08e09bbc5cc93bb6a8cffcf1d76dac5 Mon Sep 17 00:00:00 2001 From: Raphael-Tresor <40422324+Raphael-Tresor@users.noreply.github.com> Date: Fri, 27 Sep 2024 11:38:32 +0200 Subject: [PATCH 2/3] ad test rand MatrixDirichlet --- src/distributions/matrix_dirichlet.jl | 2 +- test/distributions/matrix_dirichlet_tests.jl | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/src/distributions/matrix_dirichlet.jl b/src/distributions/matrix_dirichlet.jl index 607c5633..903677e2 100644 --- a/src/distributions/matrix_dirichlet.jl +++ b/src/distributions/matrix_dirichlet.jl @@ -75,7 +75,7 @@ end function BayesBase.rand!(rng::AbstractRNG, dist::MatrixDirichlet, container::AbstractMatrix{T}) where {T <: Real} samples = vmap(d -> rand(rng, Dirichlet(convert(Vector, d))), eachcol(dist.a)) - @views for row in 1:size(container,2) + @views for row in 1:size(container, 2) b = container[:, row] b[:] .= samples[row] end diff --git a/test/distributions/matrix_dirichlet_tests.jl b/test/distributions/matrix_dirichlet_tests.jl index 6021933c..3fcf5249 100644 --- a/test/distributions/matrix_dirichlet_tests.jl +++ b/test/distributions/matrix_dirichlet_tests.jl @@ -148,3 +148,13 @@ end @test promote_variate_type(Multivariate, MatrixDirichlet) === Dirichlet @test promote_variate_type(Matrixvariate, MatrixDirichlet) === MatrixDirichlet end + +@testitem "MatrixDirichlet: rand" begin + include("distributions_setuptests.jl") + + @test_throws DimensionMismatch sum(rand(MatrixDirichlet(ones(3, 5))), dims = 1) ≈ [1.0;; 1.0;; 1.0] + + @test sum(rand(MatrixDirichlet(ones(3, 5))), dims = 1) ≈ [1.0;; 1.0;; 1.0;; 1.0;; 1.0] + @test sum(rand(MatrixDirichlet(ones(5, 3))), dims = 1) ≈ [1.0;; 1.0;; 1.0] + @test sum(rand(MatrixDirichlet(ones(5, 5))), dims = 1) ≈ [1.0;; 1.0;; 1.0;; 1.0;; 1.0] +end From 335551bc20dbed8f99537c8acbb1c4705565b754 Mon Sep 17 00:00:00 2001 From: Wouter Nuijten Date: Fri, 27 Sep 2024 14:22:43 +0200 Subject: [PATCH 3/3] Reduce allocations in matrixdirichlet rand! --- src/distributions/matrix_dirichlet.jl | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/distributions/matrix_dirichlet.jl b/src/distributions/matrix_dirichlet.jl index 903677e2..21e99da2 100644 --- a/src/distributions/matrix_dirichlet.jl +++ b/src/distributions/matrix_dirichlet.jl @@ -74,12 +74,9 @@ function BayesBase.rand(rng::AbstractRNG, dist::MatrixDirichlet{T}, nsamples::In end function BayesBase.rand!(rng::AbstractRNG, dist::MatrixDirichlet, container::AbstractMatrix{T}) where {T <: Real} - samples = vmap(d -> rand(rng, Dirichlet(convert(Vector, d))), eachcol(dist.a)) - @views for row in 1:size(container, 2) - b = container[:, row] - b[:] .= samples[row] + @views for (i, col) in enumerate(eachcol(dist.a)) + rand!(rng, Dirichlet(col), container[:, i]) end - return container end