diff --git a/Project.toml b/Project.toml index 7e779ad64..f1ff2c786 100644 --- a/Project.toml +++ b/Project.toml @@ -72,7 +72,7 @@ LinearSolve = "2" Lux = "1" Markdown = "1.10" ModelingToolkit = "9.42" -Mooncake = "0.4.50" +Mooncake = "0.4.52" NLsolve = "4.5.1" NonlinearSolve = "3.0.1" Optimization = "4" diff --git a/src/adjoint_common.jl b/src/adjoint_common.jl index c16eb3c0d..dd93ef2ac 100644 --- a/src/adjoint_common.jl +++ b/src/adjoint_common.jl @@ -465,13 +465,9 @@ end function get_paramjac_config(::MooncakeVJP, pf, p, f, y, _t) dy_mem = zero(y) - dy_mem_grad = Mooncake.zero_tangent(dy_mem) - pf_grad = Mooncake.zero_tangent(pf) - y_grad = Mooncake.zero_tangent(y) - p_grad = Mooncake.zero_tangent(p) λ_mem = zero(y) - rule = Mooncake.build_rrule(pf, dy_mem, y, p, _t) - return rule, pf, pf_grad, dy_mem, dy_mem_grad, y_grad, p_grad, λ_mem + cache = Mooncake.prepare_pullback_cache(pf, dy_mem, y, p, _t) + return cache, pf, λ_mem, dy_mem end function get_pf(autojacvec::ReverseDiffVJP; _f = nothing, isinplace = nothing, @@ -536,14 +532,11 @@ function get_pf(::MooncakeVJP, prob, _f) end function mooncake_run_ad(paramjac_config, y, p, t, λ) - rule, pf, pf_grad, dy_mem, dy_mem_grad, y_grad, p_grad, λ_mem = paramjac_config - _pf = Mooncake.CoDual(pf, pf_grad) - _dy_mem = Mooncake.CoDual(dy_mem, dy_mem_grad) - _y = Mooncake.CoDual(y, Mooncake.set_to_zero!!(y_grad)) - _p = Mooncake.CoDual(p, Mooncake.set_to_zero!!(p_grad)) - _t = Mooncake.zero_codual(t) + cache, pf, λ_mem, dy_mem = paramjac_config λ_mem .= λ - dy, _ = Mooncake.__value_and_pullback!!(rule, λ_mem, _pf, _dy_mem, _y, _p, _t) + dy, _ = Mooncake.value_and_pullback!!(cache, λ_mem, pf, dy_mem, y, p, t) + y_grad = cache.tangents[3] + p_grad = cache.tangents[4] return dy, y_grad, p_grad end