Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

besselk Enzyme compatibility w.r.t. order #96

Closed
3 tasks done
cgeoga opened this issue Apr 27, 2023 · 12 comments
Closed
3 tasks done

besselk Enzyme compatibility w.r.t. order #96

cgeoga opened this issue Apr 27, 2023 · 12 comments

Comments

@cgeoga
Copy link
Contributor

cgeoga commented Apr 27, 2023

A narrow issue, but sorting this out will probably provide a pretty good template for other issues. The current issues are:

@cgeoga
Copy link
Contributor Author

cgeoga commented Apr 27, 2023

Here is a fix for the asymptotic expansion+Levin code that gives correctly derivatives for half-integer orders, using a custom forward rule in Enzyme, using the Bessels branch airy_levin2 and Enzyme#master (and my own formatting preferences, but I promise to go back to 4 character spaces before opening a PR):

using Bessels, Enzyme
using Enzyme.EnzymeRules
using Base.Cartesian: @nexprs, @ntuple
using Bessels.Math
using Bessels.Math: levin_scale

@inline @generated function levin_transform(s::NTuple{N, T}, 
                                            w::NTuple{N, T}) where {N, T <: FloatTypes}
  len = N - 1
  :(
    begin
      @nexprs $N i -> a_{i} = Vec{2, T}((s[i] / w[i], 1 / w[i]))
      @nexprs $len k -> (@nexprs ($len-k) i -> a_{i} = fmadd(a_{i}, levin_scale(one(T), i, k-1), a_{i+1}))
      return (a_1[1] / a_1[2])
    end
  )
end

# Here is the sauce: a manual rule that separates does the levin transform for
# the values and dvalues separately. This is because if v+1/2 is an integer, the
# sequences of _values_ will be exactly converged and you'll have a divide by
# zero NaN problem, but the _dvalues_ will NOT do that, and you want all those
# terms for the sake of accuracy.
function EnzymeRules.forward(func::Const{typeof(levin_transform)}, 
                             ::Type{<:Duplicated}, 
                             s::Duplicated,
                             w::Duplicated)
  (sv, dv, N) = (s.val, s.dval, length(s.val))
  ls  = (sv[N-1] == sv[N]) ? sv[N] : levin_transform(sv, w.val)
  dls = (dv[N-1] == dv[N]) ? dv[N] : levin_transform(dv, w.dval)
  Duplicated(ls, dls)
end

@generated function besselkx_levin(v, x::T, ::Val{N}) where {T <: FloatTypes, N}
  :(
    begin
      s = zero(T)
      t = one(T)
      @nexprs $N i -> begin
          s += t
          t *= (4*v^2 - (2i - 1)^2) / (8 * x * i)
          s_{i} = s
          w_{i} = t
        end
        sequence = @ntuple $N i -> s_{i}
        weights = @ntuple $N i -> w_{i}
      return levin_transform(sequence, weights) * sqrt/ 2x)
    end
  )
end

# A reference so that you can check that the half-integer derivatives work.
using FiniteDifferences
dbeskx_dv_ref(v, x) = central_fdm(10,1)(_v->besselkx(_v, x), v)

Which you can see works with

julia> autodiff(Forward, _v->besselkx_levin(_v, 13.1, Val(20)), Duplicated, Duplicated(1.5, 1.0))
(0.372710911777806, 0.041092931130497966)

julia> dbeskx_dv_ref(1.5, 13.1)
0.04109293113049592

This works so well that we probably should just close #91. A bummer to have wasted that time, but introducing that extra stuff to get better partial information in a way that accommodates an early return in the end was slower and less accurate than this more elegant custom rule. That's what I get for try to avoid learning something new.

The one remaining annoyance here is that the autodiff call allocates. But I'm sure we can figure that out.

@cgeoga cgeoga changed the title besselk enzyme compatibility w.r.t. order besselk Enzyme compatibility w.r.t. order Apr 27, 2023
@heltonmc
Copy link
Member

So cool 🔥 I personally love the simplicity of this. So let's get the path forward here. A PR would contain a module (name TBD something to do with AD) that will be used as an extension or weak dep and whenever Enzyme is loaded together with Bessels that will quick in. In that module we can put together custom rules for these functions that are pretty elegant. That module will have a dependency on EnzymeCore (https://github.com/EnzymeAD/Enzyme.jl/tree/main/lib/EnzymeCore) ? I don't think we have to have it depend completely on Enzyme ?

