Skip to content

Commit

Permalink
prepare for the registration
Browse files Browse the repository at this point in the history
  • Loading branch information
bvdmitri committed Oct 3, 2023
1 parent 4657e2c commit c634d21
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 16 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ExponentialFamily"
uuid = "62312e5e-252a-4322-ace9-a5f4bf9b357b"
authors = ["Ismail Senoz <i.senoz@tue.nl>, Dmitry Bagaev <d.v.bagaev@tue.nl>, Mykola Lukashchuk <m.lukashchuk@tue.nl>"]
authors = ["Ismail Senoz <i.senoz@tue.nl>", "Albert Podusenko <a.podusenko@tue.nl>", "Dmitry Bagaev <d.v.bagaev@tue.nl>"]
version = "1.0.0"

[deps]
Expand Down
72 changes: 62 additions & 10 deletions src/exponential_family.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import Base: map
Return the transformation function that maps the parameters in the mean parameters space to the natural parameters space for a distribution of type `T`.
The transformation function is of signature `(params_in_mean_space, [ conditioner ]) -> params_in_natural_space`.
See also: [`NaturalToMean`](@ref)
See also: [`NaturalToMean`](@ref), [`NaturalParametersSpace`](@ref), [`MeanParametersSpace`](@ref), [`getmapping`](@ref)
"""
struct MeanToNatural{T} end

Expand All @@ -35,7 +35,7 @@ end
Return the transformation function that maps the parameters in the natural parameters space to the mean parameters space for a distribution of type `T`.
The transformation function is of signature `(params_in_natural_space, [ conditioner ]) -> params_in_mean_space`.
See also: [`MeanToNatural`](@ref)
See also: [`MeanToNatural`](@ref), [`NaturalParametersSpace`](@ref), [`MeanParametersSpace`](@ref), [`getmapping`](@ref)
"""
struct NaturalToMean{T} end

Expand All @@ -54,7 +54,7 @@ Some functions (such as `logpartition` or `fisherinformation`) accept an additio
Use `map(MeanParametersSpace() => NaturalParametersSpace(), T, parameters, conditioner)` to map the `parameters` and the `conditioner` of a distribution of type `T`
from the mean parametrization to the corresponding natural parametrization.
See also: [`NaturalParametersSpace`](@ref)
See also: [`NaturalParametersSpace`](@ref), [`getmapping`](@ref), [`NaturalToMean`](@ref), [`MeanToNatural`](@ref)
"""
struct MeanParametersSpace end

Expand All @@ -66,7 +66,7 @@ Some functions (such as `logpartition` or `fisherinformation`) accept an additio
Use `map(NaturalParametersSpace() => MeanParametersSpace(), T, parameters, conditioner)` to map the `parameters` and the `conditioner` of a distribution of type `T`
from the natural parametrization to the corresponding mean parametrization.
See also: [`MeanParametersSpace`](@ref)
See also: [`MeanParametersSpace`](@ref), [`getmapping`](@ref), [`NaturalToMean`](@ref), [`MeanToNatural`](@ref)
"""
struct NaturalParametersSpace end

