Skip to content

Commit

Permalink
add forward, reverse-mode AD operators, enable tests for reverse-mode (
Browse files Browse the repository at this point in the history
…#185)

- adds a dynamic variable `emmy.calculus.derivative/*mode*` that allows
the
user to switch between forward and reverse mode automatic
differentiation

  - adds a new `emmy.calculus.derivative/gradient` that acts like
    `emmy.tape/gradient` but is capable of taking multiple variables

- adds new operators `emmy.calculus.derivative/{D-forward, D-reverse}`
and
operator-returning-functions `emmy.calculus.derivative/{partial-forward,
partial-reverse}` that allow the user to explicitly invoke forward-mode
or
reverse-mode automatic differentiation. `D` and `partial` still default
to
    forward-mode

  - modifies `emmy.tape/gradient` to correctly error when passed invalid
    selectors, just like `emmy.dual/derivative`.
  • Loading branch information
sritchie authored Aug 13, 2024
1 parent d4c4c07 commit 622ba21
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 54 deletions.
17 changes: 17 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,23 @@

## [unreleased]

- #185:

- adds a dynamic variable `emmy.calculus.derivative/*mode*` that allows the
user to switch between forward and reverse mode automatic differentiation

- adds a new `emmy.calculus.derivative/gradient` that acts like
`emmy.tape/gradient` but is capable of taking multiple variables

- adds new operators `emmy.calculus.derivative/{D-forward, D-reverse}` and
operator-returning-functions `emmy.calculus.derivative/{partial-forward,
partial-reverse}` that allow the user to explicitly invoke forward-mode or
reverse-mode automatic differentiation. `D` and `partial` still default to
forward-mode

- modifies `emmy.tape/gradient` to correctly error when passed invalid
selectors, just like `emmy.dual/derivative`.

- #183:

- adds `emmy.{autodiff, tape}` to `emmy.sci`'s exported namespace set
Expand Down
165 changes: 132 additions & 33 deletions src/emmy/calculus/derivative.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@

;; ## Single and Multivariable Calculus
;;
;; These functions put together the pieces laid out
;; in [[emmy.dual]] and declare an interface for taking
;; derivatives.
;; These functions put together the pieces laid out in [[emmy.dual]] and declare
;; an interface for taking derivatives.

;; The result of applying the derivative `(D f)` of a multivariable function `f`
;; to a sequence of `args` is a structure of the same shape as `args` with all
Expand All @@ -37,7 +36,8 @@
;; To generate the result:
;;
;; - For a single non-structural argument, return `(d/derivative f)`
;; - else, bundle up all arguments into a single [[s/Structure]] instance `xs`
;; - else, bundle up all arguments into a single [[emmy.structure/Structure]]
;; instance `xs`
;; - Generate `xs'` by replacing each entry in `xs` with `((d/derivative f')
;; entry)`, where `f'` is a function of ONLY that entry that
;; calls `(f (assoc-in xs path entry))`. In other words, replace each entry
Expand All @@ -49,7 +49,7 @@
;; above.
;;
;; [[jacobian]] handles this main logic. [[jacobian]] can only take a structural
;; input. [[euclidean]] and [[multivariate]] below widen handle, respectively,
;; input. [[euclidean]] and [[multivariate]] below handle, respectively,
;; optionally-structural and multivariable arguments.

(defn- deep-partial
Expand Down Expand Up @@ -109,12 +109,12 @@
(u/illegal (str "Bad selectors " selectors " for structure " input))))))

(defn- euclidean
"Slightly more general version of [[jacobian]] that can handle a single
non-structural input; dispatches to either [[jacobian]] or [[derivative]]
depending on the input type.
"Slightly more general version of [[jacobian]] that can handle a single input;
dispatches to either [[jacobian]] or [[derivative]] depending on whether or
not the input is structural.
If you pass non-empty `selectors`, the returned function will throw if it
receives a non-structural, non-numerical argument."
receives a non-structural, non-scalar argument."
([f] (euclidean f []))
([f selectors]
(let [selectors (vec selectors)]
Expand Down Expand Up @@ -143,6 +143,28 @@
(str "Selectors " selectors
" not allowed for non-structural input " input)))))))

(defn- multi
"Given
- some higher-order function `op` that transforms a function of a single
variable into another function of a single variable
- function `f` capable of taking multiple arguments
returns a new function that acts like `(op f)` but can take multiple
arguments.
When passed multiple arguments, the returned functon packages them into a
single `[[emmy.structure/up]]` instance. Any [[emmy.matrix/Matrix]] present in
the argument list will be converted into a `down` of `up`s (a row of columns)."
[op f]
(-> (fn
([] 0)
([x] ((op f) x))
([x & more]
((multi op (fn [xs] (apply f xs)))
(matrix/seq-> (cons x more)))))
(f/with-arity (f/arity f) {:from ::multi})))