I have pushed the new versions of this at #95 so let's get that merged ASAP. I think it is ready now. Then we can fork the master branch after that is merged to add the new module ?? Does that sound good @cgeoga ?

Let's keep #91 going for now :) I don't think it's wasted and I really haven't had time to go through that explicitly to see where they could be useful. But I do agree that the sequence transformation approach is working really well both for the scalar and derivative evaluation.

@heltonmc
Copy link
Member

Also, could you give some examples for the power series stuff when you get a chance. Is that something to do with the implementation ? Does say the power series for besseli yield similar results or is this just something to do with the limit near integers in the bessel functions of the second kind?

@heltonmc
Copy link
Member

heltonmc commented May 4, 2023

Ahh so definitely the integer case is the issue and that does seem challenging without going into the temme series. Though we are very close here I think and will work on updating besselk for the general case that should be differentiable we will just have to think how we want to compute orders 0 and 1 for small values.

I was looking at the current power series a little bit and I think this could be really simplified

function besselk_power_series(v, x::ComplexOrReal{T}) where T
    MaxIter = 5000
    gam = gamma(v)
    ngam = π / (sinpi(-abs(v)) * gam * v)

    s1, s2 = zero(T), zero(T)
    t1, t2 = one(T), one(T)

    for k in 1:MaxIter
        s1 += t1
        s2 += t2
        t1 *= x^2 / (4k * (k - v))
        t2 *= x^2 / (4k * (k + v))
        abs(t1) < eps(T) && break
    end

    xpv = (x/2)^v
    s = gam * s1 + xpv^2 * ngam * s2
    return s / (2*xpv)
end

Which is significantly faster that current version and similar accuracy. One thing that was interesting is it's not deciding to auto-vectorize it but probably making the decision those divisions are expensive.

But ya the to-do for this is definitely the integer order power series.

EDIT. Ya it's not vectorizing but I think it's making a judgement that the termination criteria within the loop makes the overhead not worth it. We can manually vectorize it like so..

using SIMDMath: Vec, fadd, fmul, fdiv
function besselk_power_series(v, x::T) where T
    MaxIter = 5000
    gam = gamma(v)
    ngam = π / (sinpi(-abs(v)) * gam * v)

    s = Vec((0.0, 0.0))
    t = Vec((1.0, 1.0))

    for k in 1:20
        s = fadd(s, t)
        t = fmul(t, x^2)
        t = fdiv(t, Vec((4k * (k - v), 4k * (k + v))))
        #abs(t1) < eps(T) && break
    end

    xpv = (x/2)^v
    _s = gam * s[1] + xpv^2 * ngam * s[2]
    return _s / (2*xpv)
end

This is about 20% faster than the above version but this version does best for a set amount of terms unless want to make vector comparisons work.

@heltonmc
Copy link
Member

heltonmc commented May 7, 2023

Just wanted to comment here as I looked into what was going on with incorrect return derivatives by Enzyme for the besselk power series. I thought it might be the gamma function but that actually returns good derivatives. It looks like it's actually the sinpi function. See EnzymeAD/Enzyme.jl#443. This pops up a lot so would be good to fix that!

@cgeoga
Copy link
Contributor Author

cgeoga commented May 14, 2023

@heltonmc and @oscardssmith: I think I have a differentiable and stable power series-like method that works for integer orders and gives correct order derivatives at those values. This is very sloppy in its current form, but the idea is effectively to do a Temme-like trick of computing (K_{_v}(x), K_{_v+1}(x)) for _v = v - floor(v) and then use the forward recurrence. I know the forward recurrence is slow, but the value of the trick here is that (like with my original Temme implementation) you really effectively isolate the difficulty of the integer order case into a single coefficient. Here is a rough draft of the function:

using SpecialFunctions # for now, can use the Bessels.gamma function of course.

const _g = MathConstants.eulergamma # because I don't like unicode, but will fix in the PR

const C0 = (-2*_g, -2)
const C2 = (_g*pi^2 - 2*_g^3 + 2*trigamma(1),
            -(6*_g^2 + pi^2),
            -6*_g,
            -2.0)./6
const C4 = (-9*_g*pi^4 - 20*_g^3*pi^2 - 12*_g^5 + 12*polygamma(4,1) + 20*pi^2*trigamma(1) + 120*_g^2*trigamma(1),
            -3*(20*_g^4 + 20*_g^2*pi^2 + 3*pi^4 - 80*_g*trigamma(1) + 120*_g^2*trigamma(1)),
            -60*(2*_g^3 + _g*pi^2 - 2*trigamma(1)),
            -20*(6*_g^2 + pi^2),
            -60*_g,
            -12)./720