Expand Down Expand Up @@ -137,7 +137,7 @@ A structure to represent the attributes of an exponential family member.
- `logpartition::L`: The log-partition (cumulant) of the exponential family member.
- `support::P`: The support of the exponential family member.
See also: [`ExponentialFamilyDistribution`](@ref)
See also: [`ExponentialFamilyDistribution`](@ref), [`getbasemeasure`](@ref), [`getsufficientstatistics`](@ref), [`getlogpartition`](@ref), [`getsupport`](@ref)
"""
struct ExponentialFamilyDistributionAttributes{B, S, L, P}
basemeasure::B
Expand All @@ -160,6 +160,19 @@ value_support(::Type{ExponentialFamilyDistributionAttributes{B, S, L, P}}) where
`ExponentialFamilyDistribution` structure represents a generic exponential family distribution in natural parameterization.
Type `T` can be either a distribution type (e.g. from the `Distributions.jl` package) or a variate type (e.g. `Univariate`).
In the context of the package, exponential family distributions are represented in the form:
```math
pₓ(x ∣ η) = h(x) ⋅ exp[ η ⋅ T(x) - A(η) ]
```
Here:
- `h(x)` is the base measure.
- `T(x)` represents sufficient statistics.
- `A(η)` stands for the log partition.
- `η` denotes the natural parameters.
For a given member of exponential family:
- `getattributes` returns either `nothing` or `ExponentialFamilyDistributionAttributes`.
- `getbasemeasure` returns a positive a valued function.
Expand All @@ -171,6 +184,22 @@ Type `T` can be either a distribution type (e.g. from the `Distributions.jl` pac
!!! note
The `attributes` can be `nothing`. In which case the package will try to derive the corresponding attributes from the type `T`.
```jldoctest
julia> ef = convert(ExponentialFamilyDistribution, Bernoulli(0.5))
ExponentialFamily(Bernoulli)
julia> getsufficientstatistics(ef)
(identity,)
```
```jldoctest
julia> ef = convert(ExponentialFamilyDistribution, Laplace(1.0, 0.5))
ExponentialFamily(Laplace, conditioned on 1.0)
julia> logpdf(ef, 4.0)
-6.0
```
See also: [`getbasemeasure`](@ref), [`getsufficientstatistics`](@ref), [`getnaturalparameters`](@ref), [`getlogpartition`](@ref), [`getsupport`](@ref)
"""
struct ExponentialFamilyDistribution{T, P, C, A}
Expand Down Expand Up @@ -203,6 +232,15 @@ function ExponentialFamilyDistribution(
return ExponentialFamilyDistribution(T, naturalparameters, conditioner, nothing)
end

function Base.show(io::IO, ef::ExponentialFamilyDistribution{T}) where {T}
print(io, "ExponentialFamily(", T)
conditioner = getconditioner(ef)
if !isnothing(conditioner)
print(io, ", conditioned on ", conditioner)
end
print(io, ")")
end

"""
isproper(::ExponentialFamilyDistribution)
Expand Down Expand Up @@ -491,6 +529,20 @@ flatten_parameters(::Type{T}, params::Tuple) where {T} = flatten_parameters(para
This function returns the parameters of a distribution of type `T` in a vectorized (packed) form. For most of the distributions the packed versions are of the
same structure in any parameters space. For some distributions, however, it is necessary to indicate the `space` of the packaged parameters.
```jldoctest
julia> ExponentialFamily.pack_parameters((1, [2.0, 3.0], [4.0 5.0 6.0; 7.0 8.0 9.0]))
9-element Vector{Float64}:
1.0
2.0
3.0
4.0
7.0
5.0
8.0
6.0
9.0
```
"""
function pack_parameters end

Expand All @@ -513,19 +565,19 @@ end

function __pack_parameters_fast!(container::Vector, offset::Int, current::Int, lengths, front, tail::Tuple)
N = lengths[current]
__pack_copyto!(container, offset, front, 1, N)
__pack_copyto!(container, offset, front, N)
return __pack_parameters_fast!(container, offset + N, current + 1, lengths, Base.first(tail), Base.tail(tail))
end

function __pack_parameters_fast!(container::Vector, i::Int, k::Int, lengths, front, ::Tuple{})
N = lengths[k]
__pack_copyto!(container, i, front, 1, N)
__pack_copyto!(container, i, front, N)
return container
end

__pack_copyto!(dest, doffset, source, soffset, n) = copyto!(dest, doffset, source, soffset, n)
__pack_copyto!(dest::Array, doffset, source::Array, soffset, n) = unsafe_copyto!(dest, doffset, source, soffset, n)
__pack_copyto!(dest::Array, doffset, source::Number, soffset, n) = @inbounds(dest[doffset] = source)
__pack_copyto!(dest, doffset, source, n) = copyto!(dest, doffset, source, firstindex(source), n)
__pack_copyto!(dest::Array, doffset, source::Array, n) = unsafe_copyto!(dest, doffset, source, firstindex(source), n)
__pack_copyto!(dest::Array, doffset, source::Number, _) = @inbounds(dest[doffset] = source)

"""
unpack_parameters([ space ], ::Type{T}, parameters)
Expand Down
48 changes: 43 additions & 5 deletions src/prod.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import Base: prod, prod!, show, showerror
ClosedProd
`ClosedProd` is one of the strategies for `prod` function. This strategy uses either `PreserveTypeProd(Distribution)` or `PreserveTypeProd(ExponentialFamilyDistribution)`,
depending on the types of the input arguments. For example, if both inputs are of type `Distribution`, then `ClosedProd` would fallback to `PreserveTypeProd(Distribution)`.
depending on the types of the input arguments. For example, if both inputs are of type `Distribution`, then `ClosedProd` would fallback to `PreserveTypeProd(Distribution)`.
See also: [`prod`](@ref), [`PreserveTypeProd`](@ref), [`GenericProd`](@ref)
"""
Expand All @@ -22,14 +22,28 @@ struct ClosedProd end
There are multiple strategies for prod function, e.g. `ClosedProd`, `GenericProd` or `PreserveTypeProd`.
# Examples:
```jldoctest
julia> product = prod(PreserveTypeProd(Distribution), NormalMeanVariance(-1.0, 1.0), NormalMeanVariance(1.0, 1.0))
NormalWeightedMeanPrecision{Float64}(xi=0.0, w=2.0)
julia> mean(product), var(product)
(0.0, 0.5)
```
```jldoctest
using ExponentialFamily
julia> product = prod(PreserveTypeProd(NormalMeanVariance), NormalMeanVariance(-1.0, 1.0), NormalMeanVariance(1.0, 1.0))
NormalMeanVariance{Float64}(μ=0.0, v=0.5)
product = prod(ClosedProd(), NormalMeanVariance(-1.0, 1.0), NormalMeanVariance(1.0, 1.0))
julia> mean(product), var(product)
(0.0, 0.5)
```
mean(product), var(product)
```jldoctest
julia> product = prod(PreserveTypeProd(ExponentialFamilyDistribution), NormalMeanVariance(-1.0, 1.0), NormalMeanVariance(1.0, 1.0))
ExponentialFamily(NormalMeanVariance)
# output
julia> mean(product), var(product)
(0.0, 0.5)
```
Expand All @@ -48,6 +62,14 @@ prod(::ClosedProd, ::Missing, ::Missing) = missing
By default it uses the strategy from `default_prod_rule` and converts the output to the prespecified type but can be overwritten
for some distributions for better performance.
```jldoctest
julia> product = prod(PreserveTypeProd(NormalMeanVariance), NormalMeanVariance(-1.0, 1.0), NormalMeanVariance(1.0, 1.0))
NormalMeanVariance{Float64}(μ=0.0, v=0.5)
julia> mean(product), var(product)
(0.0, 0.5)
```
See also: [`prod`](@ref), [`ClosedProd`](@ref), [`PreserveTypeLeftProd`](@ref), [`PreserveTypeRightProd`](@ref), [`GenericProd`](@ref)
"""
struct PreserveTypeProd{T} end
Expand All @@ -65,6 +87,14 @@ prod(::PreserveTypeProd, ::Missing, ::Missing) = missing
An alias for the `PreserveTypeProd(L)` where `L` is the type of the `left` argument of the `prod` function.
```jldoctest
julia> product = prod(PreserveTypeLeftProd(), NormalMeanVariance(-1.0, 1.0), NormalMeanPrecision(1.0, 1.0))
NormalMeanVariance{Float64}(μ=0.0, v=0.5)
julia> mean(product), var(product)
(0.0, 0.5)
```
See also: [`prod`](@ref), [`PreserveTypeProd`](@ref), [`PreserveTypeRightProd`](@ref), [`GenericProd`](@ref)
"""
struct PreserveTypeLeftProd end
Expand All @@ -76,6 +106,14 @@ prod(::PreserveTypeLeftProd, left::L, right) where {L} = prod(PreserveTypeProd(L
An alias for the `PreserveTypeProd(R)` where `R` is the type of the `right` argument of the `prod` function.
```jldoctest
julia> product = prod(PreserveTypeRightProd(), NormalMeanVariance(-1.0, 1.0), NormalMeanPrecision(1.0, 1.0))
NormalMeanPrecision{Float64}(μ=0.0, w=2.0)
julia> mean(product), var(product)
(0.0, 0.5)
```
See also: [`prod`](@ref), [`PreserveTypeProd`](@ref), [`PreserveTypeLeftProd`](@ref), [`GenericProd`](@ref)
"""
struct PreserveTypeRightProd end
Expand Down

0 comments on commit c634d21

Please sign in to comment.