Skip to content

Commit

Permalink
Change default derivative impl to reverse mode (#170)
Browse files Browse the repository at this point in the history
- #170:

- changes `D` and `partial` to use reverse mode automatic
differentiation by
    default, and fixes all associated tests

- adds `emmy.generic/{zero?,one?,identity?}` implementations (all false)
to
`emmy.tape/Completed`, in case some collection type tries to simplify
these
    during reverse-mode AD
  • Loading branch information
sritchie authored Aug 13, 2024
1 parent 622ba21 commit cac1791
Show file tree
Hide file tree
Showing 19 changed files with 182 additions and 167 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,15 @@

## [unreleased]

- #170:

- changes `D` and `partial` to use reverse mode automatic differentiation by
default, and fixes all associated tests

- adds `emmy.generic/{zero?,one?,identity?}` implementations (all false) to
`emmy.tape/Completed`, in case some collection type tries to simplify these
during reverse-mode AD

- #185:

- adds a dynamic variable `emmy.calculus.derivative/*mode*` that allows the
Expand Down
16 changes: 8 additions & 8 deletions src/emmy/calculus/derivative.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@
;; implementation for the components. I vote to back out this `::s/structure`
;; installation.

(def ^:dynamic *mode* d/FORWARD-MODE)
(def ^:dynamic *mode* d/REVERSE-MODE)

(doseq [t [::v/function ::s/structure]]
(defmethod g/partial-derivative [t v/seqtype] [f selectors]
Expand Down Expand Up @@ -266,9 +266,9 @@
the [Jacobian](https://en.wikipedia.org/wiki/Jacobian_matrix_and_determinant)
of `f`."
(o/make-operator
(fn [x]
(fn [f]
(binding [*mode* d/FORWARD-MODE]
(g/partial-derivative x [])))
(g/partial-derivative f [])))
g/derivative-symbol))

(def ^{:arglists '([f])}
Expand All @@ -282,9 +282,9 @@
the [Jacobian](https://en.wikipedia.org/wiki/Jacobian_matrix_and_determinant)
of `f`."
(o/make-operator
(fn [x]
(fn [f]
(binding [*mode* d/REVERSE-MODE]
(g/partial-derivative x [])))
(g/partial-derivative f [])))
g/derivative-symbol))

(def ^{:arglists '([f])}
Expand All @@ -299,9 +299,9 @@
of `f`.
The related [[emmy.env/Grad]] returns a function that produces a structure of
the opposite orientation as [[D]]. Both of these functions use forward-mode
the opposite orientation as [[D]]. Both of these functions use reverse-mode
automatic differentiation."
D-forward)
D-reverse)

(defn D-as-matrix [F]
(fn [s]
Expand Down Expand Up @@ -337,7 +337,7 @@
"Returns an operator that, when applied to a function `f`, produces a function
that uses forward-mode automatic differentiation to compute the partial
derivative of `f` at the (zero-based) slot index provided via `selectors`."
partial-forward)
partial-reverse)

;; ## Derivative Utilities
;;
Expand Down
1 change: 0 additions & 1 deletion src/emmy/calculus/manifold.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,6 @@
(g/+ (g/square x)
(g/square y)
(g/square z)))]
(println "r is" r)
(when (g/zero? r)
(u/illegal-state "SphericalCylindrical singular"))
(-> rep
Expand Down
5 changes: 2 additions & 3 deletions src/emmy/calculus/vector_calculus.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,9 @@
The related [[emmy.env/D]] operator returns a function that produces a
structure of the opposite orientation as [[Grad]]. Both of these functions use
forward-mode automatic differentiation."
reverse-mode automatic differentiation."
(-> (fn [f]
(f/compose s/opposite
(g/partial-derivative f [])))
(f/compose s/opposite (d/D f)))
(o/make-operator 'Grad)))

(defn gradient
Expand Down
7 changes: 7 additions & 0 deletions src/emmy/dual.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,13 @@
(def ^:const REVERSE-MODE ::reverse)
(def ^:const REVERSE-EMPTY (->Completed {}))

;; These are here to handle cases where some collection type might see instances
;; of [[Completed]] and try and handle them during simplification.

(defmethod g/zero? [Completed] [_] false)
(defmethod g/one? [Completed] [_] false)
(defmethod g/identity? [Completed] [_] false)

;; `replace-tag` exists to handle subtle bugs that can arise in the case of
;; functional return values. See the "Amazing Bug" sections
;; in [[emmy.calculus.derivative-test]] for detailed examples on how this might
Expand Down
18 changes: 9 additions & 9 deletions src/emmy/mechanics/hamilton.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -224,10 +224,9 @@
(g/zero?
(g/simplify
(g/determinant M))))
(do (println "determinant" (g/determinant M))
(throw
(ex-info "Legendre Transform Failure: determinant = 0"
{:F F :w w})))
(throw
(ex-info "Legendre Transform Failure: determinant = 0"
{:F F :w w}))
(let [v (g/solve-linear-left M (- w b))]
(- (* w v) (F v))))))]
(let [Dpg (D putative-G)]
Expand Down Expand Up @@ -442,11 +441,12 @@
"p.326"
[C]
(fn [s]
((- J-func
(f/compose (Phi ((D C) s))
J-func
(Phi* ((D C) s))))
(s/compatible-shape s))))
(let [s-syms (s/compatible-shape s)]
((- J-func
(f/compose (Phi ((D C) s))
J-func
(Phi* ((D C) s))))
s-syms))))