# An important observation: If you plug the function for f_0 that I am using
# here into wolfram alpha, which for reference is 
#
# f_0(v, x) = (x^v)*gamma(-v)  + (x^(-v))*gamma(v)
#           = (x^v)*(gamma(-v) + (x^(-2*v))*gamma(v)),
#
# into wolfram alpha and ask for an expansion around v=0, you see that what we
# actually have is a bivariate polynomial in (v^2, log(x)). That log(x) is
# actually a bit annoying, because now if we wanted to just use a bivariate
# cheby expansion or something the log term means that we actually have a huge
# domain. So this current version uses the actual coefficients from wolfram.
function f0_local_expansion_v0(v, x)
  lx = log(x)
  c0 = evalpoly(lx, C0)
  c2 = evalpoly(lx, C2)
  c4 = evalpoly(lx, C4)
  evalpoly(v, (c0,0.0,c2,0.0,c4))/2
end

# Strategy: because of the order recurrence, we only need a code that gives
# (K_{v}(x), K_{v+1}(x)) for v in (-1/2,1/2]. Then we can use a simple forward
# recursion.
#
# This is slower than just using the straight power series, BUT the standard
# power series is not valid for integer orders, and can be numerically unstable
# for very near-integer orders. The point of this domain reduction in the order
# is so that we can isolate the problem of the power series' singularity when
# the order approaches an integer to a single coefficient.
function besselk_powerseries_temme_basal(v, x)
  @assert -0.5 < v <= 0.5 "This internal routine is only for the base-sized order v."
  # pre-computations:
  # TODO (cg 2023/05/14 18:47): organize all of this more thoughtfully.
  z  = x/2
  zz = z*z
  # special branch for the initial f_0 coefficient when iszero(v):
  if abs(v) < 1e-5
    fk = f0_local_expansion_v0(v, x/2)
  else
    p1 =  gamma(-v)*(z^v)
    p2 = -gamma(v)*(z^(-v))
    fk = (p1 - p2)/2
  end
  # TODO (cg 2023/05/14 18:05): the gamma(1+v) and gamma(1-v) in some sense
  # should be computed inside the above if/else block.
  (pk, qk, _ck, factk) = ((z^(-v))*gamma(1+v)/2, (z^v)*gamma(1-v)/2, 1.0, 1.0)
  (out_v, out_vp1) = (0.0, 0.0)
  # TODO (cg 2023/05/14 18:45): Should this be compile-time unrolled like our
  # other newer implementations?
  max_iter = 50
  for k in 1:max_iter
    # Add to the series. 
    ck       = _ck/factk
    term_v   = ck*fk
    term_vp1 = ck*(pk - (k-1)*fk)
    out_v   += term_v
    out_vp1 += term_vp1
    # TODO (cg 2023/05/14 18:47): take convergence check seriously.
    ((abs(term_v) < eps()) && (abs(term_vp1) < eps())) && break
    # Increment/update each term:
    # TODO (cg 2023/05/14 18:47): Ask @heltonmc to turn this into one muladd
    # somehow
    fk     = (k*fk + pk + qk)/(k^2 - v^2)
    pk    /= (k-v)
    qk    /= (k+v)
    _ck   *= zz
    factk *= k
  end
  (out_v, out_vp1/z)
end

# Stable and differentiable for integer orders.
function besselk_powerseries_temme(v, x)
  v < zero(v) && return besselk_powerseries_temme(-v, x)
  flv = Int(floor(v))
  _v  = v - flv
  # get (K_{_v}(x), K_{_v+1}(x)):
  (kv, kvp1) = besselk_powerseries_temme_basal(_v, x)
  abs(v) < 1/2 && return kv
  # do the forward recurrence:
  twodx = 2/x
  # TODO (cg 2023/05/14 18:47): Ask @heltonmc to turn this into one muladd
  # somehow, or use the existing forward rec function that I don't understand.
  for _ in 1:(flv-1)
    _v += 1
    (kv, kvp1) = (kvp1, muladd(twodx*_v, kvp1, kv))
  end
  kvp1
end

As you will see and find annoying, I haven't done the thing you guys do of turning everything into a single muladd somehow. So that is on the todo list, along with some general re-organization and refinement. But even as-is, this is not painfully slow. And at least with ForwardDiff (the only one I have checked so far), it gives correct derivatives at integer orders.

