Skip to content

Commit

Permalink
Fix constant, but aliasing, return (#1080)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Sep 25, 2023
1 parent 3c7f8c7 commit 96d5efb
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 9 deletions.
2 changes: 1 addition & 1 deletion lib/EnzymeTestUtils/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "EnzymeTestUtils"
uuid = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
authors = ["Seth Axen <seth@sethaxen.com>", "William Moses <wmoses@mit.edu>", "Valentin Churavy <vchuravy@mit.edu>"]
version = "0.1.1"
version = "0.1.2"

[deps]
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
Expand Down
16 changes: 9 additions & 7 deletions lib/EnzymeTestUtils/src/finite_difference_calls.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ end
_fd_forward(fdm, f, ::Type{<:Const}, y, activities) = ()

#=
_fd_reverse(fdm, f, ȳ, activities)
_fd_reverse(fdm, f, ȳ, activities, active_return)
Call `FiniteDifferences.j′vp` on `f` with the arguments `xs` determined by `activities`.
Expand All @@ -51,14 +51,14 @@ Call `FiniteDifferences.j′vp` on `f` with the arguments `xs` determined by `ac
- `f`: The function to differentiate.
- `ȳ`: The cotangent of the primal output `y=f(xs...)`.
- `activities`: activities that would be passed to `Enzyme.autodiff`
- `active_return`: whether the return is non-constant
# Returns
- `x̄s`: Derivatives of output `s` w.r.t. `xs` estimated by finite differencing.
=#
function _fd_reverse(fdm, f, ȳ, activities)
function _fd_reverse(fdm, f, ȳ, activities, active_return)
xs = map(x -> x.val, activities)
ignores = map(a -> a isa Const, activities)
f2 = _wrap_reverse_function(f, xs, ignores)
f2 = _wrap_reverse_function(active_return, f, xs, ignores)
all(ignores) && return map(zero_tangent, xs)
ignores = collect(ignores)
is_batch = _any_batch_duplicated(map(typeof, activities)...)
Expand Down Expand Up @@ -137,7 +137,7 @@ All arguments are copied before being passed to `f`, so that `fnew` is non-mutat
- `ignores`: Collection of `Bool`s, the same length as `xs`.
If `ignores[i] === true`, then `xs[i]` is ignored; `∂xs[i] === NoTangent()`.
=#
function _wrap_reverse_function(f, xs, ignores)
function _wrap_reverse_function(active_return, f, xs, ignores)
function fnew(sigargs...)
callargs = Any[]
retargs = Any[]
Expand Down Expand Up @@ -165,8 +165,10 @@ function _wrap_reverse_function(f, xs, ignores)

# we will now explicitly zero all objects returned, and replace any of the args with this
# zero, if the input and output alias.
for k in keys(zeros)
zeros[k] = zero_tangent(k)
if active_return
for k in keys(zeros)
zeros[k] = zero_tangent(k)
end
end

return (origRet, Base.deepcopy_internal(retargs, zeros)...)
Expand Down
2 changes: 1 addition & 1 deletion lib/EnzymeTestUtils/src/test_reverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ function test_reverse(
end
end
# call finitedifferences, avoid mutating original arguments
dx_fdm = _fd_reverse(fdm, call_with_kwargs, ȳ, activities)
dx_fdm = _fd_reverse(fdm, call_with_kwargs, ȳ, activities, !(ret_activity <: Const))
# call autodiff, allow mutating original arguments
c_act = Const(call_with_kwargs)
forward, reverse = autodiff_thunk(
Expand Down

2 comments on commit 96d5efb

@wsmoses
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register subdir="lib/EnzymeTestUtils"

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/92140

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a EnzymeTestUtils-v0.1.2 -m "<description of version>" 96d5efb395e2a5a407d6e9e2eec95cbc819d7a6f
git push origin EnzymeTestUtils-v0.1.2

Please sign in to comment.