Skip to content

Commit

Permalink
Major refactoring, breaking changes
Browse files Browse the repository at this point in the history
  • Loading branch information
oschulz committed Oct 2, 2020
1 parent 9f4fe67 commit 6dc87a0
Show file tree
Hide file tree
Showing 6 changed files with 256 additions and 153 deletions.
1 change: 1 addition & 0 deletions src/EmpiricalDistributions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ using Distributions
using StatsBase


include("hist_funcs.jl")
include("uv_binned_dist.jl")
include("mv_binned_dist.jl")

Expand Down
76 changes: 76 additions & 0 deletions src/hist_funcs.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# This file is a part of EmpiricalDistributions.jl, licensed under the MIT License (MIT).


function _pdf(h::Histogram{T,N}, xs::NTuple{N,Real}) where {T,N}
@assert h.isdensity # Implementation requires normalized histogram

idx = StatsBase.binindex(h, xs)
r::T = zero(T)
if checkbounds(Bool, h.weights, idx...)
@inbounds r = h.weights[idx...]
end
r
end


function _mean(h::StatsBase.Histogram{<:Real, N}; T::DataType = Float64) where {N}
@assert !h.isdensity # Implementation currently assumes non-normalized histogram

s_inv::T = inv(sum(h.weights))
m::Vector{T} = zeros(T, N)
mps = StatsBase.midpoints.(h.edges)
cart_inds = CartesianIndices(h.weights)
for i in cart_inds
for idim in 1:N
m[idim] += s_inv * mps[idim][i[idim]] * h.weights[i]
end
end
return m
end


_findmaxidx_tuple_or_int(A::AbstractVector{<:Real}) = findmax(A)[2]
_findmaxidx_tuple_or_int(A::AbstractArray{<:Real}) = findmax(A)[2].I

function _mode(h::StatsBase.Histogram; T::DataType = Float64)
@assert h.isdensity # Implementation requires normalized histogram

maxidx = _findmaxidx_tuple_or_int(h.weights)
mode_corner1 = map(getindex, h.edges, maxidx)
mode_corner2 = map(getindex, h.edges, maxidx .+ 1)
cov_est = T[(mode_corner1 .+ mode_corner2) ./ 2...]
end


function _var(h::StatsBase.Histogram{<:Real, N}; T::DataType = Float64, mean = StatsBase.mean(h, T = T), ) where {N}
@assert !h.isdensity # Implementation currently assumes non-normalized histogram

s_inv::T = inv(sum(h.weights))
v::Vector{T} = zeros(T, N)
mps = StatsBase.midpoints.(h.edges)
cart_inds = CartesianIndices(h.weights)
for i in cart_inds
for idim in 1:N
v[idim] += s_inv * (mps[idim][i[idim]] - mean[idim])^2 * h.weights[i]
end
end
return v
end


function _cov(h::StatsBase.Histogram{<:Real, N}; T::DataType = Float64, mean = StatsBase.mean(h, T = T)) where {N}
@assert !h.isdensity # Implementation currently assumes non-normalized histogram

s_inv::T = inv(sum(h.weights))
c::Matrix{T} = zeros(T, N, N)
mps = StatsBase.midpoints.(h.edges)
cart_inds = CartesianIndices(h.weights)
for i in cart_inds
for idim in 1:N
for jdim in 1:N
c[idim, jdim] += s_inv * (mps[idim][i[idim]] - mean[idim]) * (mps[jdim][i[jdim]] - mean[jdim]) * h.weights[i]
end
end
end
return c
end
121 changes: 51 additions & 70 deletions src/mv_binned_dist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


"""
UvBinnedDist <: Distribution{Univariate,Continuous}
MvBinnedDist <: Distribution{Multivariate,Continuous}
Wraps a multi-dimensional histograms and presents it as a binned multivariate
distribution.
Expand All @@ -11,16 +11,21 @@ Constructor:
MvBinnedDist(h::Histogram{<:Real,N})
"""
struct MvBinnedDist{T, N} <: Distributions.Distribution{Multivariate,Continuous}
h::StatsBase.Histogram{<:Real, N}
edges::NTuple{N, <:AbstractVector{T}}
cart_inds::CartesianIndices{N, NTuple{N, Base.OneTo{Int}}}

probabilty_edges::AbstractVector{T}

μ::AbstractVector{T}
var::AbstractVector{T}
cov::AbstractMatrix{T}
struct MvBinnedDist{
T <: Real,
N,
H <: Histogram{<:Real, N},
VT <: AbstractVector{T},
MT <: AbstractMatrix{T}
} <: Distributions.Distribution{Multivariate,Continuous}
hist::H
_edges::NTuple{N, <:AbstractVector{T}}
_cart_inds::CartesianIndices{N, NTuple{N, Base.OneTo{Int}}}
_probability_edges::VT
_mean::VT
_mode::VT
_var::VT
_cov::MT
end

export MvBinnedDist
Expand All @@ -37,83 +42,45 @@ function MvBinnedDist(h::StatsBase.Histogram{<:Real, N}, T::DataType = Float64)
probabilty_edges[i+1] = v > 1 ? 1 : v
end

