Skip to content

Commit

Permalink
delete differential
Browse files Browse the repository at this point in the history
  • Loading branch information
sritchie committed Apr 9, 2024
1 parent d3a16ba commit 24a1466
Show file tree
Hide file tree
Showing 22 changed files with 796 additions and 2,819 deletions.
2 changes: 1 addition & 1 deletion dev/emmy/notebook.clj
Original file line number Diff line number Diff line change
Expand Up @@ -425,14 +425,14 @@ clojure -Sdeps '{:deps {io.github.mentat-collective/emmy {:git/sha \"%s\"}}}' \\
;; - [emmy.special.factorial](src/emmy/special/factorial.html)
;; - [emmy.sr.boost](src/emmy/sr/boost.html)
;; - [emmy.sr.frames](src/emmy/sr/frames.html)
;; - [emmy.tape](src/emmy/tape.html)
;; - [emmy.util](src/emmy/util.html)
;; - [emmy.util.aggregate](src/emmy/util/aggregate.html)
;; - [emmy.util.def](src/emmy/util/def.html)
;; - [emmy.util.logic](src/emmy/util/logic.html)
;; - [emmy.util.permute](src/emmy/util/permute.html)
;; - [emmy.util.stopwatch](src/emmy/util/stopwatch.html)
;; - [emmy.util.stream](src/emmy/util/stream.html)
;; - [emmy.util.vector-set](src/emmy/util/vector_set.html)

;; ## Who is using Emmy?

Expand Down
25 changes: 12 additions & 13 deletions src/emmy/abstract/function.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,13 @@
(:refer-clojure :exclude [name])
(:require #?(:clj [clojure.pprint :as pprint])
[emmy.abstract.number :as an]
[emmy.differential :as d]
[emmy.function :as f]
[emmy.generic :as g]
[emmy.matrix :as m]
[emmy.numsymb :as sym]
[emmy.polynomial]
[emmy.structure :as s]
[emmy.tape :as tape]
[emmy.tape :as ad]
[emmy.util :as u]
[emmy.value :as v])
#?(:clj
Expand Down Expand Up @@ -251,12 +250,13 @@
(->Function
fexp (f/arity f) (domain-types f) (range-type f))))

