Skip to content

Commit

Permalink
enable gradient
Browse files Browse the repository at this point in the history
  • Loading branch information
sritchie committed Aug 12, 2024
1 parent d4c4c07 commit c41c215
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 24 deletions.
89 changes: 68 additions & 21 deletions src/emmy/calculus/derivative.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@

;; ## Single and Multivariable Calculus
;;
;; These functions put together the pieces laid out
;; in [[emmy.dual]] and declare an interface for taking
;; derivatives.
;; These functions put together the pieces laid out in [[emmy.dual]] and declare
;; an interface for taking derivatives.

;; The result of applying the derivative `(D f)` of a multivariable function `f`
;; to a sequence of `args` is a structure of the same shape as `args` with all
Expand Down Expand Up @@ -70,6 +69,8 @@
" at path " path
" in input structure " structure)))))

;; TODO have this jacobian be its own thing, then augment it with multi.

(defn- jacobian
"Takes:
Expand Down Expand Up @@ -108,6 +109,8 @@
;; correctly into the supplied `input`, triggering this exception.
(u/illegal (str "Bad selectors " selectors " for structure " input))))))

;; TODO can we do something like this for gradient, with gradient in both slots??

(defn- euclidean
"Slightly more general version of [[jacobian]] that can handle a single
non-structural input; dispatches to either [[jacobian]] or [[derivative]]
Expand Down Expand Up @@ -143,6 +146,15 @@
(str "Selectors " selectors
" not allowed for non-structural input " input)))))))

(defn multi [op f]
(-> (fn
([] 0)
([x] ((op f) x))
([x & more]
((multi op (fn [xs] (apply f xs)))
(matrix/seq-> (cons x more)))))
(f/with-arity (f/arity f) {:from ::multi})))

(defn- multivariate
"Slightly wider version of [[euclidean]]. Accepts:
Expand All @@ -161,15 +173,13 @@
Single-argument functions don't transform their arguments."
([f] (multivariate f []))
([f selectors]
(let [d #(euclidean % selectors)
df (d f)
df* (d (fn [args] (apply f args)))]
(-> (fn
([] 0)
([x] (df x))
([x & more]
(df* (matrix/seq-> (cons x more)))))
(f/with-arity (f/arity f) {:from ::multivariate})))))
(let [d #(euclidean % selectors)]
(multi d f))))

(defn gradient
([f] (gradient f []))
([f selectors]
(multi #(tape/gradient % selectors) f)))

;; ## Generic [[g/partial-derivative]] Installation
;;
Expand Down Expand Up @@ -197,19 +207,42 @@
;; implementation for the components. I vote to back out this `::s/structure`
;; installation.

(def ^:dynamic *mode* d/FORWARD-MODE)

(doseq [t [::v/function ::s/structure]]
(defmethod g/partial-derivative [t v/seqtype] [f selectors]
(multivariate f selectors))
(if (= *mode* d/FORWARD-MODE)
(multivariate f selectors)
(gradient f selectors)))

(defmethod g/partial-derivative [t nil] [f _]
(multivariate f [])))
(if (= *mode* d/FORWARD-MODE)
(multivariate f [])
(gradient f []))))

Check warning on line 221 in src/emmy/calculus/derivative.cljc

View check run for this annotation

Codecov / codecov/patch

src/emmy/calculus/derivative.cljc#L221

Added line #L221 was not covered by tests

;; ## Operators
;;
;; This section exposes various differential operators as [[o/Operator]]
;; instances.

(def D
(def ^{:arglists '([f])}
D-forward
(o/make-operator
(fn [x]
(binding [*mode* d/FORWARD-MODE]
(g/partial-derivative x [])))
g/derivative-symbol))

(def ^{:arglists '([f])}
D-reverse
(o/make-operator
(fn [x]
(binding [*mode* d/REVERSE-MODE]
(g/partial-derivative x [])))
g/derivative-symbol))

(def ^{:arglists '([f])}
D
"Derivative operator. Takes some function `f` and returns a function whose value
at some point can multiply an increment in the arguments to produce the best
linear estimate of the increment in the function value.
Expand All @@ -222,8 +255,7 @@
The related [[emmy.env/Grad]] returns a function that produces a structure of
the opposite orientation as [[D]]. Both of these functions use forward-mode
automatic differentiation."
(o/make-operator #(g/partial-derivative % [])
g/derivative-symbol))
D-forward)

(defn D-as-matrix [F]
(fn [s]
Expand All @@ -232,13 +264,28 @@
((D F) s)
s)))

(defn partial
(defn partial-forward
[& selectors]
(o/make-operator
(fn [x]
(binding [*mode* d/FORWARD-MODE]
(g/partial-derivative x selectors)))
`(~'partial ~@selectors)))

(defn partial-reverse
[& selectors]
(o/make-operator
(fn [x]
(binding [*mode* d/REVERSE-MODE]
(g/partial-derivative x selectors)))
`(~'partial ~@selectors)))

(def ^{:arglists '([& selectors])}
partial
"Returns an operator that, when applied to a function `f`, produces a function
that computes the partial derivative of `f` at the (zero-based) slot index
provided via `selectors`."
[& selectors]
(o/make-operator #(g/partial-derivative % selectors)
`(~'partial ~@selectors)))
partial-forward)

;; ## Derivative Utilities
;;
Expand Down
7 changes: 4 additions & 3 deletions test/emmy/calculus/derivative_test.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -1616,9 +1616,6 @@
"symbolic-taylor-series keeps the arguments symbolic, even when they
are numbers."))))

;; TODO enable when we add our gradient impl in the next PR.

#_
(deftest mixed-mode-tests
(testing "multiple input, vector output"
(let [f (fn [a b c d e f]
Expand Down Expand Up @@ -1686,3 +1683,7 @@

(deftest forward-mode-tests
(all-tests d/D d/partial))

(deftest reverse-mode-tests
(all-tests d/D-reverse
d/partial-reverse))

0 comments on commit c41c215

Please sign in to comment.