Skip to content

Commit

Permalink
Merge pull request #122 from fjebaker/fergus/convolutions
Browse files Browse the repository at this point in the history
Convolutions
  • Loading branch information
fjebaker authored Jun 29, 2024
2 parents 4b69a4a + eab089e commit dc082cf
Show file tree
Hide file tree
Showing 8 changed files with 239 additions and 7 deletions.
2 changes: 2 additions & 0 deletions src/SpectralFitting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ struct Cash <: AbstractStatistic end
include("units.jl")
SpectralUnits.@reexport using .SpectralUnits

include("utils.jl")
include("print-utilities.jl")
include("support.jl")

Expand All @@ -60,6 +61,7 @@ include("meta-models/wrappers.jl")
include("meta-models/table-models.jl")
include("meta-models/surrogate-models.jl")
include("meta-models/caching.jl")
include("meta-models/functions.jl")

include("poisson.jl")

Expand Down
3 changes: 2 additions & 1 deletion src/fitting/methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ function fit(
alg::LevenbergMarquadt;
verbose = false,
max_iter = 1000,
autodiff = supports_autodiff(config) ? :forward : :finite,
method_kwargs...,
)
@assert fit_statistic(config) == ChiSquared() "Least squares only for χ2 statistics."
Expand All @@ -66,7 +67,7 @@ function fit(
alg;
verbose = verbose,
max_iter = max_iter,
autodiff = supports_autodiff(config) ? :forward : :finite,
autodiff = autodiff,
method_kwargs...,
)
params = LsqFit.coef(lsq_result)
Expand Down
25 changes: 24 additions & 1 deletion src/julia-models/additive.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,4 +130,27 @@ end
end
end

export PowerLaw, BlackBody, GaussianLine
struct DeltaLine{W<:Number,T} <: AbstractSpectralModel{T,Additive}
_width::W
"Normalisation."
K::T
"Energy at which the delta function spikes."
E::T
end

function DeltaLine(; K = FitParam(1.0), E = FitParam(5.0), width = 1e-2)
DeltaLine{typeof(width),typeof(K)}(width, K, E)
end

Reflection.get_closure_symbols(::Type{<:DeltaLine}) = (:_width,)

Reflection.get_parameter_symbols(model::Type{<:DeltaLine}) = fieldnames(model)[2:end]

@inline function invoke!(flux, energy, model::DeltaLine{T}) where {T}
# we can't actually have a diract delta because that would ruin
# the ability to run dual numbers through the system. What we can do instead
# is have a miniscule gaussian
invoke!(flux, energy, GaussianLine(promote(model.K, model.E, model._width)...))
end

export PowerLaw, BlackBody, GaussianLine, DeltaLine
2 changes: 2 additions & 0 deletions src/julia-models/convolutional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,5 @@ function _or_else(value::Union{Nothing,T}, v::T)::T where {T}
value
end
end

export AsConvolution
18 changes: 13 additions & 5 deletions src/meta-models/caching.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,29 @@ function AutoCache(model::AbstractSpectralModel{T,K}; abstol = 1e-3) where {T,K}
AutoCache(model, cache, abstol)
end

function _reinterpret_dual(::Type, v::AbstractArray, n::Int)
function _reinterpret_dual(
M::Type{<:AbstractSpectralModel},
::Type,
v::AbstractArray,
n::Int,
)
needs_resize = n > length(v)
if needs_resize
@warn "AutoCache: Growing dual buffer..."
@warn "$(Base.typename(M).name): Growing dual buffer..."
resize!(v, n)
end
view(v, 1:n), needs_resize
end
function _reinterpret_dual(
M::Type{<:AbstractSpectralModel},
DualType::Type{<:ForwardDiff.Dual},
v::AbstractArray{T},
n::Int,
) where {T}
n_elems = div(sizeof(DualType), sizeof(T)) * n
needs_resize = n_elems > length(v)
if needs_resize
@warn "AutoCache: Growing dual buffer..."
@warn "$(Base.typename(M).name): Growing dual buffer..."
resize!(v, n_elems)
end
reinterpret(DualType, view(v, 1:n_elems)), needs_resize
Expand All @@ -58,8 +64,10 @@ function invoke!(output, domain, model::AutoCache{M,T,K}) where {M,T,K}
_new_params = parameter_tuple(model.model)
_new_limits = (first(domain), last(domain))

output_cache, out_resized = _reinterpret_dual(D, model.cache.cache, length(output))
param_cache, _ = _reinterpret_dual(D, model.cache.params, length(_new_params))
output_cache, out_resized =
_reinterpret_dual(typeof(model), D, model.cache.cache, length(output))
param_cache, _ =
_reinterpret_dual(typeof(model), D, model.cache.params, length(_new_params))

same_domain = model.cache.domain_limits == _new_limits

Expand Down
77 changes: 77 additions & 0 deletions src/meta-models/functions.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
struct AsConvolution{M,T,V,P} <: AbstractModelWrapper{M,T,Convolutional}
model::M
# the domain on which we evaluate this model
domain::V
# an additional output cache
cache::NTuple{2,Vector{P}}
function AsConvolution(
model::AbstractSpectralModel{T},
domain::V,
cache::NTuple{2,Vector{P}},
) where {T,V,P}
new{typeof(model),T,V,P}(model, domain, cache)
end
end

