Skip to content

Commit

Permalink
simplify, defunary works for differential now
Browse files Browse the repository at this point in the history
  • Loading branch information
sritchie committed Apr 9, 2024
1 parent 62587b8 commit 3b8d3fc
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 47 deletions.
101 changes: 55 additions & 46 deletions src/emmy/tape.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -813,29 +813,38 @@
- a function `df:dx` that computes the derivative of `f` with respect to its
single argument
Returns a new unary function that operates on both the original type of `f`
and [[TapeCell]] instances.
Returns a new unary function that operates on both the original type of
`f`, [[TapeCell]] and [[emmy.differential/Differential]] instances.
If called without `df:dx`, `df:dx` defaults to `(f :dfdx)`; this will return
the derivative registered to a generic function defined
with [[emmy.util.def/defgeneric]].
NOTE: `df:dx` has to ALREADY be able to handle [[TapeCell]] instances. The
best way to accomplish this is by building `df:dx` out of already-lifted
functions, and declaring them by forward reference if you need to."
NOTE: `df:dx` has to ALREADY be able to handle [[TapeCell]]
and [[emmy.differential/Differential]] instances. The best way to accomplish
this is by building `df:dx` out of already-lifted functions, and declaring
them by forward reference if you need to."
([f]
(if-let [df:dx (f :dfdx)]
(lift-1 f df:dx)
(u/illegal
"No df:dx supplied for `f` or registered generically.")))
([f df:dx]
(fn call [x]
(if (tape? x)
(let [primal (tape-primal x)]
(make (tape-tag x)
(call primal)
[[x (df:dx primal)]]))
(f x)))))
(cond (tape? x)
(let [primal (tape-primal x)]
(make (tape-tag x)
(call primal)
[[x (df:dx primal)]]))

(d/differential? x)
(let [[px tx] (d/primal-tangent-pair x)
fx (call px)]
(if (g/numeric-zero? tx)
fx

Check warning on line 844 in src/emmy/tape.cljc

View check run for this annotation

Codecov / codecov/patch

src/emmy/tape.cljc#L844

Added line #L844 was not covered by tests
(d/d:+* fx (df:dx px) tx)))

:else (f x)))))