#_
(defn- forward-mode-fold [f primal-s tag]
(fn
([] (apply f primal-s))
([acc] acc)
([acc [x path _]]
(let [dx (d/tangent-part x tag)]
(let [dx (ad/tangent-part x tag)]
(if (g/numeric-zero? dx)
acc
(let [partial (literal-partial f path)]
Expand All @@ -266,9 +266,9 @@
(defn- dual-forward-mode-fold [f primal-s tag]
(fn
([] 0)
([tangent] (tape/->Dual tag (apply f primal-s) tangent))
([tangent] (ad/->Dual tag (apply f primal-s) tangent))
([tangent [x path _]]
(let [dx (tape/dual-tangent x tag)]
(let [dx (ad/dual-tangent x tag)]
(if (g/numeric-zero? dx)
tangent
(let [partial (literal-partial f path)]
Expand All @@ -279,9 +279,9 @@
(fn
([] [])
([partials]
(tape/make tag (apply f primal-s) partials))
(ad/make tag (apply f primal-s) partials))
([partials [entry path _]]
(if (and (tape/tape? entry) (= tag (tape/tape-tag entry)))
(if (and (ad/tape? entry) (= tag (ad/tape-tag entry)))
(let [partial (literal-partial f path)]
(conj partials [entry (literal-apply partial primal-s)]))
partials))))
Expand All @@ -293,11 +293,10 @@
expanded `((D f) xs)` by applying the chain rule and summing the partial
derivatives for each differential argument in the input structure."
[f s tag dx]
(let [fold-fn (cond (tape/tape? dx) reverse-mode-fold
(tape/dual? dx) dual-forward-mode-fold
(d/differential? dx) forward-mode-fold
:else (u/illegal "No tape or differential inputs."))
primal-s (s/mapr (fn [x] (tape/primal-of x tag)) s)]
(let [fold-fn (cond (ad/tape? dx) reverse-mode-fold
(ad/dual? dx) dual-forward-mode-fold
:else (u/illegal "No tape or differential inputs."))
primal-s (s/mapr (fn [x] (ad/primal-of x tag)) s)]
(s/fold-chain (fold-fn f primal-s tag) s)))

(defn- check-argument-type
Expand Down Expand Up @@ -332,7 +331,7 @@
(if-let [[tag dx] (s/fold-chain
(fn
([] [])
([acc] (apply tape/tag+perturbation acc))
([acc] (apply ad/tag+perturbation acc))
([acc [d]] (conj acc d)))
s)]
(literal-derivative f s tag dx)
Expand Down
142 changes: 88 additions & 54 deletions src/emmy/calculus/derivative.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,30 @@
"This namespace implements a number of differential operators like [[D]], and
the machinery to apply [[D]] to various structures."
(:refer-clojure :exclude [partial])
(:require [emmy.differential :as d]
[emmy.expression :as x]
(:require [emmy.expression :as x]
[emmy.function :as f]
[emmy.generic :as g]
[emmy.matrix :as matrix]
[emmy.operator :as o]
[emmy.series :as series]
[emmy.structure :as s]
[emmy.tape :as tape]
[emmy.tape :as ad]
[emmy.util :as u]
[emmy.value :as v])
#?(:clj
(:import (clojure.lang Fn MultiFn))))

;; ## IPerturbed Implementation for Functions
;;
;; The following section, along with [[emmy.collection]]
;; and [[emmy.differential]], rounds out the implementations
;; of [[d/IPerturbed]] for native Clojure(script) data types. The function
;; implementation is subtle, as described by [Manzyuk et al.
;; 2019](https://arxiv.org/pdf/1211.4892.pdf).
;; The following section, along with [[emmy.collection]] and [[emmy.tape]],
;; rounds out the implementations of [[emmy.tape/IPerturbed]] for native
;; Clojure(script) data types. The function implementation is subtle, as
;; described by [Manzyuk et al. 2019](https://arxiv.org/pdf/1211.4892.pdf).
;; ([[emmy.derivative.calculus-test]], in the "Amazing Bug" sections,
;; describes the pitfalls at length.)
;;
;; [[emmy.differential]] describes how each in-progress perturbed variable
;; in a derivative is assigned a "tag" that accumulates the variable's partial
;; [[emmy.tape]] describes how each in-progress perturbed variable in a
;; derivative is assigned a "tag" that accumulates the variable's partial
;; derivative.
;;
;; How do we interpret the case where `((D f) x)` produces a _function_?
Expand Down Expand Up @@ -70,12 +68,12 @@
;; `extract-tangent` operation with the returned function.
;;
;; The returned function needs to capture an internal reference to the
;; original [[d/Differential]] input. This is true for any Functor-shaped return
;; original [[emmy.tape/Dual]] input. This is true for any Functor-shaped return
;; value, like a structure or Map. However! There is a subtlety present with
;; functions that's not present with vectors or other containers.
;;
;; The difference with functions is that they take _inputs_. If you contrive a
;; situation where you can feed the original captured [[d/Differential]] into
;; situation where you can feed the original captured [[emmy.tape/Dual]] into
;; the returned function, this can trigger "perturbation confusion", where two
;; different layers try to extract the tangent corresponding to the SAME tag,
;; and one is left with nothing.
Expand Down Expand Up @@ -115,8 +113,8 @@
;;
;; - it extracts the originally-injected tag when someone eventually calls the
;; function
;; - if some caller passes a new [[d/Differential]] instance into the function,
;; any tags in that [[d/Differential]] will survive on their way back out...
;; - if some caller passes a new [[emmy.tape/Dual]] instance into the function,
;; any tags in that [[emmy.tape/Dual]] will survive on their way back out...
;; even if they happen to contain the originally-injected tag.
;;
;; We do this by:
Expand All @@ -127,7 +125,7 @@
;; `tag`, as requested (note now that the only instances of `tag` that can
;; appear in the result come from variables captured in the function's
;; closure)
;; - remapping `fresh` back to `tag` inside the remaining [[d/Differential]]
;; - remapping `fresh` back to `tag` inside the remaining [[emmy.tape/Dual]]
;; instance.
;;
;; This last step ensures that any tangent tagged with `tag` in the input can
Expand All @@ -151,13 +149,13 @@
occur."
[f tag]
(-> (fn [& args]
(if (d/tag-active? tag)
(let [fresh (d/fresh-tag)]
(-> (d/with-active-tag tag f (map #(d/replace-tag % tag fresh) args))
(d/extract-tangent tag)
(d/replace-tag fresh tag)))
(-> (d/with-active-tag tag f args)
(d/extract-tangent tag))))
(if (ad/tag-active? tag)
(let [fresh (ad/fresh-tag)]
(-> (ad/with-active-tag tag f (map #(ad/replace-tag % tag fresh) args))
(ad/extract-tangent tag)
(ad/replace-tag fresh tag)))
(-> (ad/with-active-tag tag f args)
(ad/extract-tangent tag))))
(f/with-arity (f/arity f))))

;; NOTE: that the tag-remapping that the docstring for `extract-tag-fn`
Expand All @@ -184,14 +182,14 @@
no tag-rerouting."
[f old new]
(-> (fn [& args]
(if (d/tag-active? old)
(let [fresh (d/fresh-tag)
args (map #(d/replace-tag % old fresh) args)]
(if (ad/tag-active? old)
(let [fresh (ad/fresh-tag)
args (map #(ad/replace-tag % old fresh) args)]
(-> (apply f args)
(d/replace-tag old new)
(d/replace-tag fresh old)))
(ad/replace-tag old new)
(ad/replace-tag fresh old)))
(-> (apply f args)
(d/replace-tag old new))))
(ad/replace-tag old new))))
(f/with-arity (f/arity f))))

;; ## Protocol Implementation
Expand All @@ -200,8 +198,7 @@
;; ClojureScript, [[MetaFn]] instances. Metadata in the original function is
;; preserved through tag replacement and extraction.

(extend-protocol d/IPerturbed

(extend-protocol ad/IPerturbed
MultiFn
(perturbed? [_] false)
(replace-tag [f old new] (replace-tag-fn f old new))
Expand Down Expand Up @@ -235,9 +232,53 @@

;; ## Single and Multivariable Calculus
;;
;; These functions put together the pieces laid out
;; in [[emmy.differential]] and declare an interface for taking
;; derivatives.
;; These functions put together the pieces laid out in [[emmy.tape]] and declare
;; an interface for taking derivatives.

;; TODO document and file.

(defn gradient
"Given some differentiable function `f`, returns a function whose value at some
point can multiply an increment in the arguments to produce the best linear
estimate of the increment in the function value.
For univariate functions, [[gradient]] computes a derivative. For
vector-valued functions, [[gradient]] computes
the [Jacobian](https://en.wikipedia.org/wiki/Jacobian_matrix_and_determinant)
of `f`.
For numerical differentiation, see [[emmy.numerical.derivative/D-numeric]].
NOTE: `f` must be built out of generic operations that know how to
handle [[emmy.tape/TapeCell]] inputs in addition to any types that a
normal `(f x)` call would present. This restriction does _not_ apply to
operations like putting `x` into a container or destructuring; just primitive
function calls."
([f] (gradient f []))
([f selectors]
(fn
([] 0)
([x]
(when (and (seq selectors) (not (s/structure? x)))
(u/illegal
(str "Selectors " selectors
" not allowed for non-structural input " x)))

(let [tag (ad/fresh-tag)
inputs (if (empty? selectors)
(ad/tapify x tag)
(update-in x selectors ad/tapify tag))
output (ad/with-active-tag tag f [inputs])
;; TODO there is an implicit sensitivity here for each run.
completed (ad/->partials output tag)]
(if (empty? selectors)
(ad/interpret inputs completed tag)
(ad/interpret (get-in inputs selectors) completed tag))))
([x & more]
((gradient (fn [args]
(apply f args))
selectors)
(matrix/seq-> (cons x more)))))))

(defn derivative
"Returns a single-argument function of that, when called with an argument `x`,
Expand All @@ -248,21 +289,15 @@
see [[emmy.numerical.derivative/D-numeric]].
`f` must be built out of generic operations that know how to
handle [[emmy.differential/Differential]] inputs in addition to any types that
a normal `(f x)` call would present. This restriction does _not_ apply to
operations like putting `x` into a container or destructuring; just primitive
function calls."
handle [[emmy.tape/Dual]] inputs in addition to any types that a normal `(f
x)` call would present. This restriction does _not_ apply to operations like
putting `x` into a container or destructuring; just primitive function calls."
[f]
(fn [x]
(let [tag (d/fresh-tag)
lifted
;; TODO if we want to do this, we can kill differentials and move the
;; tests over...
#_(d/bundle-element x 1 tag)
(emmy.tape/->Dual tag x 1)]
(-> (d/with-active-tag tag f [lifted])
(d/extract-tangent tag))
)))
(let [tag (ad/fresh-tag)
lifted (ad/->Dual tag x 1)]
(-> (ad/with-active-tag tag f [lifted])
(ad/extract-tangent tag)))))

;; The result of applying the derivative `(D f)` of a multivariable function `f`
;; to a sequence of `args` is a structure of the same shape as `args` with all
Expand Down Expand Up @@ -437,11 +472,11 @@

(doseq [t [::v/function ::s/structure]]
(defmethod g/partial-derivative [t v/seqtype] [f selectors]
#_(tape/gradient f selectors)
#_(ad/gradient f selectors)
(multivariate f selectors))

(defmethod g/partial-derivative [t nil] [f _]
#_(tape/gradient f [])
#_(ad/gradient f [])
(multivariate f [])))

;; ## Operators
Expand Down Expand Up @@ -548,13 +583,12 @@
(letfn [(process-term [term]
(g/simplify
(s/mapr (fn rec [x]
(cond (d/differential? x) (d/map-coefficients rec x)
(tape/dual? x)
(tape/->Dual (tape/dual-tag x)
(rec (tape/dual-primal x))
(rec (tape/dual-tangent x)))
(cond (ad/dual? x)
(ad/->Dual (ad/dual-tag x)
(rec (ad/dual-primal x))
(rec (ad/dual-tangent x)))

(tape/tape? x)
(ad/tape? x)
(u/illegal "TODO implement this using fmap style.")

:else (-> (g/simplify x)
Expand Down
Loading

0 comments on commit 24a1466

Please sign in to comment.