Skip to content

Commit

Permalink
Make use of new Mooncake feature
Browse files Browse the repository at this point in the history
  • Loading branch information
willtebbutt committed Nov 28, 2024
1 parent 0733b50 commit eaaba53
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 14 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ LinearSolve = "2"
Lux = "1"
Markdown = "1.10"
ModelingToolkit = "9.42"
Mooncake = "0.4.50"
Mooncake = "0.4.52"
NLsolve = "4.5.1"
NonlinearSolve = "3.0.1"
Optimization = "4"
Expand Down
19 changes: 6 additions & 13 deletions src/adjoint_common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -465,13 +465,9 @@ end

function get_paramjac_config(::MooncakeVJP, pf, p, f, y, _t)
dy_mem = zero(y)
dy_mem_grad = Mooncake.zero_tangent(dy_mem)
pf_grad = Mooncake.zero_tangent(pf)
y_grad = Mooncake.zero_tangent(y)
p_grad = Mooncake.zero_tangent(p)
λ_mem = zero(y)
rule = Mooncake.build_rrule(pf, dy_mem, y, p, _t)
return rule, pf, pf_grad, dy_mem, dy_mem_grad, y_grad, p_grad, λ_mem
cache = Mooncake.prepare_pullback_cache(pf, dy_mem, y, p, _t)
return cache, pf, λ_mem, dy_mem
end

function get_pf(autojacvec::ReverseDiffVJP; _f = nothing, isinplace = nothing,
Expand Down Expand Up @@ -536,14 +532,11 @@ function get_pf(::MooncakeVJP, prob, _f)
end

function mooncake_run_ad(paramjac_config, y, p, t, λ)
rule, pf, pf_grad, dy_mem, dy_mem_grad, y_grad, p_grad, λ_mem = paramjac_config
_pf = Mooncake.CoDual(pf, pf_grad)
_dy_mem = Mooncake.CoDual(dy_mem, dy_mem_grad)
_y = Mooncake.CoDual(y, Mooncake.set_to_zero!!(y_grad))
_p = Mooncake.CoDual(p, Mooncake.set_to_zero!!(p_grad))
_t = Mooncake.zero_codual(t)
cache, pf, λ_mem, dy_mem = paramjac_config
λ_mem .= λ
dy, _ = Mooncake.__value_and_pullback!!(rule, λ_mem, _pf, _dy_mem, _y, _p, _t)
dy, _ = Mooncake.value_and_pullback!!(cache, λ_mem, pf, dy_mem, y, p, t)
y_grad = cache.tangents[3]
p_grad = cache.tangents[4]
return dy, y_grad, p_grad
end

Expand Down

0 comments on commit eaaba53

Please sign in to comment.