diff --git a/CHANGELOG.md b/CHANGELOG.md index 4d25b40d..ea065097 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,23 @@ ## [unreleased] +- #185: + + - adds a dynamic variable `emmy.calculus.derivative/*mode*` that allows the + user to switch between forward and reverse mode automatic differentiation + + - adds a new `emmy.calculus.derivative/gradient` that acts like + `emmy.tape/gradient` but is capable of taking multiple variables + + - adds new operators `emmy.calculus.derivative/{D-forward, D-reverse}` and + operator-returning-functions `emmy.calculus.derivative/{partial-forward, + partial-reverse}` that allow the user to explicitly invoke forward-mode or + reverse-mode automatic differentiation. `D` and `partial` still default to + forward-mode + + - modifies `emmy.tape/gradient` to correctly error when passed invalid + selectors, just like `emmy.dual/derivative`. + - #183: - adds `emmy.{autodiff, tape}` to `emmy.sci`'s exported namespace set diff --git a/src/emmy/calculus/derivative.cljc b/src/emmy/calculus/derivative.cljc index 53c1a742..57868007 100644 --- a/src/emmy/calculus/derivative.cljc +++ b/src/emmy/calculus/derivative.cljc @@ -22,9 +22,8 @@ ;; ## Single and Multivariable Calculus ;; -;; These functions put together the pieces laid out -;; in [[emmy.dual]] and declare an interface for taking -;; derivatives. +;; These functions put together the pieces laid out in [[emmy.dual]] and declare +;; an interface for taking derivatives. ;; The result of applying the derivative `(D f)` of a multivariable function `f` ;; to a sequence of `args` is a structure of the same shape as `args` with all @@ -37,7 +36,8 @@ ;; To generate the result: ;; ;; - For a single non-structural argument, return `(d/derivative f)` -;; - else, bundle up all arguments into a single [[s/Structure]] instance `xs` +;; - else, bundle up all arguments into a single [[emmy.structure/Structure]] +;; instance `xs` ;; - Generate `xs'` by replacing each entry in `xs` with `((d/derivative f') ;; entry)`, where `f'` is a function of ONLY that entry that ;; calls `(f (assoc-in xs path entry))`. In other words, replace each entry @@ -49,7 +49,7 @@ ;; above. ;; ;; [[jacobian]] handles this main logic. [[jacobian]] can only take a structural -;; input. [[euclidean]] and [[multivariate]] below widen handle, respectively, +;; input. [[euclidean]] and [[multivariate]] below handle, respectively, ;; optionally-structural and multivariable arguments. (defn- deep-partial @@ -109,12 +109,12 @@ (u/illegal (str "Bad selectors " selectors " for structure " input)))))) (defn- euclidean - "Slightly more general version of [[jacobian]] that can handle a single - non-structural input; dispatches to either [[jacobian]] or [[derivative]] - depending on the input type. + "Slightly more general version of [[jacobian]] that can handle a single input; + dispatches to either [[jacobian]] or [[derivative]] depending on whether or + not the input is structural. If you pass non-empty `selectors`, the returned function will throw if it - receives a non-structural, non-numerical argument." + receives a non-structural, non-scalar argument." ([f] (euclidean f [])) ([f selectors] (let [selectors (vec selectors)] @@ -143,6 +143,28 @@ (str "Selectors " selectors " not allowed for non-structural input " input))))))) +(defn- multi + "Given + + - some higher-order function `op` that transforms a function of a single + variable into another function of a single variable + - function `f` capable of taking multiple arguments + + returns a new function that acts like `(op f)` but can take multiple + arguments. + + When passed multiple arguments, the returned functon packages them into a + single `[[emmy.structure/up]]` instance. Any [[emmy.matrix/Matrix]] present in + the argument list will be converted into a `down` of `up`s (a row of columns)." + [op f] + (-> (fn + ([] 0) + ([x] ((op f) x)) + ([x & more] + ((multi op (fn [xs] (apply f xs))) + (matrix/seq-> (cons x more))))) + (f/with-arity (f/arity f) {:from ::multi}))) + (defn- multivariate "Slightly wider version of [[euclidean]]. Accepts: @@ -152,24 +174,39 @@ And returns a new function that computes either the full [Jacobian](https://en.wikipedia.org/wiki/Jacobian_matrix_and_determinant) - or the entry at `selectors`. + or the entry at `selectors` using [forward-mode automatic + differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation#Forward_accumulation). Any multivariable function will have its argument vector coerced into an `up` - structure. Any [[matrix/Matrix]] in a multiple-arg function call will be + structure. Any [[emmy.matrix/Matrix]] in a multiple-arg function call will be converted into a `down` of `up`s (a row of columns). - Single-argument functions don't transform their arguments." + Arguments to single-variable functions are not transformed." ([f] (multivariate f [])) ([f selectors] - (let [d #(euclidean % selectors) - df (d f) - df* (d (fn [args] (apply f args)))] - (-> (fn - ([] 0) - ([x] (df x)) - ([x & more] - (df* (matrix/seq-> (cons x more))))) - (f/with-arity (f/arity f) {:from ::multivariate}))))) + (let [d #(euclidean % selectors)] + (multi d f)))) + +(defn gradient + "Accepts: + + - some function `f` of potentially many arguments + - optionally, a sequence of selectors meant to index into the structural + argument, or argument vector, of `f` + + And returns a new function that computes either the + full [Jacobian](https://en.wikipedia.org/wiki/Jacobian_matrix_and_determinant) + or the entry at `selectors` using [reverse-mode automatic + differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation#Reverse_accumulation). + + Any multivariable function will have its argument vector coerced into an `up` + structure. Any [[emmy.matrix/Matrix]] in a multiple-arg function call will be + converted into a `down` of `up`s (a row of columns). + + Arguments to single-variable functions are not transformed." + ([f] (gradient f [])) + ([f selectors] + (multi #(tape/gradient % selectors) f))) ;; ## Generic [[g/partial-derivative]] Installation ;; @@ -192,24 +229,66 @@ ;; passed to the structure of functions, instead of separately for every entry ;; in the structure. ;; +;; A dynamic variable controls whether or not this process uses forward-mode or +;; reverse-mode AD. +;; ;; TODO: I think this is going to cause problems for, say, a Structure of ;; PowerSeries, where there is actually a cheap `g/partial-derivative` ;; implementation for the components. I vote to back out this `::s/structure` ;; installation. +(def ^:dynamic *mode* d/FORWARD-MODE) + (doseq [t [::v/function ::s/structure]] (defmethod g/partial-derivative [t v/seqtype] [f selectors] - (multivariate f selectors)) + (if (= *mode* d/FORWARD-MODE) + (multivariate f selectors) + (gradient f selectors))) (defmethod g/partial-derivative [t nil] [f _] - (multivariate f []))) + (if (= *mode* d/FORWARD-MODE) + (multivariate f []) + (gradient f [])))) ;; ## Operators ;; -;; This section exposes various differential operators as [[o/Operator]] -;; instances. +;; This section exposes various differential operators +;; as [[emmy.operator/Operator]] instances. + +(def ^{:arglists '([f])} + D-forward + "Forward-mode derivative operator. Takes some function `f` and returns a + function whose value at some point can multiply an increment in the arguments + to produce the best linear estimate of the increment in the function value. -(def D + For univariate functions, [[D-forward]] computes a derivative. For vector-valued + functions, [[D-forward]] computes + the [Jacobian](https://en.wikipedia.org/wiki/Jacobian_matrix_and_determinant) + of `f`." + (o/make-operator + (fn [x] + (binding [*mode* d/FORWARD-MODE] + (g/partial-derivative x []))) + g/derivative-symbol)) + +(def ^{:arglists '([f])} + D-reverse + "Reverse-mode derivative operator. Takes some function `f` and returns a + function whose value at some point can multiply an increment in the arguments + to produce the best linear estimate of the increment in the function value. + + For univariate functions, [[D-reverse]] computes a derivative. For vector-valued + functions, [[D-reverse]] computes + the [Jacobian](https://en.wikipedia.org/wiki/Jacobian_matrix_and_determinant) + of `f`." + (o/make-operator + (fn [x] + (binding [*mode* d/REVERSE-MODE] + (g/partial-derivative x []))) + g/derivative-symbol)) + +(def ^{:arglists '([f])} + D "Derivative operator. Takes some function `f` and returns a function whose value at some point can multiply an increment in the arguments to produce the best linear estimate of the increment in the function value. @@ -222,8 +301,7 @@ The related [[emmy.env/Grad]] returns a function that produces a structure of the opposite orientation as [[D]]. Both of these functions use forward-mode automatic differentiation." - (o/make-operator #(g/partial-derivative % []) - g/derivative-symbol)) + D-forward) (defn D-as-matrix [F] (fn [s] @@ -232,13 +310,34 @@ ((D F) s) s))) -(defn partial +(defn partial-forward + "Returns an operator that, when applied to a function `f`, produces a function + that uses forward-mode automatic differentiation to compute the partial + derivative of `f` at the (zero-based) slot index provided via `selectors`." + [& selectors] + (o/make-operator + (fn [x] + (binding [*mode* d/FORWARD-MODE] + (g/partial-derivative x selectors))) + `(~'partial ~@selectors))) + +(defn partial-reverse "Returns an operator that, when applied to a function `f`, produces a function - that computes the partial derivative of `f` at the (zero-based) slot index - provided via `selectors`." + that uses reverse-mode automatic differentiation to compute the partial + derivative of `f` at the (zero-based) slot index provided via `selectors`." [& selectors] - (o/make-operator #(g/partial-derivative % selectors) - `(~'partial ~@selectors))) + (o/make-operator + (fn [x] + (binding [*mode* d/REVERSE-MODE] + (g/partial-derivative x selectors))) + `(~'partial ~@selectors))) + +(def ^{:arglists '([& selectors])} + partial + "Returns an operator that, when applied to a function `f`, produces a function + that uses forward-mode automatic differentiation to compute the partial + derivative of `f` at the (zero-based) slot index provided via `selectors`." + partial-forward) ;; ## Derivative Utilities ;; diff --git a/src/emmy/tape.cljc b/src/emmy/tape.cljc index 8df7e7b6..3f6e281c 100644 --- a/src/emmy/tape.cljc +++ b/src/emmy/tape.cljc @@ -512,16 +512,20 @@ (u/illegal (str "Selectors " selectors " not allowed for non-structural input " x))) - (let [tag (d/fresh-tag) - inputs (if (empty? selectors) - (tapify x tag) - (update-in x selectors tapify tag)) - output (d/with-active-tag tag f [inputs]) + input (if-let [piece (get-in x selectors)] + (if (empty? selectors) + (tapify piece tag) + (assoc-in x selectors (tapify piece tag))) + ;; The call to `get-in` will return nil if the + ;; `selectors` don't index correctly into the supplied + ;; `input`, triggering this exception. + (u/illegal + (str "Bad selectors " selectors " for structure " x))) + output (d/with-active-tag tag f [input]) completed (d/extract-tangent output tag d/REVERSE-MODE)] - (if (empty? selectors) - (interpret inputs completed tag) - (interpret (get-in inputs selectors) completed tag))))))) + (-> (get-in input selectors) + (interpret completed tag))))))) (defmethod g/zero-like [::tape] [_] 0) (defmethod g/one-like [::tape] [_] 1) diff --git a/test/emmy/calculus/derivative_test.cljc b/test/emmy/calculus/derivative_test.cljc index ceb935e8..c146d809 100644 --- a/test/emmy/calculus/derivative_test.cljc +++ b/test/emmy/calculus/derivative_test.cljc @@ -446,13 +446,27 @@ (testing "space" (let [g (af/literal-function 'g [0 0] 0) h (af/literal-function 'h [0 0] 0)] - (is (= '(+ (((partial 0) g) x y) (((partial 0) h) x y)) - (simplify (((partial 0) (+ g h)) 'x 'y)))) - (is (= '(+ (* (((partial 0) g) x y) (h x y)) (* (((partial 0) h) x y) (g x y))) - (simplify (((partial 0) (* g h)) 'x 'y)))) - (is (= '(+ (* (((partial 0) g) x y) (h x y) (expt (g x y) (+ (h x y) -1))) - (* (((partial 0) h) x y) (log (g x y)) (expt (g x y) (h x y)))) - (simplify (((partial 0) (g/expt g h)) 'x 'y)))))) + (is (zero? + (simplify + (g/- (g/+ (((partial 0) g) 'x 'y) + (((partial 0) h) 'x 'y)) + (((partial 0) (+ g h)) 'x 'y))))) + (is (zero? + (simplify + (g/- + (g/+ (g/* (((partial 0) g) 'x 'y) (h 'x 'y)) + (g/* (((partial 0) h) 'x 'y) (g 'x 'y))) + (((partial 0) (* g h)) 'x 'y))))) + (is (zero? + (simplify + (g/- + (g/+ (g/* (((partial 0) g) 'x 'y) + (h 'x 'y) + (g/expt (g 'x 'y) (+ (h 'x 'y) -1))) + (g/* (((partial 0) h) 'x 'y) + (g/log (g 'x 'y)) + (g/expt (g 'x 'y) (h 'x 'y)))) + (((partial 0) (g/expt g h)) 'x 'y))))))) (testing "operators" (is (= '(down 1 1 1 1 1 1 1 1 1 1) @@ -485,9 +499,12 @@ f3 (fn [x y] (* (tan x) (log y))) f4 (fn [x y] (* (tan x) (sin y))) f5 (fn [x y] (/ (tan x) (sin y)))] - (is (= '(down (* (log y) (cos x)) - (/ (sin x) y)) - (simplify ((D f2) 'x 'y)))) + (is (= '(down 0 0) + (simplify + (g/- (s/down + (g/* (g/log 'y) (g/cos 'x)) + (g// (g/sin 'x) 'y)) + ((D f2) 'x 'y))))) (is (= '(down (/ (log y) (expt (cos x) 2)) (/ (tan x) y)) (simplify ((D f3) 'x 'y)))) @@ -1616,9 +1633,6 @@ "symbolic-taylor-series keeps the arguments symbolic, even when they are numbers.")))) -;; TODO enable when we add our gradient impl in the next PR. - -#_ (deftest mixed-mode-tests (testing "multiple input, vector output" (let [f (fn [a b c d e f] @@ -1686,3 +1700,7 @@ (deftest forward-mode-tests (all-tests d/D d/partial)) + +(deftest reverse-mode-tests + (all-tests d/D-reverse + d/partial-reverse))