(defn lift-2
"Given:
Expand All @@ -845,13 +854,13 @@
single argument
- a function `df:dy`, similar to `df:dx` for the second arg
Returns a new binary function that operates on both the original type of `f`
and [[TapeCell]] instances.
Returns a new binary function that operates on both the original type of
`f`, [[TapeCell]] and [[emmy.differential/Differential]] instances.
NOTE: `df:dx` and `df:dy` have to ALREADY be able to handle [[TapeCell]]
instances. The best way to accomplish this is by building `df:dx` and `df:dy`
out of already-lifted functions, and declaring them by forward reference if
you need to."
and [[emmy.differential/Differential]] instances. The best way to accomplish
this is by building `df:dx` and `df:dy` out of already-lifted functions, and
declaring them by forward reference if you need to."
([f]
(let [df:dx (f :dfdx)
df:dy (f :dfdy)]
Expand Down Expand Up @@ -920,8 +929,12 @@

;; ## Generic Method Installation
;;
;; Armed with [[lift-1]] and [[lift-2]], we can install [[TapeCell]] into
;; the Emmy generic arithmetic system.
;; Armed with [[lift-1]] and [[lift-2]], we can install [[TapeCell]]
;; and [[emmy.differential/Differential]] into the Emmy generic arithmetic
;; system.
;;
;; Any function built out of these components will work with
;; the [[emmy.calculus.derivative/D]] operator.

(defn- defunary
"Given:
Expand All @@ -930,8 +943,8 @@
- optionally, a corresponding single-arity lifted function
`differential-op` (defaults to `(lift-1 generic-op)`)
installs an appropriate unary implementation of `generic-op` for `::tape`
instances."
installs an appropriate unary implementation of `generic-op` for
perturbations."
([generic-op]
(defunary generic-op (lift-1 generic-op)))
([generic-op differential-op]
Expand All @@ -944,27 +957,29 @@
- optionally, a corresponding 2-arity lifted function
`differential-op` (defaults to `(lift-2 generic-op)`)
installs an appropriate binary implementation of `generic-op` between `:tape`
and `::v/scalar` instances."
installs an appropriate binary implementation of `generic-op` between
perturbations and `::v/scalar` instances."
([generic-op]
(defbinary generic-op (lift-2 generic-op)))
([generic-op differential-op]
(doseq [signature [[::tape ::tape]
[::v/scalar ::tape]
[::tape ::v/scalar]

;; TODO does nested work if we don't have these overrides?
[::tape ::d/differential]
[::d/differential ::tape]]]
[::d/differential ::tape]
[::tape ::v/scalar]
[::v/scalar ::tape]
#_#_#_
[::d/differential ::d/differential]
[::d/differential ::v/scalar]
[::v/scalar ::d/differential]]]
(defmethod generic-op signature [a b] (differential-op a b)))))

(defn ^:no-doc by-primal
"Given some unary or binary function `f`, returns an augmented `f` that acts on
the primal entries of any [[TapeCell]] arguments encounted, irrespective of
tag.
Given a [[TapeCell]] with a [[TapeCell]] in its [[primal-part]], the returned
`f` will recursively descend until it hits a non-[[TapeCell]]."
Given a perturbation with a perturbation in its [[primal-part]], the returned
`f` will recursively descend until it hits a non-perturbation."
[f]
(fn
([x] (f (deep-primal x)))
Expand All @@ -988,9 +1003,9 @@

(defunary g/abs
(fn [x]
(let [f (deep-primal x)
func (cond (< f 0) (lift-1 (fn [x] x) (fn [_] -1))
(> f 0) (lift-1 (fn [x] x) (fn [_] 1))
(let [f (deep-primal x)
func (cond (< f 0) (lift-1 g/negate (fn [_] -1))
(> f 0) (lift-1 identity (fn [_] 1))
(= f 0) (u/illegal "Derivative of g/abs undefined at zero")
:else (u/illegal (str "error! derivative of g/abs at" x)))]
(func x))))
Expand Down Expand Up @@ -1079,17 +1094,11 @@
[(g/freeze node) (g/freeze partial)])
(tape-partials t))])

;; `simplify` explicitly does nothing because we don't want to lose our identity
;; for the topological sort. TODO this can change now... add this impl

(defmethod g/simplify [::tape] [t] t)

#_(defn foo [c d f]
(let [b (g/* c d)
e (g/+ b f)
a (g/* e b)]
a))

#_(g/simplify
(g/- ((emmy.env/D foo) 'a 'b 'c 'd)
((gradient foo) 'c 'd 'f)))
(defmethod g/simplify [::tape] [^TapeCell t]
(TapeCell. (.-tag t)
(.-id t)
(g/simplify (.-primal t))
(mapv (fn [[node partial]]
[(g/simplify node)
(g/simplify partial)])
(.-in->partial t))))
3 changes: 2 additions & 1 deletion test/emmy/differential_test.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
[emmy.generic :as g]
[emmy.numerical.derivative :refer [D-numeric]]
[emmy.simplify :refer [hermetic-simplify-fixture]]
[emmy.tape :as tape]
[emmy.value :as v]
[same.core :refer [ish? with-comparator]]))

Expand Down Expand Up @@ -557,7 +558,7 @@
(is (g/one? ((derivative g/fractional-part) x)))))))

(testing "lift-n"
(let [* (d/lift-n g/* (fn [_] 1) (fn [_ y] y) (fn [x _] x))
(let [* (tape/lift-n g/* (fn [_] 1) (fn [_ y] y) (fn [x _] x))
Df7 (derivative
(fn x**7 [x] (* x x x x x x x)))
Df1 (derivative *)
Expand Down
11 changes: 11 additions & 0 deletions test/emmy/tape_test.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -508,3 +508,14 @@
(g/simplify
((t/gradient (D f)) 'a 'b 'c 'd 'e 'f)))
"reverse-over-forward"))))


#_(defn foo [c d f]
(let [b (g/* c d)
e (g/+ b f)
a (g/* e b)]
a))

#_(g/simplify
(g/- ((emmy.env/D foo) 'a 'b 'c 'd)
((gradient foo) 'c 'd 'f)))

0 comments on commit 3b8d3fc

Please sign in to comment.