function AsConvolution(
model::AbstractSpectralModel{T};
domain = collect(range(0, 2, 100)),
) where {T}
output = invokemodel(domain, model)
AsConvolution(model, domain, (output, deepcopy(output)))
end

function invoke!(output, domain, model::AsConvolution{M,T}) where {M,T}
D = promote_type(eltype(domain), T)
model_output, _ =
_reinterpret_dual(typeof(model), D, model.cache[1], length(model.domain) - 1)
convolution_cache, _ = _reinterpret_dual(
typeof(model),
D,
model.cache[2],
length(output) + length(model_output) - 1,
)

# invoke the child model
invoke!(model_output, model.domain, model.model)

# do the convolution
convolve!(convolution_cache, output, model_output)

# overwrite the output
shift = div(length(model_output), 2)
@views output .= convolution_cache[1+shift:length(output)+shift]
end

function Reflection.get_parameter_symbols(
::Type{<:AsConvolution{M}},
) where {M<:AbstractSpectralModel{T,K}} where {T,K}
syms = Reflection.get_parameter_symbols(M)
if K === Additive
# we need to lose the normalisation parameter
(syms[2:end]...,)
else
syms
end
end

function Reflection.make_constructor(
M::Type{<:AsConvolution{Model}},
closures::Vector,
params::Vector,
T::Type,
) where {Model<:AbstractSpectralModel{Q,K}} where {Q,K}
num_closures = fieldcount(M) - 1 # ignore the `model` field
my_closures = closures[1:num_closures]

model_params = if K === Additive
# insert a dummy normalisation to the constructor
vcat(:(one($T)), params)
else
params
end

model_constructor =
Reflection.make_constructor(Model, closures[num_closures+1:end], model_params, T)
:($(Base.typename(M).name)($(model_constructor), $(my_closures...)))
end
47 changes: 47 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
function _convolve_implementation!(
output::AbstractVector{T},
vec_A::AbstractVector{T},
kernel::AbstractVector{T},
) where {T<:Number}
# Based on https://discourse.julialang.org/t/97658/15
J = length(vec_A)
K = length(kernel)
@assert length(output) == J + K - 1 "Ouput is $(length(output)); should be $(J + K - 1)"

# do the kernel's side first
for i = 1:K-1
total = zero(T)
for k = 1:K
ib = (i >= k)
oa = ib ? vec_A[i-k+1] : zero(T)
total += kernel[k] * oa
end
output[i] = total
end
# now the middle
for i = K:J-1
total = zero(T)
for k = 1:K
oa = vec_A[i-k+1]
total += kernel[k] * oa
end
output[i] = total
end
# and finally the end
for i = J:(J+K-1)
total = zero(T)
for k = 1:K
ib = (i < J + k)
oa = ib ? vec_A[i-k+1] : zero(T)
total += kernel[k] * oa
end
output[i] = total
end
output
end

convolve!(output, A, kernel) = _convolve_implementation!(output, A, kernel)
function convolve(A, kernel)
output = zeros(eltype(A), length(A) + length(kernel) - 1)
convolve!(output, A, kernel)
end
72 changes: 72 additions & 0 deletions test/models/test-as-convolution.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
using SpectralFitting
using Test

include("../dummies.jl")

# put a couple of delta emission lines together
lines = DeltaLine(; E = FitParam(3.0), K = FitParam(2.0)) + DeltaLine(; E = FitParam(7.0))

# construct the convolutional wrapper
base_model = GaussianLine(; μ = FitParam(1.0), σ = FitParam(0.3))
conv = AsConvolution(base_model)

model = conv(lines)

domain = collect(range(0.0, 10.0, 150))

plot(domain[1:end-1], invokemodel(domain, lines))
plot(domain[1:end-1], invokemodel(domain, model))

output = invokemodel(domain, model)

@test sum(output) 3.2570820013702395 atol = 1e-4
@test output[10] 0.0036345342427057687 atol = 1e-4
@test output[40] 0.055218163108951814 atol = 1e-4

# simulate a model spectrum
dummy_data = make_dummy_dataset((E) -> (E^(-3.0)); units = u"counts / (s * keV)")
sim = simulate(model, dummy_data; seed = 42)

model.μ_1.frozen = true
model.K_1.frozen = true
model.K_2.frozen = true
model.E_1.frozen = true
model.E_2.frozen = true

# change the width
model.σ_1.value = 0.1
model

begin
prob = FittingProblem(model => sim)
result = fit(prob, LevenbergMarquadt())
end
@test result.χ2 76.15221077389369 atol = 1e-3

# put a couple of delta emission lines together
lines =
DeltaLine(; E = FitParam(3.0), K = FitParam(2.0), width = 0.1) +
DeltaLine(; E = FitParam(7.0))
model = conv(lines)

sim = simulate(model, dummy_data; seed = 42)

# now see if we can fit the delta line
model.μ_1.frozen = true
model.K_1.frozen = true
model.K_2.frozen = true
model.E_1.frozen = true
model.E_2.frozen = true
model.σ_1.frozen = true

model.E_2.frozen = false
model.E_2.value = 2.0
model.K_2.frozen = true
# model.K_2.value = 2.0

model
begin
prob = FittingProblem(model => sim)
result = fit(prob, LevenbergMarquadt(); verbose = true)
end
@test result.χ2 75.736 atol = 1e-3

0 comments on commit dc082cf

Please sign in to comment.