-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
besselk
Enzyme compatibility w.r.t. order
#96
Comments
Here is a fix for the asymptotic expansion+Levin code that gives correctly derivatives for half-integer orders, using a custom forward rule in using Bessels, Enzyme
using Enzyme.EnzymeRules
using Base.Cartesian: @nexprs, @ntuple
using Bessels.Math
using Bessels.Math: levin_scale
@inline @generated function levin_transform(s::NTuple{N, T},
w::NTuple{N, T}) where {N, T <: FloatTypes}
len = N - 1
:(
begin
@nexprs $N i -> a_{i} = Vec{2, T}((s[i] / w[i], 1 / w[i]))
@nexprs $len k -> (@nexprs ($len-k) i -> a_{i} = fmadd(a_{i}, levin_scale(one(T), i, k-1), a_{i+1}))
return (a_1[1] / a_1[2])
end
)
end
# Here is the sauce: a manual rule that separates does the levin transform for
# the values and dvalues separately. This is because if v+1/2 is an integer, the
# sequences of _values_ will be exactly converged and you'll have a divide by
# zero NaN problem, but the _dvalues_ will NOT do that, and you want all those
# terms for the sake of accuracy.
function EnzymeRules.forward(func::Const{typeof(levin_transform)},
::Type{<:Duplicated},
s::Duplicated,
w::Duplicated)
(sv, dv, N) = (s.val, s.dval, length(s.val))
ls = (sv[N-1] == sv[N]) ? sv[N] : levin_transform(sv, w.val)
dls = (dv[N-1] == dv[N]) ? dv[N] : levin_transform(dv, w.dval)
Duplicated(ls, dls)
end
@generated function besselkx_levin(v, x::T, ::Val{N}) where {T <: FloatTypes, N}
:(
begin
s = zero(T)
t = one(T)
@nexprs $N i -> begin
s += t
t *= (4*v^2 - (2i - 1)^2) / (8 * x * i)
s_{i} = s
w_{i} = t
end
sequence = @ntuple $N i -> s_{i}
weights = @ntuple $N i -> w_{i}
return levin_transform(sequence, weights) * sqrt(π / 2x)
end
)
end
# A reference so that you can check that the half-integer derivatives work.
using FiniteDifferences
dbeskx_dv_ref(v, x) = central_fdm(10,1)(_v->besselkx(_v, x), v) Which you can see works with julia> autodiff(Forward, _v->besselkx_levin(_v, 13.1, Val(20)), Duplicated, Duplicated(1.5, 1.0))
(0.372710911777806, 0.041092931130497966)
julia> dbeskx_dv_ref(1.5, 13.1)
0.04109293113049592 This works so well that we probably should just close #91. A bummer to have wasted that time, but introducing that extra stuff to get better partial information in a way that accommodates an early return in the end was slower and less accurate than this more elegant custom rule. That's what I get for try to avoid learning something new. The one remaining annoyance here is that the autodiff call allocates. But I'm sure we can figure that out. |
besselk
enzyme compatibility w.r.t. orderbesselk
Enzyme compatibility w.r.t. order
So cool 🔥 I personally love the simplicity of this. So let's get the path forward here. A PR would contain a module (name TBD something to do with AD) that will be used as an extension or weak dep and whenever Enzyme is loaded together with Bessels that will quick in. In that module we can put together custom rules for these functions that are pretty elegant. That module will have a dependency on EnzymeCore (https://github.com/EnzymeAD/Enzyme.jl/tree/main/lib/EnzymeCore) ? I don't think we have to have it depend completely on Enzyme ? I have pushed the new versions of this at #95 so let's get that merged ASAP. I think it is ready now. Then we can fork the master branch after that is merged to add the new module ?? Does that sound good @cgeoga ? Let's keep #91 going for now :) I don't think it's wasted and I really haven't had time to go through that explicitly to see where they could be useful. But I do agree that the sequence transformation approach is working really well both for the scalar and derivative evaluation. |
Also, could you give some examples for the power series stuff when you get a chance. Is that something to do with the implementation ? Does say the power series for |
Ahh so definitely the integer case is the issue and that does seem challenging without going into the temme series. Though we are very close here I think and will work on updating I was looking at the current power series a little bit and I think this could be really simplified function besselk_power_series(v, x::ComplexOrReal{T}) where T
MaxIter = 5000
gam = gamma(v)
ngam = π / (sinpi(-abs(v)) * gam * v)
s1, s2 = zero(T), zero(T)
t1, t2 = one(T), one(T)
for k in 1:MaxIter
s1 += t1
s2 += t2
t1 *= x^2 / (4k * (k - v))
t2 *= x^2 / (4k * (k + v))
abs(t1) < eps(T) && break
end
xpv = (x/2)^v
s = gam * s1 + xpv^2 * ngam * s2
return s / (2*xpv)
end Which is significantly faster that current version and similar accuracy. One thing that was interesting is it's not deciding to auto-vectorize it but probably making the decision those divisions are expensive. But ya the to-do for this is definitely the integer order power series. EDIT. Ya it's not vectorizing but I think it's making a judgement that the termination criteria within the loop makes the overhead not worth it. We can manually vectorize it like so.. using SIMDMath: Vec, fadd, fmul, fdiv
function besselk_power_series(v, x::T) where T
MaxIter = 5000
gam = gamma(v)
ngam = π / (sinpi(-abs(v)) * gam * v)
s = Vec((0.0, 0.0))
t = Vec((1.0, 1.0))
for k in 1:20
s = fadd(s, t)
t = fmul(t, x^2)
t = fdiv(t, Vec((4k * (k - v), 4k * (k + v))))
#abs(t1) < eps(T) && break
end
xpv = (x/2)^v
_s = gam * s[1] + xpv^2 * ngam * s[2]
return _s / (2*xpv)
end This is about 20% faster than the above version but this version does best for a set amount of terms unless want to make vector comparisons work. |
Just wanted to comment here as I looked into what was going on with incorrect return derivatives by Enzyme for the besselk power series. I thought it might be the gamma function but that actually returns good derivatives. It looks like it's actually the |
@heltonmc and @oscardssmith: I think I have a differentiable and stable power series-like method that works for integer orders and gives correct order derivatives at those values. This is very sloppy in its current form, but the idea is effectively to do a Temme-like trick of computing using SpecialFunctions # for now, can use the Bessels.gamma function of course.
const _g = MathConstants.eulergamma # because I don't like unicode, but will fix in the PR
const C0 = (-2*_g, -2)
const C2 = (_g*pi^2 - 2*_g^3 + 2*trigamma(1),
-(6*_g^2 + pi^2),
-6*_g,
-2.0)./6
const C4 = (-9*_g*pi^4 - 20*_g^3*pi^2 - 12*_g^5 + 12*polygamma(4,1) + 20*pi^2*trigamma(1) + 120*_g^2*trigamma(1),
-3*(20*_g^4 + 20*_g^2*pi^2 + 3*pi^4 - 80*_g*trigamma(1) + 120*_g^2*trigamma(1)),
-60*(2*_g^3 + _g*pi^2 - 2*trigamma(1)),
-20*(6*_g^2 + pi^2),
-60*_g,
-12)./720
# An important observation: If you plug the function for f_0 that I am using
# here into wolfram alpha, which for reference is
#
# f_0(v, x) = (x^v)*gamma(-v) + (x^(-v))*gamma(v)
# = (x^v)*(gamma(-v) + (x^(-2*v))*gamma(v)),
#
# into wolfram alpha and ask for an expansion around v=0, you see that what we
# actually have is a bivariate polynomial in (v^2, log(x)). That log(x) is
# actually a bit annoying, because now if we wanted to just use a bivariate
# cheby expansion or something the log term means that we actually have a huge
# domain. So this current version uses the actual coefficients from wolfram.
function f0_local_expansion_v0(v, x)
lx = log(x)
c0 = evalpoly(lx, C0)
c2 = evalpoly(lx, C2)
c4 = evalpoly(lx, C4)
evalpoly(v, (c0,0.0,c2,0.0,c4))/2
end
# Strategy: because of the order recurrence, we only need a code that gives
# (K_{v}(x), K_{v+1}(x)) for v in (-1/2,1/2]. Then we can use a simple forward
# recursion.
#
# This is slower than just using the straight power series, BUT the standard
# power series is not valid for integer orders, and can be numerically unstable
# for very near-integer orders. The point of this domain reduction in the order
# is so that we can isolate the problem of the power series' singularity when
# the order approaches an integer to a single coefficient.
function besselk_powerseries_temme_basal(v, x)
@assert -0.5 < v <= 0.5 "This internal routine is only for the base-sized order v."
# pre-computations:
# TODO (cg 2023/05/14 18:47): organize all of this more thoughtfully.
z = x/2
zz = z*z
# special branch for the initial f_0 coefficient when iszero(v):
if abs(v) < 1e-5
fk = f0_local_expansion_v0(v, x/2)
else
p1 = gamma(-v)*(z^v)
p2 = -gamma(v)*(z^(-v))
fk = (p1 - p2)/2
end
# TODO (cg 2023/05/14 18:05): the gamma(1+v) and gamma(1-v) in some sense
# should be computed inside the above if/else block.
(pk, qk, _ck, factk) = ((z^(-v))*gamma(1+v)/2, (z^v)*gamma(1-v)/2, 1.0, 1.0)
(out_v, out_vp1) = (0.0, 0.0)
# TODO (cg 2023/05/14 18:45): Should this be compile-time unrolled like our
# other newer implementations?
max_iter = 50
for k in 1:max_iter
# Add to the series.
ck = _ck/factk
term_v = ck*fk
term_vp1 = ck*(pk - (k-1)*fk)
out_v += term_v
out_vp1 += term_vp1
# TODO (cg 2023/05/14 18:47): take convergence check seriously.
((abs(term_v) < eps()) && (abs(term_vp1) < eps())) && break
# Increment/update each term:
# TODO (cg 2023/05/14 18:47): Ask @heltonmc to turn this into one muladd
# somehow
fk = (k*fk + pk + qk)/(k^2 - v^2)
pk /= (k-v)
qk /= (k+v)
_ck *= zz
factk *= k
end
(out_v, out_vp1/z)
end
# Stable and differentiable for integer orders.
function besselk_powerseries_temme(v, x)
v < zero(v) && return besselk_powerseries_temme(-v, x)
flv = Int(floor(v))
_v = v - flv
# get (K_{_v}(x), K_{_v+1}(x)):
(kv, kvp1) = besselk_powerseries_temme_basal(_v, x)
abs(v) < 1/2 && return kv
# do the forward recurrence:
twodx = 2/x
# TODO (cg 2023/05/14 18:47): Ask @heltonmc to turn this into one muladd
# somehow, or use the existing forward rec function that I don't understand.
for _ in 1:(flv-1)
_v += 1
(kv, kvp1) = (kvp1, muladd(twodx*_v, kvp1, kv))
end
kvp1
end As you will see and find annoying, I haven't done the thing you guys do of turning everything into a single Would love to hear thoughts! If you think fine modulo some cleanup, happy to create a PR. |
This is excellent and really only thing holding back the modified bessel function implementations and differentiable. My opinion is that if this is accurate and reasonably fast let's get this merged into the code base. I have a branch that I have been really waiting on this to merge into to complete So for me the structure for this power series would look like a combination of the version I posted above with a "near_int" check at the top level. So something like... function besselk_power_series(v, x::ComplexOrReal{T}) where T
isnearint(v) && return besselk_power_series_int(v, x)
MaxIter = 5000
gam = gamma(v)
ngam = π / (sinpi(-abs(v)) * gam * v)
s1, s2 = zero(T), zero(T)
t1, t2 = one(T), one(T)
for k in 1:MaxIter
s1 += t1
s2 += t2
t1 *= x^2 / (4k * (k - v))
t2 *= x^2 / (4k * (k + v))
abs(t1) < eps(T) && break
end
xpv = (x/2)^v
s = gam * s1 + xpv^2 * ngam * s2
return s / (2*xpv)
end
function besselk_power_series_int(v, x)
# your function here
end I have not really gone through the implementation on this but again I think optimization and tricks can come later at the end. Let's check for accuracy and merge and go from there. My only initial thing was that in the But this is awesome 😄 I'm excited to get this merged and clean up the whole function. Edit: So what we need to determine is how close to an integer does the simpler version present numerical cancellation. Probably like within 0.001 of an Int need to dispatch to this method. Edit2: I think unrolling this at compile time will depend. If the number of terms is pretty similar for the range of inputs that we are considering then we could consider unrolling but if there is a wide variability then maybe it's best to have the convergence check. Edit3: Ya I think the things to improve before merging would be to get this down to one gamma call and also consider the powers in |
That all sounds good to me. Considering that it is related, how would you feel about me adding it to the Also, will make those improvements you mention in the edit. I can certainly do that much before making the PR, even if I can't go full |
Go for it! Let's get it merged and stuff and can go from there so I can work off that branch. We should then have a fully differentiable Sounds good! I think those are the biggest things that stand out. I am literally unloading all my stuff from the move right now so I won't have that much time right to comb through the fine details but those I think are the biggest things. Again this method will only be for |
Awesome! Just pushed up a first commit that brings things in. A couple new With regard to a paper, unfortunately I didn't really need to do anything new or cool---it is just the initial term from the power series in my paper that is already very generously mentioned (but simplified a lot, as you have also mentioned) and then the rest of the coefficients/structure is right out of the Temme paper. So I would say it is really just the Temme method with a better designed |
I've merged #101 which should be first steps on |
Amazing! Will mark this issue as done and we can chase down small inaccuracies in subsequent issues/PRs. |
A narrow issue, but sorting this out will probably provide a pretty good template for other issues. The current issues are:
isinteger(v)
#102The text was updated successfully, but these errors were encountered: