-
-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add easy way to use Mooncake.jl for gradients #26
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #26 +/- ##
==========================================
- Coverage 79.92% 73.55% -6.37%
==========================================
Files 7 9 +2
Lines 269 329 +60
==========================================
+ Hits 215 242 +27
- Misses 54 87 +33 ☔ View full report in Codecov by Sentry. 🚨 Try these New Features:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is cool. Thanks for engaging with Mooncake on this!
In answer to your questions:
- I do not hate this at all. It's not like it's lots of code duplication, and you having ownership over the type seems very reasonable to me.
- Sadly, the correct thing to do is currently to do what DI currently does, and simply ignore the terms you do not want. I hope this will change at some point, but I have no idea when that will happen.
edit: also, you're correct, CoDual(x, NoTangent())
won't work in general. Mooncake's semantics are: if you see a CoDual(x, dx)
, it must be true that either
typeof(dx) == tangent_type(typeof(x))
, ortypeof(dx) == fdata_type(tangent_type(typeof(x)))
.
Anything else is incorrect, and will cause you problems. Which one of these you should find depends on context. In all of the functionality that you are using, it should always be 1.
ext/FluxMooncakeExt.jl
Outdated
_moonstrip(dx::Union{Tuple, NamedTuple, AbstractArray}) = map(_moonstrip, dx) | ||
_moonstrip(dx::AbstractArray{Mooncake.NoTangent}) = nothing | ||
_moonstrip(dx::AbstractArray{<:Number}) = dx | ||
_moonstrip(dx::AbstractArray{<:Integer}) = nothing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To the best of my knowledge, this shouldn't ever be a case that you see.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I didn't try very hard! Do you know if there are other types which can occur, besides those handled here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you place any constraints on the types of the things that you take gradients w.r.t.?
I ask because there's nothing to stop people adding a new type, and declaring a non-standard tangent type for it (i.e. while some subtype of Tangent
is the thing returned by tangent_type
by default for struct
s, there's nothing to stop people making tangent_type
return something else for a type that they own). So in principle you could see literally anything. In practice, assuming that you're working with Array
s, and struct
s / mutable struct
s / Tuple
s / NamedTuple
s of Array
s, I think you should be fine.
My honest advice would be to do what you're doing at the minute. i.e. check that it works for models that you care about, and ensure that there's a good error message so that a user knows where to ask for help if they encounter something you weren't expecting.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, sounds good!
No constraints, but in reality it's going to be arrays & all kinds of structs.
zero && Mooncake.set_to_zero!!(x.dval) | ||
end | ||
coduals = map(x -> Mooncake.CoDual(x.val, x.dval), args) | ||
val, _ = Mooncake.__value_and_gradient!!(rule, Mooncake.zero_codual(f), coduals...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any chance at all that f
will contain trainable parameters, or does Flux insist that you not do that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea Flux.gradient(f, x, y)
has length 2, alla Zygote. I agree that's not the maximally flexible thing, and very occasionally you end up with gradient(|>, x, f)
... but in real use it seems like never.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool
Thanks for having a look! Ok, it is easy of course to hide undesired inputs. That's all that's possible with Zygote too. I could check but do you know whether |
It's always computed I'm afraid -- since |
This is parallel to FluxML/Flux.jl#2471 for Enzyme. See compintell/Mooncake.jl#361 for some benchmarks (without using this).
Questions for @willtebbutt are:
CoDual(x, NoTangent())
does not seem to work.Re the first point, rather than having one struct per package, another option might be to have one universal thing: