Skip to content

Commit

Permalink
dual vs differential works
Browse files Browse the repository at this point in the history
  • Loading branch information
sritchie committed Apr 9, 2024
1 parent 1049fde commit 05f0773
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 28 deletions.
15 changes: 14 additions & 1 deletion src/emmy/abstract/function.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -257,12 +257,24 @@
([acc] acc)
([acc [x path _]]
(let [dx (d/tangent-part x tag)]
(if (g/zero? dx)
(if (g/numeric-zero? dx)
acc
(let [partial (literal-partial f path)]
(d/d:+ acc (d/d:* (literal-apply partial primal-s)
dx))))))))

(defn- dual-forward-mode-fold [f primal-s tag]
(fn
([] 0)
([tangent] (tape/->Dual tag (apply f primal-s) tangent))

Check warning on line 269 in src/emmy/abstract/function.cljc

View check run for this annotation

Codecov / codecov/patch

src/emmy/abstract/function.cljc#L267-L269

Added lines #L267 - L269 were not covered by tests
([tangent [x path _]]
(let [dx (tape/dual-tangent x tag)]
(if (g/numeric-zero? dx)
tangent
(let [partial (literal-partial f path)]
(g/+ tangent (g/* (literal-apply partial primal-s)
dx))))))))

Check warning on line 276 in src/emmy/abstract/function.cljc

View check run for this annotation

Codecov / codecov/patch

src/emmy/abstract/function.cljc#L271-L276

Added lines #L271 - L276 were not covered by tests

(defn- reverse-mode-fold [f primal-s tag]
(fn
([] [])
Expand All @@ -282,6 +294,7 @@
derivatives for each differential argument in the input structure."
[f s tag dx]
(let [fold-fn (cond (tape/tape? dx) reverse-mode-fold
(tape/dual? dx) dual-forward-mode-fold
(d/differential? dx) forward-mode-fold
:else (u/illegal "No tape or differential inputs."))
primal-s (s/mapr (fn [x] (tape/primal-of x tag)) s)]
Expand Down
9 changes: 7 additions & 2 deletions src/emmy/calculus/derivative.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -255,9 +255,14 @@
[f]
(fn [x]
(let [tag (d/fresh-tag)
lifted (d/bundle-element x 1 tag)]
lifted
(d/bundle-element x 1 tag)

;; TODO this is the only change we need to get this working. Also implement comparable.
#_(emmy.tape/->Dual tag x 1)]
(-> (d/with-active-tag tag f [lifted])
(d/extract-tangent tag)))))
(d/extract-tangent tag))
)))

;; The result of applying the derivative `(D f)` of a multivariable function `f`
;; to a sequence of `args` is a structure of the same shape as `args` with all
Expand Down
1 change: 1 addition & 0 deletions src/emmy/differential.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
;; derivatives of entire programs in a similar way by building them out of the
;; derivatives of the smaller pieces of those programs.
;;

;; ### Forward-Mode Automatic Differentiation
;;
;; For many scientific computing applications, it's valuable be able to generate
Expand Down
161 changes: 136 additions & 25 deletions src/emmy/tape.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,44 @@
;; allow [[TapeCell]] instances to work in any place that real numbers or
;; symbolic argument work, make `::tape` derive from `::v/scalar`:

(derive ::dual ::v/scalar)
(derive ::tape ::v/scalar)

;; Here's the [[TapeCell]] type with the fields described above.

(declare compare)

(deftype TapeCell [tag id primal dual in->partial]
(deftype Dual [tag primal tangent]
v/IKind
(kind [_] ::dual)

d/IPerturbed
;; NOTE the reason we need this is for the arguments to literal function.
;; Those need to tell if there is some tape coming in.
(perturbed? [_] true)
(replace-tag [_ old new]
(Dual. (if (= old tag) new tag)
(d/replace-tag primal old new)
(d/replace-tag tangent old new)))

Check warning on line 142 in src/emmy/tape.cljc

View check run for this annotation

Codecov / codecov/patch

src/emmy/tape.cljc#L140-L142

Added lines #L140 - L142 were not covered by tests
(extract-tangent [_ t]
(if (= t tag)
tangent

Check warning on line 145 in src/emmy/tape.cljc

View check run for this annotation

Codecov / codecov/patch

src/emmy/tape.cljc#L144-L145

Added lines #L144 - L145 were not covered by tests
0))

Object
;; TODO revisit all of this
#?(:cljs (valueOf [_] (.valueOf primal)))
(toString [_]
(str "#emmy.tape.Dual"
{:tag tag
:primal primal
:tangent tangent}))

Check warning on line 155 in src/emmy/tape.cljc

View check run for this annotation

Codecov / codecov/patch

src/emmy/tape.cljc#L152-L155

Added lines #L152 - L155 were not covered by tests

;; TODO add comparable block from below
)


(deftype TapeCell [tag id primal in->partial]
v/IKind
(kind [_] ::tape)

Expand All @@ -147,15 +178,13 @@
(TapeCell. (if (= old tag) new tag)
id
(d/replace-tag primal old new)
(d/replace-tag dual old new)
(mapv (fn [[node partial]]
[(d/replace-tag node old new)
(d/replace-tag partial old new)])
in->partial)))

;; This implementation is called if a tape ever makes it out of
;; forward-mode-differentiated function. TODO if we like the dual number thing
;; this should return `dual`.
;; forward-mode-differentiated function.
(extract-tangent [_ _] 0)

Object
Expand Down Expand Up @@ -207,6 +236,12 @@

;; ## Non-generic API

(defn dual?
"Returns true if the supplied object is an instance of [[Dual]], false
otherwise."
[x]
(instance? Dual x))

(defn tape?
"Returns true if the supplied object is an instance of [[TapeCell]], false
otherwise."
Expand All @@ -230,11 +265,9 @@
where `<partial>` is the partial derivative of the output with respect to each
input (defaults to `[]`)."
([tag primal]
(->TapeCell tag (fresh-id) primal 1 []))
(->TapeCell tag (fresh-id) primal []))
([tag primal partials]
(->TapeCell tag (fresh-id) primal 1 partials))
([tag primal dual partials]
(->TapeCell tag (fresh-id) primal dual partials)))
(->TapeCell tag (fresh-id) primal partials)))

;; TODO making [[tapify]] extensible is the key to differentiating things like
;; quaternion-valued functions. Forward-mode handles this differently, since we
Expand Down Expand Up @@ -264,6 +297,10 @@
[^TapeCell tape]
(.-tag tape))

(defn dual-tag
[^Dual dual]
(.-tag dual))

Check warning on line 302 in src/emmy/tape.cljc

View check run for this annotation

Codecov / codecov/patch

src/emmy/tape.cljc#L302

Added line #L302 was not covered by tests

(defn tape-primal
"Given a [[TapeCell]], returns the `primal` field of the supplied [[TapeCell]]
object. For all other types, acts as identity.
Expand All @@ -279,6 +316,26 @@
(.-primal ^TapeCell x)
x)))

(defn dual-primal
([x]
(if (dual? x)
(.-primal ^Dual x)
x))

Check warning on line 323 in src/emmy/tape.cljc

View check run for this annotation

Codecov / codecov/patch

src/emmy/tape.cljc#L321-L323

Added lines #L321 - L323 were not covered by tests
([x tag]
(if (and (dual? x) (= tag (dual-tag x)))
(.-primal ^Dual x)
x)))

Check warning on line 327 in src/emmy/tape.cljc

View check run for this annotation

Codecov / codecov/patch

src/emmy/tape.cljc#L325-L327

Added lines #L325 - L327 were not covered by tests

(defn dual-tangent
([x]
(if (dual? x)
(.-tangent ^Dual x)

Check warning on line 332 in src/emmy/tape.cljc

View check run for this annotation

Codecov / codecov/patch

src/emmy/tape.cljc#L331-L332

Added lines #L331 - L332 were not covered by tests
0))
([x tag]
(if (and (dual? x) (= tag (dual-tag x)))
(.-tangent ^Dual x)

Check warning on line 336 in src/emmy/tape.cljc

View check run for this annotation

Codecov / codecov/patch

src/emmy/tape.cljc#L335-L336

Added lines #L335 - L336 were not covered by tests
0)))

(defn tape-id
"Returns the `-id` field of the supplied [[TapeCell]] object. Errors if any
other type is supplied.
Expand Down Expand Up @@ -314,7 +371,6 @@
{:tag (.-tag t)
:id (.-id t)
:primal (.-primal t)
:dual (.-dual t)
:in->partial (.-in->partial t)})

;; More permissive accessors...
Expand All @@ -324,6 +380,7 @@
non-[[TapeCell]] instance."
[x]
(cond (tape? x) (tape-tag x)
(dual? x) (dual-tag x)
(d/differential? x) (d/max-order-tag x)
:else nil))

Expand All @@ -334,6 +391,7 @@
(primal-of v (tag-of v)))
([v tag]
(cond (tape? v) (tape-primal v tag)
(dual? v) (dual-primal v tag)
(d/differential? v) (d/primal-part v tag)
:else v)))

Expand All @@ -344,6 +402,7 @@
Given a non-[[TapeCell]], acts as identity."
([v]
(cond (tape? v) (recur (tape-primal v))
(dual? v) (recur (dual-primal v))
(d/differential? v) (recur (d/primal-part v))
:else v)))

Expand Down Expand Up @@ -417,6 +476,7 @@
false)))

(defn compare
;; TODO bad docstring
"Comparator that compares [[Differential]] instances with each other or
non-differentials using only the [[finite-part]] of each instance. Matches the
response of [[equiv]].
Expand Down Expand Up @@ -741,6 +801,20 @@
selectors)
(matrix/seq-> (cons x more)))))))

(defn fwd
([f]
(fn
([] 0)

Check warning on line 807 in src/emmy/tape.cljc

View check run for this annotation

Codecov / codecov/patch

src/emmy/tape.cljc#L806-L807

Added lines #L806 - L807 were not covered by tests
([x]
(let [tag (d/fresh-tag)
inputs (->Dual tag x 1)
output (d/with-active-tag tag f [inputs])]
(d/extract-tangent output tag)))

Check warning on line 812 in src/emmy/tape.cljc

View check run for this annotation

Codecov / codecov/patch

src/emmy/tape.cljc#L809-L812

Added lines #L809 - L812 were not covered by tests
([x & more]
((fwd (fn [args]
(apply f args)))
(matrix/seq-> (cons x more)))))))

Check warning on line 816 in src/emmy/tape.cljc

View check run for this annotation

Codecov / codecov/patch

src/emmy/tape.cljc#L814-L816

Added lines #L814 - L816 were not covered by tests

;; starting to work on the one that returns a pair of primal and fn.

#_
Expand Down Expand Up @@ -831,15 +905,22 @@
(fn call [x]
(cond (tape? x)
(let [primal (tape-primal x)
dual (.-dual ^TapeCell x)
partial (df:dx primal)]
(make (tape-tag x)
(call primal)
(if (g/numeric-zero? dual)
dual
(g/* partial dual))
[[x partial]]))

(dual? x)
(let [px (dual-primal x)
tx (dual-tangent x)
partial (df:dx px)]
(->Dual (dual-tag x)
(call px)
(if (g/numeric-zero? tx)
tx
(g/* partial tx))))

Check warning on line 921 in src/emmy/tape.cljc

View check run for this annotation

Codecov / codecov/patch

src/emmy/tape.cljc#L914-L921

Added lines #L914 - L921 were not covered by tests


(d/differential? x)
(let [[px tx] (d/primal-tangent-pair x)
fx (call px)]
Expand Down Expand Up @@ -884,28 +965,35 @@
b
(d/d:+* b (df:dy xe ye) dy))))

(operate-dual [tag]
(let [xe (dual-primal x tag)
ye (dual-primal y tag)
dx (dual-tangent x tag)
dy (dual-tangent y tag)]
(->Dual tag
(call xe ye)
(g/+ (if (g/numeric-zero? dx)
dx
(g/* (df:dx xe ye) dx))
(if (g/numeric-zero? dy)
dy
(g/* (df:dy xe ye) dy))))))

Check warning on line 980 in src/emmy/tape.cljc

View check run for this annotation

Codecov / codecov/patch

src/emmy/tape.cljc#L969-L980

Added lines #L969 - L980 were not covered by tests

(operate-reverse [tag]
(let [primal-x (tape-primal x tag)
primal-y (tape-primal y tag)
dx (delay (df:dx primal-x primal-y))
dy (delay (df:dy primal-x primal-y))
partial-x (if (and (tape? x) (= tag (tape-tag x)))
[[x @dx]]
[[x (df:dx primal-x primal-y)]]
[])
partial-y (if (and (tape? y) (= tag (tape-tag y)))
[[y @dy]]
[[y (df:dy primal-x primal-y)]]
[])]
(make tag
(call primal-x primal-y)
(g/+ (if (and (tape? x) (= tag (tape-tag x)))
(g/* @dx (.-dual ^TapeCell x))
0)
(if (and (tape? y) (= tag (tape-tag y)))
(g/* @dy (.-dual ^TapeCell y))
0))
(into partial-x partial-y))))]
(if-let [[tag dx] (tag+perturbation x y)]
(cond (tape? dx) (operate-reverse tag)
(dual? dx) (operate-dual tag)
(d/differential? dx) (operate-forward tag)
:else
(u/illegal "Non-tape or differential perturbation!"))
Expand Down Expand Up @@ -958,7 +1046,8 @@
([generic-op]
(defunary generic-op (lift-1 generic-op)))
([generic-op differential-op]
(defmethod generic-op [::tape] [a] (differential-op a))))
(defmethod generic-op [::tape] [a] (differential-op a))
(defmethod generic-op [::dual] [a] (differential-op a))))

(defn- defbinary
"Given:
Expand All @@ -977,6 +1066,13 @@
[::d/differential ::tape]
[::tape ::v/scalar]
[::v/scalar ::tape]


[::dual ::dual]
[::tape ::dual]
[::dual ::tape]
[::dual ::v/scalar]
[::v/scalar ::dual]
#_#_#_
[::d/differential ::d/differential]
[::d/differential ::v/scalar]
Expand Down Expand Up @@ -1108,8 +1204,23 @@
(TapeCell. (.-tag t)
(.-id t)
(g/simplify (.-primal t))
(.-dual t)
(mapv (fn [[node partial]]
[(g/simplify node)
(g/simplify partial)])
(.-in->partial t))))

;; DUAL:

(defmethod g/zero-like [::dual] [_] 0)
(defmethod g/one-like [::dual] [_] 1)
(defmethod g/identity-like [::dual] [_] 1)
(defmethod g/freeze [::dual] [t]
`[~'Dual
~(dual-tag t)
~(g/freeze (dual-primal t))
~(g/freeze (dual-tangent t))])

Check warning on line 1221 in src/emmy/tape.cljc

View check run for this annotation

Codecov / codecov/patch

src/emmy/tape.cljc#L1219-L1221

Added lines #L1219 - L1221 were not covered by tests

(defmethod g/simplify [::dual] [^Dual t]
(Dual. (.-tag t)
(g/simplify (.-primal t))
(g/simplify (.-tangent t))))

Check warning on line 1226 in src/emmy/tape.cljc

View check run for this annotation

Codecov / codecov/patch

src/emmy/tape.cljc#L1224-L1226

Added lines #L1224 - L1226 were not covered by tests

0 comments on commit 05f0773

Please sign in to comment.