Skip to content

Commit

Permalink
Merge pull request #116 from fjebaker/fergus/frozen-cache
Browse files Browse the repository at this point in the history
Frozen parameters in parameter cache
  • Loading branch information
fjebaker authored Jun 25, 2024
2 parents e15f87b + 96a02c0 commit f85b322
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 21 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SpectralFitting"
uuid = "f2c56810-742e-4b72-8bf4-27af3bb81a12"
authors = ["Fergus Baker <fergusbkr@gmail.com>"]
version = "0.5.8"
version = "0.5.9"

[deps]
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
Expand Down
6 changes: 1 addition & 5 deletions src/abstract-models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -378,9 +378,5 @@ function make_diff_parameter_cache(
N = isnothing(param_diff_cache_size) ? length(vals) : param_diff_cache_size
diffcache = DiffCache(vals, ForwardDiff.pickchunksize(N))

# embed current parameter values inside of the dual cache
# else all frozens will be zero
get_tmp(diffcache, ForwardDiff.Dual(one(eltype(vals)))) .= vals

ParameterCache(free_mask, diffcache)
ParameterCache(free_mask, diffcache, vals[.!free_mask])
end
30 changes: 17 additions & 13 deletions src/param-cache.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# these should work for both single models and composite models
# so that the parameters can all be allocated in one go
# for multi model: give a view to each cache
struct ParameterCache{M<:AbstractArray,V}
struct ParameterCache{M<:AbstractArray,V,T<:Number}
free_mask::M # bit vector or a view into one
parameters::V
frozen_values::Vector{T}
end

function _make_free_mask(params::AbstractArray{<:FitParam})
Expand All @@ -16,15 +17,20 @@ end

function ParameterCache(params::AbstractArray{<:FitParam})
free_mask = _make_free_mask(params)
ParameterCache(free_mask, map(get_value, params))
frozen = params[.!free_mask]
ParameterCache(free_mask, map(get_value, params), map(get_value, frozen))
end

function _update_conditional!(parameters, mask, new_parameters, condition)
function _update_conditional!(parameters, mask, new_parameters, frozen)
j::Int = 1
k::Int = 1
for (i, free) in enumerate(mask)
if condition(free)
if free
parameters[i] = new_parameters[j]
j += 1
else
parameters[i] = frozen[k]
k += 1
end
end
end
Expand All @@ -35,15 +41,13 @@ _get_parameters(cache::ParameterCache{M,V}, params) where {M<:AbstractArray,V<:D

function update_free_parameters!(cache::ParameterCache, params)
@assert count(cache.free_mask) == length(params)
_update_conditional!(_get_parameters(cache, params), cache.free_mask, params, ==(true))
_update_conditional!(
_get_parameters(cache, params),
cache.free_mask,
params,
cache.frozen_values,
)
cache
end

function update_frozen_parameters!(cache::ParameterCache, params)
parameters = _get_parameters(cache, params)
@assert length(parameters) - count(cache.free_mask) == length(params)
_update_conditional!(parameters, cache.free_mask, params, ==(false))
cache
end

export ParameterCache, update_free_parameters!, update_frozen_parameters!
export ParameterCache, update_free_parameters!
5 changes: 3 additions & 2 deletions test/parameters/test-free-frozen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@ model = DummyMultiplicative() * PowerLaw(a = FitParam(4.0, frozen = true))
cache = make_parameter_cache(model)
SpectralFitting.update_free_parameters!(cache, [2.0, 0.5])
@test cache.parameters == [2.0, 4.0, 0.5, 5.0]
SpectralFitting.update_frozen_parameters!(cache, [50.0, 0.0])
@test cache.parameters == [2.0, 50.0, 0.5, 0.0]


# TODO: test to check that auto diff gradients with various sizes can be propagated correctly

0 comments on commit f85b322

Please sign in to comment.