(defn- multivariate
"Slightly wider version of [[euclidean]]. Accepts:
Expand All @@ -152,24 +174,39 @@
And returns a new function that computes either the
full [Jacobian](https://en.wikipedia.org/wiki/Jacobian_matrix_and_determinant)
or the entry at `selectors`.
or the entry at `selectors` using [forward-mode automatic
differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation#Forward_accumulation).
Any multivariable function will have its argument vector coerced into an `up`
structure. Any [[matrix/Matrix]] in a multiple-arg function call will be
structure. Any [[emmy.matrix/Matrix]] in a multiple-arg function call will be
converted into a `down` of `up`s (a row of columns).
Single-argument functions don't transform their arguments."
Arguments to single-variable functions are not transformed."
([f] (multivariate f []))
([f selectors]
(let [d #(euclidean % selectors)
df (d f)
df* (d (fn [args] (apply f args)))]
(-> (fn
([] 0)
([x] (df x))
([x & more]
(df* (matrix/seq-> (cons x more)))))
(f/with-arity (f/arity f) {:from ::multivariate})))))
(let [d #(euclidean % selectors)]
(multi d f))))

(defn gradient
"Accepts:
- some function `f` of potentially many arguments
- optionally, a sequence of selectors meant to index into the structural
argument, or argument vector, of `f`
And returns a new function that computes either the
full [Jacobian](https://en.wikipedia.org/wiki/Jacobian_matrix_and_determinant)
or the entry at `selectors` using [reverse-mode automatic
differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation#Reverse_accumulation).
Any multivariable function will have its argument vector coerced into an `up`
structure. Any [[emmy.matrix/Matrix]] in a multiple-arg function call will be
converted into a `down` of `up`s (a row of columns).
Arguments to single-variable functions are not transformed."
([f] (gradient f []))
([f selectors]
(multi #(tape/gradient % selectors) f)))

;; ## Generic [[g/partial-derivative]] Installation
;;
Expand All @@ -192,24 +229,66 @@
;; passed to the structure of functions, instead of separately for every entry
;; in the structure.
;;
;; A dynamic variable controls whether or not this process uses forward-mode or
;; reverse-mode AD.
;;
;; TODO: I think this is going to cause problems for, say, a Structure of
;; PowerSeries, where there is actually a cheap `g/partial-derivative`
;; implementation for the components. I vote to back out this `::s/structure`
;; installation.

(def ^:dynamic *mode* d/FORWARD-MODE)

(doseq [t [::v/function ::s/structure]]
(defmethod g/partial-derivative [t v/seqtype] [f selectors]
(multivariate f selectors))
(if (= *mode* d/FORWARD-MODE)
(multivariate f selectors)
(gradient f selectors)))

(defmethod g/partial-derivative [t nil] [f _]
(multivariate f [])))
(if (= *mode* d/FORWARD-MODE)
(multivariate f [])
(gradient f []))))

;; ## Operators
;;
;; This section exposes various differential operators as [[o/Operator]]
;; instances.
;; This section exposes various differential operators
;; as [[emmy.operator/Operator]] instances.

(def ^{:arglists '([f])}
D-forward
"Forward-mode derivative operator. Takes some function `f` and returns a
function whose value at some point can multiply an increment in the arguments
to produce the best linear estimate of the increment in the function value.
(def D
For univariate functions, [[D-forward]] computes a derivative. For vector-valued
functions, [[D-forward]] computes
the [Jacobian](https://en.wikipedia.org/wiki/Jacobian_matrix_and_determinant)
of `f`."
(o/make-operator
(fn [x]
(binding [*mode* d/FORWARD-MODE]
(g/partial-derivative x [])))
g/derivative-symbol))

(def ^{:arglists '([f])}
D-reverse
"Reverse-mode derivative operator. Takes some function `f` and returns a
function whose value at some point can multiply an increment in the arguments
to produce the best linear estimate of the increment in the function value.
For univariate functions, [[D-reverse]] computes a derivative. For vector-valued
functions, [[D-reverse]] computes
the [Jacobian](https://en.wikipedia.org/wiki/Jacobian_matrix_and_determinant)
of `f`."
(o/make-operator
(fn [x]
(binding [*mode* d/REVERSE-MODE]
(g/partial-derivative x [])))
g/derivative-symbol))

(def ^{:arglists '([f])}
D
"Derivative operator. Takes some function `f` and returns a function whose value
at some point can multiply an increment in the arguments to produce the best
linear estimate of the increment in the function value.
Expand All @@ -222,8 +301,7 @@
The related [[emmy.env/Grad]] returns a function that produces a structure of
the opposite orientation as [[D]]. Both of these functions use forward-mode
automatic differentiation."
(o/make-operator #(g/partial-derivative % [])
g/derivative-symbol))
D-forward)

(defn D-as-matrix [F]
(fn [s]
Expand All @@ -232,13 +310,34 @@
((D F) s)
s)))

(defn partial
(defn partial-forward
"Returns an operator that, when applied to a function `f`, produces a function
that uses forward-mode automatic differentiation to compute the partial
derivative of `f` at the (zero-based) slot index provided via `selectors`."
[& selectors]
(o/make-operator
(fn [x]
(binding [*mode* d/FORWARD-MODE]
(g/partial-derivative x selectors)))
`(~'partial ~@selectors)))

(defn partial-reverse
"Returns an operator that, when applied to a function `f`, produces a function
that computes the partial derivative of `f` at the (zero-based) slot index
provided via `selectors`."
that uses reverse-mode automatic differentiation to compute the partial
derivative of `f` at the (zero-based) slot index provided via `selectors`."
[& selectors]
(o/make-operator #(g/partial-derivative % selectors)
`(~'partial ~@selectors)))
(o/make-operator
(fn [x]
(binding [*mode* d/REVERSE-MODE]
(g/partial-derivative x selectors)))
`(~'partial ~@selectors)))

(def ^{:arglists '([& selectors])}
partial
"Returns an operator that, when applied to a function `f`, produces a function
that uses forward-mode automatic differentiation to compute the partial
derivative of `f` at the (zero-based) slot index provided via `selectors`."
partial-forward)

;; ## Derivative Utilities
;;
Expand Down
20 changes: 12 additions & 8 deletions src/emmy/tape.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -512,16 +512,20 @@
(u/illegal
(str "Selectors " selectors
" not allowed for non-structural input " x)))

(let [tag (d/fresh-tag)
inputs (if (empty? selectors)
(tapify x tag)
(update-in x selectors tapify tag))
output (d/with-active-tag tag f [inputs])
input (if-let [piece (get-in x selectors)]
(if (empty? selectors)
(tapify piece tag)
(assoc-in x selectors (tapify piece tag)))
;; The call to `get-in` will return nil if the
;; `selectors` don't index correctly into the supplied
;; `input`, triggering this exception.
(u/illegal
(str "Bad selectors " selectors " for structure " x)))
output (d/with-active-tag tag f [input])
completed (d/extract-tangent output tag d/REVERSE-MODE)]
(if (empty? selectors)
(interpret inputs completed tag)
(interpret (get-in inputs selectors) completed tag)))))))
(-> (get-in input selectors)
(interpret completed tag)))))))

(defmethod g/zero-like [::tape] [_] 0)
(defmethod g/one-like [::tape] [_] 1)
Expand Down
44 changes: 31 additions & 13 deletions test/emmy/calculus/derivative_test.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -446,13 +446,27 @@
(testing "space"
(let [g (af/literal-function 'g [0 0] 0)
h (af/literal-function 'h [0 0] 0)]
(is (= '(+ (((partial 0) g) x y) (((partial 0) h) x y))
(simplify (((partial 0) (+ g h)) 'x 'y))))
(is (= '(+ (* (((partial 0) g) x y) (h x y)) (* (((partial 0) h) x y) (g x y)))
(simplify (((partial 0) (* g h)) 'x 'y))))
(is (= '(+ (* (((partial 0) g) x y) (h x y) (expt (g x y) (+ (h x y) -1)))
(* (((partial 0) h) x y) (log (g x y)) (expt (g x y) (h x y))))
(simplify (((partial 0) (g/expt g h)) 'x 'y))))))
(is (zero?
(simplify
(g/- (g/+ (((partial 0) g) 'x 'y)
(((partial 0) h) 'x 'y))
(((partial 0) (+ g h)) 'x 'y)))))
(is (zero?
(simplify
(g/-
(g/+ (g/* (((partial 0) g) 'x 'y) (h 'x 'y))
(g/* (((partial 0) h) 'x 'y) (g 'x 'y)))
(((partial 0) (* g h)) 'x 'y)))))
(is (zero?
(simplify
(g/-
(g/+ (g/* (((partial 0) g) 'x 'y)
(h 'x 'y)
(g/expt (g 'x 'y) (+ (h 'x 'y) -1)))
(g/* (((partial 0) h) 'x 'y)
(g/log (g 'x 'y))
(g/expt (g 'x 'y) (h 'x 'y))))
(((partial 0) (g/expt g h)) 'x 'y)))))))

(testing "operators"
(is (= '(down 1 1 1 1 1 1 1 1 1 1)
Expand Down Expand Up @@ -485,9 +499,12 @@
f3 (fn [x y] (* (tan x) (log y)))
f4 (fn [x y] (* (tan x) (sin y)))
f5 (fn [x y] (/ (tan x) (sin y)))]
(is (= '(down (* (log y) (cos x))
(/ (sin x) y))
(simplify ((D f2) 'x 'y))))
(is (= '(down 0 0)
(simplify
(g/- (s/down
(g/* (g/log 'y) (g/cos 'x))
(g// (g/sin 'x) 'y))
((D f2) 'x 'y)))))
(is (= '(down (/ (log y) (expt (cos x) 2))
(/ (tan x) y))
(simplify ((D f3) 'x 'y))))
Expand Down Expand Up @@ -1616,9 +1633,6 @@
"symbolic-taylor-series keeps the arguments symbolic, even when they
are numbers."))))

;; TODO enable when we add our gradient impl in the next PR.

#_
(deftest mixed-mode-tests
(testing "multiple input, vector output"
(let [f (fn [a b c d e f]
Expand Down Expand Up @@ -1686,3 +1700,7 @@

(deftest forward-mode-tests
(all-tests d/D d/partial))

(deftest reverse-mode-tests
(all-tests d/D-reverse
d/partial-reverse))

0 comments on commit 622ba21

Please sign in to comment.