Would love to hear thoughts! If you think fine modulo some cleanup, happy to create a PR.

@heltonmc
Copy link
Member

heltonmc commented May 16, 2023

This is excellent and really only thing holding back the modified bessel function implementations and differentiable. My opinion is that if this is accurate and reasonably fast let's get this merged into the code base. I have a branch that I have been really waiting on this to merge into to complete besselk.

So for me the structure for this power series would look like a combination of the version I posted above with a "near_int" check at the top level. So something like...

function besselk_power_series(v, x::ComplexOrReal{T}) where T
    isnearint(v) && return besselk_power_series_int(v, x)
    MaxIter = 5000
    gam = gamma(v)
    ngam = π / (sinpi(-abs(v)) * gam * v)

    s1, s2 = zero(T), zero(T)
    t1, t2 = one(T), one(T)

    for k in 1:MaxIter
        s1 += t1
        s2 += t2
        t1 *= x^2 / (4k * (k - v))
        t2 *= x^2 / (4k * (k + v))
        abs(t1) < eps(T) && break
    end

    xpv = (x/2)^v
    s = gam * s1 + xpv^2 * ngam * s2
    return s / (2*xpv)
end
function besselk_power_series_int(v, x)
# your function here
end

I have not really gone through the implementation on this but again I think optimization and tricks can come later at the end. Let's check for accuracy and merge and go from there. My only initial thing was that in the f0_local_expansion_v0 it'll probably be best to just store the Float64 values and then we can SIMD that better. And also the last line in that function should just be evalpoly(v*v, (c0, c2, c4))/2 because it's an even series. Will be slightly faster and cleaner.

But this is awesome 😄 I'm excited to get this merged and clean up the whole function.

Edit: So what we need to determine is how close to an integer does the simpler version present numerical cancellation. Probably like within 0.001 of an Int need to dispatch to this method.

Edit2: I think unrolling this at compile time will depend. If the number of terms is pretty similar for the range of inputs that we are considering then we could consider unrolling but if there is a wide variability then maybe it's best to have the convergence check.

Edit3: Ya I think the things to improve before merging would be to get this down to one gamma call and also consider the powers in (pk, qk, _ck, factk) = ((z^(-v))*gamma(1+v)/2, (z^v)*gamma(1-v)/2, 1.0, 1.0). It's probably not need to compute both z^(-v) and z^v. That line will be very expensive. I'm assuming the compiler is not making that simplification because the results probably won't be strictly equal but we should do this.

@cgeoga
Copy link
Contributor Author

cgeoga commented May 16, 2023

That all sounds good to me. Considering that it is related, how would you feel about me adding it to the EnzymeRules PR? Or would you prefer it to be separate?

Also, will make those improvements you mention in the edit. I can certainly do that much before making the PR, even if I can't go full muladd like you guys can.

@heltonmc
Copy link
Member

Go for it! Let's get it merged and stuff and can go from there so I can work off that branch. We should then have a fully differentiable besselk at least for real arguments.

Sounds good! I think those are the biggest things that stand out. I am literally unloading all my stuff from the move right now so I won't have that much time right to comb through the fine details but those I think are the biggest things. Again this method will only be for x<1.5 and nu ~ int and `nu<25. Important nonetheless but can further optimize the finer details. Is there a paper that we can link to as well or is some of it some personal modifications :)

@cgeoga
Copy link
Contributor Author

cgeoga commented May 16, 2023

Awesome! Just pushed up a first commit that brings things in. A couple new Float32 test failures and I still need to write the Enzyme tests for that new method, but I'll do those things tomorrow.

With regard to a paper, unfortunately I didn't really need to do anything new or cool---it is just the initial term from the power series in my paper that is already very generously mentioned (but simplified a lot, as you have also mentioned) and then the rest of the coefficients/structure is right out of the Temme paper. So I would say it is really just the Temme method with a better designed $f_0(v,x)$ coefficient.

@heltonmc
Copy link
Member

I've merged #101 which should be first steps on besselk compatibility with AD. Undoubtedly there will be inaccuracies that we will have to track down where they are coming from either through cutoffs etc. But that should make the general besselk routine differentiable at any value.

@cgeoga
Copy link
Contributor Author

cgeoga commented May 24, 2023

Amazing! Will mark this issue as done and we can chase down small inaccuracies in subsequent issues/PRs.

@cgeoga cgeoga closed this as completed May 24, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants