From c41c2150b8bb047513689b92d0087e02bb0132be Mon Sep 17 00:00:00 2001 From: Sam Ritchie Date: Mon, 12 Aug 2024 07:10:00 -0600 Subject: [PATCH] enable gradient --- src/emmy/calculus/derivative.cljc | 89 +++++++++++++++++++------ test/emmy/calculus/derivative_test.cljc | 7 +- 2 files changed, 72 insertions(+), 24 deletions(-) diff --git a/src/emmy/calculus/derivative.cljc b/src/emmy/calculus/derivative.cljc index 53c1a742..f189ffb4 100644 --- a/src/emmy/calculus/derivative.cljc +++ b/src/emmy/calculus/derivative.cljc @@ -22,9 +22,8 @@ ;; ## Single and Multivariable Calculus ;; -;; These functions put together the pieces laid out -;; in [[emmy.dual]] and declare an interface for taking -;; derivatives. +;; These functions put together the pieces laid out in [[emmy.dual]] and declare +;; an interface for taking derivatives. ;; The result of applying the derivative `(D f)` of a multivariable function `f` ;; to a sequence of `args` is a structure of the same shape as `args` with all @@ -70,6 +69,8 @@ " at path " path " in input structure " structure))))) +;; TODO have this jacobian be its own thing, then augment it with multi. + (defn- jacobian "Takes: @@ -108,6 +109,8 @@ ;; correctly into the supplied `input`, triggering this exception. (u/illegal (str "Bad selectors " selectors " for structure " input)))))) +;; TODO can we do something like this for gradient, with gradient in both slots?? + (defn- euclidean "Slightly more general version of [[jacobian]] that can handle a single non-structural input; dispatches to either [[jacobian]] or [[derivative]] @@ -143,6 +146,15 @@ (str "Selectors " selectors " not allowed for non-structural input " input))))))) +(defn multi [op f] + (-> (fn + ([] 0) + ([x] ((op f) x)) + ([x & more] + ((multi op (fn [xs] (apply f xs))) + (matrix/seq-> (cons x more))))) + (f/with-arity (f/arity f) {:from ::multi}))) + (defn- multivariate "Slightly wider version of [[euclidean]]. Accepts: @@ -161,15 +173,13 @@ Single-argument functions don't transform their arguments." ([f] (multivariate f [])) ([f selectors] - (let [d #(euclidean % selectors) - df (d f) - df* (d (fn [args] (apply f args)))] - (-> (fn - ([] 0) - ([x] (df x)) - ([x & more] - (df* (matrix/seq-> (cons x more))))) - (f/with-arity (f/arity f) {:from ::multivariate}))))) + (let [d #(euclidean % selectors)] + (multi d f)))) + +(defn gradient + ([f] (gradient f [])) + ([f selectors] + (multi #(tape/gradient % selectors) f))) ;; ## Generic [[g/partial-derivative]] Installation ;; @@ -197,19 +207,42 @@ ;; implementation for the components. I vote to back out this `::s/structure` ;; installation. +(def ^:dynamic *mode* d/FORWARD-MODE) + (doseq [t [::v/function ::s/structure]] (defmethod g/partial-derivative [t v/seqtype] [f selectors] - (multivariate f selectors)) + (if (= *mode* d/FORWARD-MODE) + (multivariate f selectors) + (gradient f selectors))) (defmethod g/partial-derivative [t nil] [f _] - (multivariate f []))) + (if (= *mode* d/FORWARD-MODE) + (multivariate f []) + (gradient f [])))) ;; ## Operators ;; ;; This section exposes various differential operators as [[o/Operator]] ;; instances. -(def D +(def ^{:arglists '([f])} + D-forward + (o/make-operator + (fn [x] + (binding [*mode* d/FORWARD-MODE] + (g/partial-derivative x []))) + g/derivative-symbol)) + +(def ^{:arglists '([f])} + D-reverse + (o/make-operator + (fn [x] + (binding [*mode* d/REVERSE-MODE] + (g/partial-derivative x []))) + g/derivative-symbol)) + +(def ^{:arglists '([f])} + D "Derivative operator. Takes some function `f` and returns a function whose value at some point can multiply an increment in the arguments to produce the best linear estimate of the increment in the function value. @@ -222,8 +255,7 @@ The related [[emmy.env/Grad]] returns a function that produces a structure of the opposite orientation as [[D]]. Both of these functions use forward-mode automatic differentiation." - (o/make-operator #(g/partial-derivative % []) - g/derivative-symbol)) + D-forward) (defn D-as-matrix [F] (fn [s] @@ -232,13 +264,28 @@ ((D F) s) s))) -(defn partial +(defn partial-forward + [& selectors] + (o/make-operator + (fn [x] + (binding [*mode* d/FORWARD-MODE] + (g/partial-derivative x selectors))) + `(~'partial ~@selectors))) + +(defn partial-reverse + [& selectors] + (o/make-operator + (fn [x] + (binding [*mode* d/REVERSE-MODE] + (g/partial-derivative x selectors))) + `(~'partial ~@selectors))) + +(def ^{:arglists '([& selectors])} + partial "Returns an operator that, when applied to a function `f`, produces a function that computes the partial derivative of `f` at the (zero-based) slot index provided via `selectors`." - [& selectors] - (o/make-operator #(g/partial-derivative % selectors) - `(~'partial ~@selectors))) + partial-forward) ;; ## Derivative Utilities ;; diff --git a/test/emmy/calculus/derivative_test.cljc b/test/emmy/calculus/derivative_test.cljc index ceb935e8..c76075dd 100644 --- a/test/emmy/calculus/derivative_test.cljc +++ b/test/emmy/calculus/derivative_test.cljc @@ -1616,9 +1616,6 @@ "symbolic-taylor-series keeps the arguments symbolic, even when they are numbers.")))) -;; TODO enable when we add our gradient impl in the next PR. - -#_ (deftest mixed-mode-tests (testing "multiple input, vector output" (let [f (fn [a b c d e f] @@ -1686,3 +1683,7 @@ (deftest forward-mode-tests (all-tests d/D d/partial)) + +(deftest reverse-mode-tests + (all-tests d/D-reverse + d/partial-reverse))