mean = _mean(h)
var = _var(h, mean = mean)
cov = _cov(h, mean = mean)
mean_est = _mean(h)
mode_est = _mode(nh)
var_est = _var(h, mean = mean_est)
cov_est = _cov(h, mean = mean_est)

return MvBinnedDist{T, N}(
return MvBinnedDist(
nh,
collect.(nh.edges),
CartesianIndices(nh.weights),
probabilty_edges,
mean,
var,
cov
mean_est,
mode_est,
var_est,
cov_est
)
end


Base.convert(::Type{Histogram}, d::MvBinnedDist) = d.hist


Base.length(d::MvBinnedDist{T, N}) where {T, N} = N
Base.size(d::MvBinnedDist{T, N}) where {T, N} = (N,)
Base.eltype(d::MvBinnedDist{T, N}) where {T, N} = T

Statistics.mean(d::MvBinnedDist{T, N}) where {T, N} = d.μ
Statistics.var(d::MvBinnedDist{T, N}) where {T, N} = d.var
Statistics.cov(d::MvBinnedDist{T, N}) where {T, N} = d.cov


function _mean(h::StatsBase.Histogram{<:Real, N}; T::DataType = Float64) where {N}
s_inv::T = inv(sum(h.weights))
m::Vector{T} = zeros(T, N)
mps = StatsBase.midpoints.(h.edges)
cart_inds = CartesianIndices(h.weights)
for i in cart_inds
for idim in 1:N
m[idim] += s_inv * mps[idim][i[idim]] * h.weights[i]
end
end
return m
end


function _var(h::StatsBase.Histogram{<:Real, N}; T::DataType = Float64, mean = StatsBase.mean(h, T = T), ) where {N}
s_inv::T = inv(sum(h.weights))
v::Vector{T} = zeros(T, N)
mps = StatsBase.midpoints.(h.edges)
cart_inds = CartesianIndices(h.weights)
for i in cart_inds
for idim in 1:N
v[idim] += s_inv * (mps[idim][i[idim]] - mean[idim])^2 * h.weights[i]
end
end
return v
end


function _cov(h::StatsBase.Histogram{<:Real, N}; T::DataType = Float64, mean = StatsBase.mean(h, T = T)) where {N}
s_inv::T = inv(sum(h.weights))
c::Matrix{T} = zeros(T, N, N)
mps = StatsBase.midpoints.(h.edges)
cart_inds = CartesianIndices(h.weights)
for i in cart_inds
for idim in 1:N
for jdim in 1:N
c[idim, jdim] += s_inv * (mps[idim][i[idim]] - mean[idim]) * (mps[jdim][i[jdim]] - mean[jdim]) * h.weights[i]
end
end
end
return c
end
Statistics.mean(d::MvBinnedDist{T, N}) where {T, N} = d._mean
StatsBase.mode(d::MvBinnedDist{T, N}) where {T, N} = d._mode
Statistics.var(d::MvBinnedDist{T, N}) where {T, N} = d._var
Statistics.cov(d::MvBinnedDist{T, N}) where {T, N} = d._cov


function Distributions._rand!(r::AbstractRNG, d::MvBinnedDist{T,N}, A::AbstractVector{<:Real}) where {T, N}
rand!(r, A)
next_inds::UnitRange{Int} = searchsorted(d.probabilty_edges::Vector{T}, A[1]::T)
next_inds::UnitRange{Int} = searchsorted(d._probability_edges::Vector{T}, A[1]::T)
cell_lin_index::Int = min(next_inds.start, next_inds.stop)
cell_car_index = d.cart_inds[cell_lin_index]
cell_car_index = d._cart_inds[cell_lin_index]
for idim in Base.OneTo(N)
i = cell_car_index[idim]
sub_int = d.edges[idim][i:i+1]
sub_int = d._edges[idim][i:i+1]
sub_int_width::T = sub_int[2] - sub_int[1]
A[idim] = sub_int[1] + sub_int_width * A[idim]
end
Expand All @@ -122,11 +89,25 @@ end

function Distributions._rand!(r::AbstractRNG, d::MvBinnedDist{T,N}, A::AbstractMatrix{<:Real}) where {T, N}
Distributions._rand!.((r,), (d,), nestedview(A))
return A
end


# Similar to unroll_tuple in StaticArrays.jl:
@generated function _unsafe_unroll_tuple(A::AbstractArray, ::Val{L}) where {L}
exprs = [:(A[idx0 + $j]) for j = 0:(L-1)]
quote
idx0 = firstindex(A)
Base.@_inline_meta
@inbounds return $(Expr(:tuple, exprs...))
end
end


function Distributions.pdf(d::MvBinnedDist{T, N}, x::AbstractArray{<:Real, 1}) where {T, N}
return @inbounds d.h.weights[StatsBase.binindex(d.h, Tuple(x))...]
function Distributions.pdf(d::MvBinnedDist{T,N}, x::AbstractVector{<:Real}) where {T,N}
length(eachindex(x)) == N || throw(ArgumentError("Length of variate doesn't match dimensionality of distribution"))
x_tpl = _unsafe_unroll_tuple(x, Val(N))
_pdf(d.hist, x_tpl)
end


Expand Down
Loading

0 comments on commit 6dc87a0

Please sign in to comment.