Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add easy way to use Mooncake.jl for gradients #26

Merged
merged 7 commits into from
Nov 26, 2024
Merged

Add easy way to use Mooncake.jl for gradients #26

merged 7 commits into from
Nov 26, 2024

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Nov 22, 2024

This is parallel to FluxML/Flux.jl#2471 for Enzyme. See compintell/Mooncake.jl#361 for some benchmarks (without using this).

Questions for @willtebbutt are:

  • Do you hate this? Should Mooncake own the struct? It's not far from CoDual but (1) we want a one-arg method, and (2) Flux wants nicer printing, and (3) that is marked private & discouraged?
  • How do I implement Const, to take gradient with respect to only some arguments? DI seems to just compute them all & discard. CoDual(x, NoTangent()) does not seem to work.

Re the first point, rather than having one struct per package, another option might be to have one universal thing:

GradWrap(Enzyme, model)    # returns something that behaves like Duplicated(model)
GradWrap(Mooncake, model)  # instead of Moonduo
GradWrap(Reactant, model)  # for Reactant.jl, would store compiled functions not just gradients, mutable

Flux.gradient(loss, ::GradWrap, ...)  # uses selected package

Copy link

codecov bot commented Nov 23, 2024

Codecov Report

Attention: Patch coverage is 45.00000% with 33 lines in your changes missing coverage. Please review.

Project coverage is 73.55%. Comparing base (21693da) to head (139ff3e).
Report is 1 commits behind head on master.

Files with missing lines Patch % Lines
ext/FluxMooncakeExt.jl 51.02% 24 Missing ⚠️
src/mooncake.jl 18.18% 9 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master      #26      +/-   ##
==========================================
- Coverage   79.92%   73.55%   -6.37%     
==========================================
  Files           7        9       +2     
  Lines         269      329      +60     
==========================================
+ Hits          215      242      +27     
- Misses         54       87      +33     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.


🚨 Try these New Features:

Copy link
Member

@willtebbutt willtebbutt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is cool. Thanks for engaging with Mooncake on this!

In answer to your questions:

  1. I do not hate this at all. It's not like it's lots of code duplication, and you having ownership over the type seems very reasonable to me.
  2. Sadly, the correct thing to do is currently to do what DI currently does, and simply ignore the terms you do not want. I hope this will change at some point, but I have no idea when that will happen.

edit: also, you're correct, CoDual(x, NoTangent()) won't work in general. Mooncake's semantics are: if you see a CoDual(x, dx), it must be true that either

  1. typeof(dx) == tangent_type(typeof(x)), or
  2. typeof(dx) == fdata_type(tangent_type(typeof(x))).

Anything else is incorrect, and will cause you problems. Which one of these you should find depends on context. In all of the functionality that you are using, it should always be 1.

_moonstrip(dx::Union{Tuple, NamedTuple, AbstractArray}) = map(_moonstrip, dx)
_moonstrip(dx::AbstractArray{Mooncake.NoTangent}) = nothing
_moonstrip(dx::AbstractArray{<:Number}) = dx
_moonstrip(dx::AbstractArray{<:Integer}) = nothing
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To the best of my knowledge, this shouldn't ever be a case that you see.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I didn't try very hard! Do you know if there are other types which can occur, besides those handled here?

Copy link
Member

@willtebbutt willtebbutt Nov 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you place any constraints on the types of the things that you take gradients w.r.t.?

I ask because there's nothing to stop people adding a new type, and declaring a non-standard tangent type for it (i.e. while some subtype of Tangent is the thing returned by tangent_type by default for structs, there's nothing to stop people making tangent_type return something else for a type that they own). So in principle you could see literally anything. In practice, assuming that you're working with Arrays, and structs / mutable structs / Tuples / NamedTuples of Arrays, I think you should be fine.

My honest advice would be to do what you're doing at the minute. i.e. check that it works for models that you care about, and ensure that there's a good error message so that a user knows where to ask for help if they encounter something you weren't expecting.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, sounds good!

No constraints, but in reality it's going to be arrays & all kinds of structs.

zero && Mooncake.set_to_zero!!(x.dval)
end
coduals = map(x -> Mooncake.CoDual(x.val, x.dval), args)
val, _ = Mooncake.__value_and_gradient!!(rule, Mooncake.zero_codual(f), coduals...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any chance at all that f will contain trainable parameters, or does Flux insist that you not do that?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea Flux.gradient(f, x, y) has length 2, alla Zygote. I agree that's not the maximally flexible thing, and very occasionally you end up with gradient(|>, x, f)... but in real use it seems like never.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool

@mcabbott
Copy link
Member Author

Thanks for having a look!

Ok, it is easy of course to hide undesired inputs. That's all that's possible with Zygote too.

I could check but do you know whether gradient(x -> f(x, y), x) always does the work to compute dy, or does it notice that this is global & propagate its constant-ness along? Zygote always computes. If Mooncake does not, then we (and DI) could transform gradient(f, CoDual(x, ...), Const(y)) to that.

@willtebbutt
Copy link
Member

willtebbutt commented Nov 25, 2024

It's always computed I'm afraid -- since x -> f(x, y) is just a closure, and therefore has a y field, Mooncake treats it like any other struct, and computes the gradient w.r.t. the parameters it closes over.

@mcabbott mcabbott marked this pull request as ready for review November 26, 2024 03:56
@mcabbott mcabbott merged commit 3b70dd0 into master Nov 26, 2024
2 of 3 checks passed
@mcabbott mcabbott deleted the mooncake branch November 26, 2024 04:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants