-
Notifications
You must be signed in to change notification settings - Fork 230
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add Kaiming He initialization, fixed Xavier initialization #311
base: master
Are you sure you want to change the base?
Conversation
src/distributions.jl
Outdated
@@ -37,7 +37,7 @@ function xavier(a...) | |||
fanout = size(w, ndims(w)) | |||
fanin = div(length(w), fanout) | |||
end | |||
s = convert(eltype(w), sqrt(2 / (fanin + fanout))) | |||
s = convert(eltype(w), sqrt(6 / (fanin + fanout))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think, our version is specialized for conv layers with relu activation. The part you changed is called as gain
. You may want to update your pr to allow the xavier function to accept the gain parameter. And its default value can be 6.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be honest, I barely know the theoretical background. I guess you are referring to "Delving Deep into Rectifiers" paper when you say it is specialized for conv layers with relu activation. In the paper, it states this should hold: n_l * var(w_l) = 2
where n_l
is the average number of units per layer. You can check that:
x = xavier(200,300)
(200+300) / 2 * var(x) ~= 0.33
, where this value should be 1.0 for Xavier, 2.0 for ReLU activation. I also compared xavier
with Tensorflow's equivalent initializer. TF's xavier is ~3 times of xavier
, and TF's kaiming (relu specialized xavier) is ~6 times of xavier
, consistently.
As for your suggestion I am very new to Julia and I couldn't find a way to edit arguments so that it is compatible with pre-existing models. However there can be another distribution that takes both gain
and n
as arguments (as in TF).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use keyword arguments for options.
xavier(a...; gain = 6)
In the original version of If we also want to take activation functions into account, we can change the default I change the default Kaiming gain value to sqrt(2) (that is for ReLU activation units) since in the original description this is done in this way. This way of Xavier and Kaiming initializations gives the same variance as in PyTorch. |
@ozanarkancan do you think is there a problem in this PR? |
@ekinakyurek @denizyuret The branch can be merged, however, changing the initialization method will possibly break the replicability of experiments that use the current implementation. This should be stated in somewhere... |
Is there a way to do this in a backwardly compatible manner? For example
by keeping the default arguments consistent with prior implementation.
…On Tue, Dec 18, 2018 at 6:48 AM Ozan Arkan Can ***@***.***> wrote:
@ekinakyurek <https://github.com/ekinakyurek> @denizyuret
<https://github.com/denizyuret> The branch can be merged, however,
changing the initialization method will possibly break the replicability of
experiments that use the current implementation. This should be stated in
somewhere...
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#311 (comment)>,
or mute the thread
<https://github.com/notifications/unsubscribe-auth/ABvNpoTl7P4xVrx7aLV6NfOPUZ7bGZ8Kks5u6NYhgaJpZM4Ub0VV>
.
|
Xavier initialization is x ~ U( -sqrt( 6.0 / (fan_in + fan_out)), +sqrt( 6.0 / (fan_in + fan_out))),
or x ~ N(mean = 0, std = sqrt( 2.0 / (fan_in+fan_out))).
Kaiming initialization is x ~ U( -sqrt( 3.0 / fan_in), +sqrt( 3.0 / fan_in)),
or x ~ N(mean = 0, std = sqrt( 1.0 / fan_in)).