-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
282cdea
commit dc41a2a
Showing
3 changed files
with
127 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,87 @@ | ||
module ExpFamilyPCA | ||
|
||
# Write your package code here. | ||
using Optim | ||
using CompressedBeliefMDPs | ||
|
||
# TODO: make this an imutable struct | ||
mutable struct EPCA <: CompressedBeliefMDPs.Compressor | ||
n::Int # number of samples | ||
d::Int # size of each sample | ||
l::Int # number of components | ||
A::Matrix # n x l matrix | ||
V::Matrix # l x d basis matrix | ||
|
||
G # convex function that induces g, F, f, and Bregman | ||
g # g(θ) = G'(θ) | ||
F # F(g(θ)) + G(θ) = g(θ)θ | ||
f # f(x) = F'(x) | ||
Bregman # generalized Bregman divergence induced by F | ||
|
||
μ0::Real # for numerical stability; must be in the range of g | ||
ϵ::Real # controls weight of stabilizing term in loss function | ||
|
||
EPCA() = new() | ||
end | ||
|
||
|
||
# TODO: implement this with Symbolics of SymEnginer | ||
# """ | ||
# EPCA(G) | ||
|
||
# Return the EPCA induced by a convex function G. | ||
# """ | ||
# function EPCA(G) | ||
# return nothing | ||
# end | ||
|
||
# TODO: move this logic | ||
function EPCA(l::Int, μ0::Real; ϵ::Float64=0.01) | ||
epca = EPCA() | ||
epca.l = l | ||
epca.μ0 = μ0 | ||
epca.ϵ = ϵ | ||
return epca | ||
end | ||
|
||
|
||
function CompressedBeliefMDPs.fit!(epca::EPCA, X; verbose=false, maxiter::Int=50) | ||
@assert epca.l > 0 | ||
epca.n, epca.d = size(X) | ||
epca.A = zeros(epca.n, epca.l) | ||
epca.V = rand(epca.l, epca.d) | ||
|
||
L(A, V) = sum(epca.Bregman(X, epca.g(A * V)) + epca.ϵ * epca.Bregman(epca.μ0, epca.g(A * V))) | ||
|
||
for _ in 1:maxiter | ||
if verbose println("Loss: ", L(epca.A, epca.V)) end | ||
epca.V = Optim.minimizer(optimize(V->L(epca.A, V), epca.V)) | ||
epca.A = Optim.minimizer(optimize(A->L(A, epca.V), epca.A)) | ||
end | ||
end | ||
|
||
# TODO: make sure this works for both matrices and vectors!! also update the signature in compressed belief pomdps | ||
function CompressedBeliefMDPs.compress(epca::EPCA, X; maxiter=50, verbose=false) | ||
n, d = size(X) | ||
@assert d == epca.d | ||
 = zeros(n, epca.l) | ||
L(A, V) = sum(epca.Bregman(X, epca.g(A * V)) + epca.ϵ * epca.Bregman(epca.μ0, epca.g(A * V))) | ||
for _ in 1:maxiter | ||
if verbose println("Loss: ", L(Â, epca.V)) end | ||
 = Optim.minimizer(optimize(A->L(A, epca.V), Â)) | ||
end | ||
return  * epca.V | ||
end | ||
|
||
CompressedBeliefMDPs.decompress(epca::EPCA, compressed) = epca.g(compressed) | ||
|
||
|
||
export | ||
PoissonPCA | ||
include("poisson.jl") | ||
|
||
export | ||
BernoulliPCA | ||
include("bernoulli.jl") | ||
|
||
|
||
end # module ExpFamilyPCA |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
""" | ||
Best with binary data. | ||
""" | ||
function BernoulliPCA(l::Int; μ0::Real=0.5, kwargs...) | ||
epca = EPCA(l, μ0; kwargs...) | ||
# TODO: eventually replace this w/ symbolic diff | ||
ϵ = 10e-20 | ||
@. begin | ||
G(θ) = log(1 + exp(θ)) | ||
g(θ) = exp(θ) / (1 + exp(θ)) | ||
F(x) = x * log(x) + (1 - x) * log(1 - x) | ||
f(x) = log(x / (1 - x)) | ||
# TODO: look into when this value is negative | ||
Bregman(p, q) = p * log((p + ϵ) / (q + ϵ)) + (1 - p) * log((1 - p + ϵ) / (1 - q + ϵ)) # with additive smoothing | ||
end | ||
epca.G = G | ||
epca.g = g | ||
epca.F = F | ||
epca.f = f | ||
epca.Bregman = Bregman | ||
return epca | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
function PoissonPCA(l::Int; μ0::Real=0, kwargs...) | ||
epca = EPCA(l, μ0; kwargs...) | ||
# TODO: eventually replace this w/ symbolic diff | ||
# ϵ = 10e-20 | ||
ϵ = eps() | ||
@. begin | ||
G(θ) = exp(θ) | ||
g(θ) = exp(θ) | ||
F(x) = x * log(x) - x | ||
f(x) = log(x) | ||
Bregman(p, q) = p * log((p + ϵ) / (q + ϵ)) + q - p # with additive smoothing | ||
end | ||
epca.G = G | ||
epca.g = g | ||
epca.F = F | ||
epca.f = f | ||
epca.Bregman = Bregman | ||
return epca | ||
end | ||
|
||
|
||
# TODO: include a normalized Poisson w/ link function in footnote 5 of long paper |