;; Time-Varying code
;;
Expand Down
4 changes: 2 additions & 2 deletions src/emmy/numbers.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@
(g/infinite? a) 0
:else (g// (g/sin a) a)))

(defmethod g/sin [::v/real] [a] (Math/sin a))
(defmethod g/cos [::v/real] [a] (Math/cos a))
(defmethod g/sin [::v/real] [a] (if (g/zero? a) 0 (Math/sin a)))
(defmethod g/cos [::v/real] [a] (if (g/zero? a) 1 (Math/cos a)))
(defmethod g/tan [::v/real] [a] (Math/tan a))

(defmethod g/cosh [::v/real] [a] (Math/cosh a))
Expand Down
11 changes: 7 additions & 4 deletions test/emmy/abstract/number_test.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,13 @@
values like (cos 0) can be exactly evaluated.")))

(checking "inexact numbers" 100 [x sg/native-integral]
(is (= (+ 1 (Math/cos x))
(g/+ 1 (g/cos x)))
"You get a floating-point inexact result by calling generic fns
on a number directly, by comparison."))
(if (g/zero? x)
(is (= 2 (g/+ 1 (g/cos x)))
"special-cased exact output at 0")
(is (= (+ 1 (Math/cos x))
(g/+ 1 (g/cos x)))
"You get a floating-point inexact result by calling generic fns
on a number directly, by comparison.")))

(testing "literal-number properly prints wrapped sequences, even if they're lazy."
(is (= "(* 2 x)"
Expand Down
72 changes: 36 additions & 36 deletions test/emmy/calculus/covariant_test.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -48,24 +48,24 @@
(-> (simplify expr)
(x/substitute '(up x0 y0 z0) 'p)))]

(is (= '(+ (* (Y↑0 p) (((partial 0) w_0) p) (X↑0 p))
(* (Y↑0 p) (w_0 p) (((partial 0) X↑0) p))
(* (Y↑0 p) (w_1 p) (((partial 0) X↑1) p))
(* (Y↑0 p) (w_2 p) (((partial 0) X↑2) p))
(* (Y↑0 p) (((partial 1) w_0) p) (X↑1 p))
(* (Y↑0 p) (((partial 2) w_0) p) (X↑2 p))
(* (X↑0 p) (Y↑1 p) (((partial 0) w_1) p))
(* (X↑0 p) (Y↑2 p) (((partial 0) w_2) p))
(* (w_0 p) (Y↑1 p) (((partial 1) X↑0) p))
(* (w_0 p) (Y↑2 p) (((partial 2) X↑0) p))
(* (w_1 p) (Y↑1 p) (((partial 1) X↑1) p))
(* (w_1 p) (Y↑2 p) (((partial 2) X↑1) p))
(is (= '(+ (* (w_2 p) (Y↑2 p) (((partial 2) X↑2) p))
(* (w_2 p) (Y↑1 p) (((partial 1) X↑2) p))
(* (w_2 p) (Y↑2 p) (((partial 2) X↑2) p))
(* (X↑1 p) (Y↑1 p) (((partial 1) w_1) p))
(* (X↑1 p) (Y↑2 p) (((partial 1) w_2) p))
(* (X↑2 p) (Y↑1 p) (((partial 2) w_1) p))
(* (X↑2 p) (Y↑2 p) (((partial 2) w_2) p)))
(* (w_2 p) (Y↑0 p) (((partial 0) X↑2) p))
(* (Y↑2 p) (((partial 0) w_2) p) (X↑0 p))
(* (Y↑2 p) (w_1 p) (((partial 2) X↑1) p))
(* (Y↑2 p) (w_0 p) (((partial 2) X↑0) p))
(* (Y↑2 p) (((partial 1) w_2) p) (X↑1 p))
(* (Y↑2 p) (((partial 2) w_2) p) (X↑2 p))
(* (Y↑1 p) (X↑0 p) (((partial 0) w_1) p))
(* (Y↑1 p) (w_1 p) (((partial 1) X↑1) p))
(* (Y↑1 p) (w_0 p) (((partial 1) X↑0) p))
(* (Y↑1 p) (X↑1 p) (((partial 1) w_1) p))
(* (Y↑1 p) (X↑2 p) (((partial 2) w_1) p))
(* (Y↑0 p) (X↑0 p) (((partial 0) w_0) p))
(* (Y↑0 p) (w_1 p) (((partial 0) X↑1) p))
(* (Y↑0 p) (w_0 p) (((partial 0) X↑0) p))
(* (Y↑0 p) (X↑1 p) (((partial 1) w_0) p))
(* (Y↑0 p) (X↑2 p) (((partial 2) w_0) p)))
(present
((((g/Lie-derivative X) w) Y) R3-rect-point))))

Expand All @@ -92,14 +92,14 @@
(-> (simplify expr)
(x/substitute '(up x0 y0) 'p)))]

(is (= '(+ (* -1 (Y↑0 p) (((partial 0) f) p) (((partial 0) X↑0) p))
(* -1 (Y↑0 p) (((partial 1) f) p) (((partial 0) X↑1) p))
(* (((partial 0) f) p) (X↑1 p) (((partial 1) Y↑0) p))
(* (((partial 0) f) p) (((partial 0) Y↑0) p) (X↑0 p))
(* -1 (((partial 0) f) p) (Y↑1 p) (((partial 1) X0) p))
(* (X↑1 p) (((partial 1) f) p) (((partial 1) Y↑1) p))
(* (((partial 1) f) p) (X↑0 p) (((partial 0) Y↑1) p))
(* -1 (((partial 1) f) p) (Y↑1 p) (((partial 1) X↑1) p)))
(is (= '(+ (* (((partial 1) f) p) (((partial 0) Y↑1) p) (X↑0 p))
(* -1 (((partial 1) f) p) (Y↑1 p) (((partial 1) X↑1) p))
(* -1 (((partial 1) f) p) (Y↑0 p) (((partial 0) X↑1) p))
(* (((partial 1) f) p) (((partial 1) Y↑1) p) (X↑1 p))
(* (X↑0 p) (((partial 0) f) p) (((partial 0) Y0) p))
(* -1 (Y↑1 p) (((partial 0) f) p) (((partial 1) X↑0) p))
(* -1 (Y↑0 p) (((partial 0) f) p) (((partial 0) X↑0) p))
(* (X↑1 p) (((partial 0) f) p) (((partial 1) Y↑0) p)))
(present
((((g/Lie-derivative X) Y) f) R2-rect-point))))

Expand All @@ -114,7 +114,7 @@

;; we only need linear terms in phi_t(x)

;; Perhaps
;; Perhaps

;; phi_t(x) = (I + t v(I))(x)

Expand Down Expand Up @@ -186,14 +186,14 @@
present (fn [expr]
(-> (simplify expr)
(x/substitute '(up q_x q_y) 'p)))]
(is (= '(+ (* -1 (Y↑0 p) (((partial 0) f) p) (((partial 0) X↑0) p))
(* -1 (Y↑0 p) (((partial 1) f) p) (((partial 0) X↑1) p))
(* (((partial 0) f) p) (((partial 0) Y↑0) p) (X↑0 p))
(* -1 (((partial 0) f) p) (Y↑1 p) (((partial 1) X↑0) p))
(* (((partial 0) f) p) (((partial 1) Y↑0) p) (X↑1 p))
(* (((partial 1) f) p) (X↑0 p) (((partial 0) Y↑1) p))
(is (= '(+ (* (((partial 1) f) p) (((partial 0) Y↑1) p) (X↑0 p))
(* -1 (((partial 1) f) p) (Y↑1 p) (((partial 1) X↑1) p))
(* (((partial 1) f) p) (X↑1 p) (((partial 1) Y↑1) p)))
(* -1 (((partial 1) f) p) (Y↑0 p) (((partial 0) X↑1) p))
(* (((partial 1) f) p) (((partial 1) Y↑1) p) (X↑1 p))
(* (X↑0 p) (((partial 0) f) p) (((partial 0) Y↑0) p))
(* -1 (Y↑1 p) (((partial 0) f) p) (((partial 1) X↑0) p))
(* -1 (Y↑0 p) (((partial 0) f) p) (((partial 0) X↑0) p))
(* (X↑1 p) (((partial 0) f) p) (((partial 1) Y↑0) p)))
(present
((D (fn [t]
(- ((Y f) ((phiX t) m_0))
Expand All @@ -207,7 +207,7 @@
((Y (compose f (phiX t))) m_0))))
0)))))

(is (= 0 (simplify
(is (= 0 (simplify
(- result-via-Lie
((D (fn [t]
(- ((Y f) ((phiX t) m_0))
Expand Down Expand Up @@ -750,7 +750,7 @@
(is (= '(down
(+ (* -1 (cos (theta t)) (sin (theta t)) (expt ((D phi) t) 2))
(((expt D 2) theta) t))
(+ (* 2 (cos (theta t)) ((D theta) t) (sin (theta t)) ((D phi) t))
(+ (* 2 (cos (theta t)) (sin (theta t)) ((D phi) t) ((D theta) t))
(* (expt (sin (theta t)) 2) (((expt D 2) phi) t))))
(g/freeze
(simplify
Expand Down Expand Up @@ -813,7 +813,7 @@
;; So \Gammar_{\theta \theta} = -r, \Gamma\theta_{\theta \theta} = 0
;;
;; These are correct Christoffel symbols...
)))))
)))))

(defn CD
"Computation of Covariant derivatives by difference quotient. [[CD]] is parallel
Expand Down
40 changes: 20 additions & 20 deletions test/emmy/calculus/form_field_test.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -331,18 +331,18 @@
((((g/square ff/d) (m/literal-scalar-field 'f R3-rect)) X Y)
R3-cyl-point))))

(is (= '(+ (* (X↑0 p) (Y↑1 p) (((partial 0) w_1) p))
(* -1 (X↑0 p) (Y↑1 p) (((partial 1) w_0) p))
(* (X↑0 p) (Y↑2 p) (((partial 0) w_2) p))
(is (= '(+ (* (X↑0 p) (Y↑2 p) (((partial 0) w_2) p))
(* -1 (X↑0 p) (Y↑2 p) (((partial 2) w_0) p))
(* -1 (X↑1 p) (((partial 0) w_1) p) (Y↑0 p))
(* (X↑1 p) (((partial 1) w_0) p) (Y↑0 p))
(* (X↑0 p) (Y↑1 p) (((partial 0) w_1) p))
(* -1 (X↑0 p) (Y↑1 p) (((partial 1) w_0) p))
(* (X↑1 p) (Y↑2 p) (((partial 1) w_2) p))
(* -1 (X↑1 p) (Y↑2 p) (((partial 2) w_1) p))
(* -1 (X↑2 p) (Y↑1 p) (((partial 1) w_2) p))
(* (X↑2 p) (Y↑1 p) (((partial 2) w_1) p))
(* -1 (X↑1 p) (((partial 0) w_1) p) (Y↑0 p))
(* (X↑1 p) (((partial 1) w_0) p) (Y↑0 p))
(* -1 (X↑2 p) (((partial 0) w_2) p) (Y↑0 p))
(* (X↑2 p) (((partial 2) w_0) p) (Y↑0 p)))
(* (X↑2 p) (((partial 2) w_0) p) (Y↑0 p))
(* -1 (X↑2 p) (Y↑1 p) (((partial 1) w_2) p))
(* (X↑2 p) (Y↑1 p) (((partial 2) w_1) p)))
(-> (((ff/d w) X Y) R3-rect-point)
(simplify)
(x/substitute '(up x0 y0 z0) 'p))))
Expand All @@ -353,24 +353,24 @@
(ff/wedge dy dz))
(* (m/literal-scalar-field 'omega_2 R3-rect)
(ff/wedge dz dx)))]
(is (= '(+ (* (X↑0 p) (Y↑1 p) (Z↑2 p) (((partial 0) omega_1) p))
(* (X↑0 p) (Y↑1 p) (Z↑2 p) (((partial 1) omega_2) p))
(* (X↑0 p) (Y↑1 p) (Z↑2 p) (((partial 2) omega_0) p))
(* -1 (X↑0 p) (Y↑2 p) (((partial 0) omega_1) p) (Z↑1 p))
(* -1 (X↑0 p) (Y↑2 p) (((partial 1) omega_2) p) (Z↑1 p))
(* -1 (X↑0 p) (Y↑2 p) (((partial 2) omega_0) p) (Z↑1 p))
(is (= '(+ (* -1 (X↑0 p) (Y↑2 p) (Z↑1 p) (((partial 0) omega_1) p))
(* -1 (X↑0 p) (Y↑2 p) (Z↑1 p) (((partial 1) omega_2) p))
(* -1 (X↑0 p) (Y↑2 p) (Z↑1 p) (((partial 2) omega_0) p))
(* (X↑0 p) (Y↑1 p) (((partial 0) omega_1) p) (Z↑2 p))
(* (X↑0 p) (Y↑1 p) (((partial 1) omega_2) p) (Z↑2 p))
(* (X↑0 p) (Y↑1 p) (((partial 2) omega_0) p) (Z↑2 p))
(* (X↑1 p) (Y↑2 p) (((partial 0) omega_1) p) (Z↑0 p))
(* (X↑1 p) (Y↑2 p) (((partial 1) omega_2) p) (Z↑0 p))
(* (X↑1 p) (Y↑2 p) (((partial 2) omega_0) p) (Z↑0 p))
(* -1 (X↑1 p) (Y↑0 p) (Z↑2 p) (((partial 0) omega_1) p))
(* -1 (X↑1 p) (Y↑0 p) (Z↑2 p) (((partial 1) omega_2) p))
(* -1 (X↑1 p) (Y↑0 p) (Z↑2 p) (((partial 2) omega_0) p))
(* -1 (X↑1 p) (Y↑0 p) (((partial 0) omega_1) p) (Z↑2 p))
(* -1 (X↑1 p) (Y↑0 p) (((partial 1) omega_2) p) (Z↑2 p))
(* -1 (X↑1 p) (Y↑0 p) (((partial 2) omega_0) p) (Z↑2 p))
(* -1 (X↑2 p) (Y↑1 p) (((partial 0) omega_1) p) (Z↑0 p))
(* -1 (X↑2 p) (Y↑1 p) (((partial 1) omega_2) p) (Z↑0 p))
(* -1 (X↑2 p) (Y↑1 p) (((partial 2) omega_0) p) (Z↑0 p))
(* (X↑2 p) (Y↑0 p) (((partial 0) omega_1) p) (Z↑1 p))
(* (X↑2 p) (Y↑0 p) (((partial 1) omega_2) p) (Z↑1 p))
(* (X↑2 p) (Y↑0 p) (((partial 2) omega_0) p) (Z↑1 p)))
(* (X↑2 p) (Y↑0 p) (Z↑1 p) (((partial 0) omega_1) p))
(* (X↑2 p) (Y↑0 p) (Z↑1 p) (((partial 1) omega_2) p))
(* (X↑2 p) (Y↑0 p) (Z↑1 p) (((partial 2) omega_0) p)))
(-> (((ff/d omega) X Y Z) R3-rect-point)
(simplify)
(x/substitute '(up x0 y0 z0) 'p))))
Expand Down
20 changes: 10 additions & 10 deletions test/emmy/calculus/map_test.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,10 @@ and the differentials of coordinate functions."
μ (m/literal-manifold-map 'μ R1-rect R2-rect)
f (man/literal-manifold-function 'f R2-rect)]

(is (= '(+ (* (((partial 0) f) (up (μ↑0 τ) (μ↑1 τ)))
((D μ↑0) τ))
(* (((partial 1) f) (up (μ↑0 τ) (μ↑1 τ)))
((D μ↑1) τ)))
(is (= '(+ (* (((partial 1) f) (up (μ↑0 τ) (μ↑1 τ)))
((D μ↑1) τ))
(* (((partial 0) f) (up (μ↑0 τ) (μ↑1 τ)))
((D μ↑0) τ)))
(simplify
((((m/differential μ) d:dt) f)
((man/point R1-rect) 'τ)))))
Expand All @@ -120,10 +120,10 @@ and the differentials of coordinate functions."

;; "However, if we kludge the correct argument it gives the expected
;; answer."
(is (= '(/ (+ (* ((D μ↑0) τ) (e1↑1 (up x0 y0)))
(* -1 ((D μ↑1) τ) (e1↑0 (up x0 y0))))
(+ (* (e1↑1 (up x0 y0)) (e0↑0 (up x0 y0)))
(* -1 (e1↑0 (up x0 y0)) (e0↑1 (up x0 y0)))))
(is (= '(/ (+ (* -1 ((D μ↑1) τ) (e1↑0 (up x0 y0)))
(* ((D μ↑0) τ) (e1↑1 (up x0 y0))))
(+ (* -1 (e1↑0 (up x0 y0)) (e0↑1 (up x0 y0)))
(* (e1↑1 (up x0 y0)) (e0↑0 (up x0 y0)))))
(simplify
(((nth edual 0)
(vf/procedure->vector-field
Expand All @@ -142,8 +142,8 @@ and the differentials of coordinate functions."
(man/chart R1-rect))
f (f/compose (af/literal-function 'f '(-> (UP Real Real) Real))
(man/chart S2-spherical))]
(is (= '(+ (* (((partial 0) f) (up (θ τ) (φ τ))) ((D θ) τ))
(* (((partial 1) f) (up (θ τ) (φ τ))) ((D φ) τ)))
(is (= '(+ (* (((partial 1) f) (up (θ τ) (φ τ))) ((D φ) τ))
(* (((partial 0) f) (up (θ τ) (φ τ))) ((D θ) τ)))
(simplify ((((m/differential μ) d:dt) f)
((man/point R1-rect) 'τ)))))

Expand Down
Loading

0 comments on commit cac1791

Please sign in to comment.