diff --git a/src/emmy/abstract/function.cljc b/src/emmy/abstract/function.cljc index cd87ea1f..3faeed46 100644 --- a/src/emmy/abstract/function.cljc +++ b/src/emmy/abstract/function.cljc @@ -257,12 +257,24 @@ ([acc] acc) ([acc [x path _]] (let [dx (d/tangent-part x tag)] - (if (g/zero? dx) + (if (g/numeric-zero? dx) acc (let [partial (literal-partial f path)] (d/d:+ acc (d/d:* (literal-apply partial primal-s) dx)))))))) +(defn- dual-forward-mode-fold [f primal-s tag] + (fn + ([] 0) + ([tangent] (tape/->Dual tag (apply f primal-s) tangent)) + ([tangent [x path _]] + (let [dx (tape/dual-tangent x tag)] + (if (g/numeric-zero? dx) + tangent + (let [partial (literal-partial f path)] + (g/+ tangent (g/* (literal-apply partial primal-s) + dx)))))))) + (defn- reverse-mode-fold [f primal-s tag] (fn ([] []) @@ -282,6 +294,7 @@ derivatives for each differential argument in the input structure." [f s tag dx] (let [fold-fn (cond (tape/tape? dx) reverse-mode-fold + (tape/dual? dx) dual-forward-mode-fold (d/differential? dx) forward-mode-fold :else (u/illegal "No tape or differential inputs.")) primal-s (s/mapr (fn [x] (tape/primal-of x tag)) s)] diff --git a/src/emmy/calculus/derivative.cljc b/src/emmy/calculus/derivative.cljc index 8d7f957d..f3a507dc 100644 --- a/src/emmy/calculus/derivative.cljc +++ b/src/emmy/calculus/derivative.cljc @@ -255,9 +255,14 @@ [f] (fn [x] (let [tag (d/fresh-tag) - lifted (d/bundle-element x 1 tag)] + lifted + (d/bundle-element x 1 tag) + + ;; TODO this is the only change we need to get this working. Also implement comparable. + #_(emmy.tape/->Dual tag x 1)] (-> (d/with-active-tag tag f [lifted]) - (d/extract-tangent tag))))) + (d/extract-tangent tag)) + ))) ;; The result of applying the derivative `(D f)` of a multivariable function `f` ;; to a sequence of `args` is a structure of the same shape as `args` with all diff --git a/src/emmy/differential.cljc b/src/emmy/differential.cljc index 809f0650..e4e1988d 100644 --- a/src/emmy/differential.cljc +++ b/src/emmy/differential.cljc @@ -34,6 +34,7 @@ ;; derivatives of entire programs in a similar way by building them out of the ;; derivatives of the smaller pieces of those programs. ;; + ;; ### Forward-Mode Automatic Differentiation ;; ;; For many scientific computing applications, it's valuable be able to generate diff --git a/src/emmy/tape.cljc b/src/emmy/tape.cljc index e1a0a1d7..cf077481 100644 --- a/src/emmy/tape.cljc +++ b/src/emmy/tape.cljc @@ -121,13 +121,44 @@ ;; allow [[TapeCell]] instances to work in any place that real numbers or ;; symbolic argument work, make `::tape` derive from `::v/scalar`: +(derive ::dual ::v/scalar) (derive ::tape ::v/scalar) ;; Here's the [[TapeCell]] type with the fields described above. (declare compare) -(deftype TapeCell [tag id primal dual in->partial] +(deftype Dual [tag primal tangent] + v/IKind + (kind [_] ::dual) + + d/IPerturbed + ;; NOTE the reason we need this is for the arguments to literal function. + ;; Those need to tell if there is some tape coming in. + (perturbed? [_] true) + (replace-tag [_ old new] + (Dual. (if (= old tag) new tag) + (d/replace-tag primal old new) + (d/replace-tag tangent old new))) + (extract-tangent [_ t] + (if (= t tag) + tangent + 0)) + + Object + ;; TODO revisit all of this + #?(:cljs (valueOf [_] (.valueOf primal))) + (toString [_] + (str "#emmy.tape.Dual" + {:tag tag + :primal primal + :tangent tangent})) + + ;; TODO add comparable block from below + ) + + +(deftype TapeCell [tag id primal in->partial] v/IKind (kind [_] ::tape) @@ -147,15 +178,13 @@ (TapeCell. (if (= old tag) new tag) id (d/replace-tag primal old new) - (d/replace-tag dual old new) (mapv (fn [[node partial]] [(d/replace-tag node old new) (d/replace-tag partial old new)]) in->partial))) ;; This implementation is called if a tape ever makes it out of - ;; forward-mode-differentiated function. TODO if we like the dual number thing - ;; this should return `dual`. + ;; forward-mode-differentiated function. (extract-tangent [_ _] 0) Object @@ -207,6 +236,12 @@ ;; ## Non-generic API +(defn dual? + "Returns true if the supplied object is an instance of [[Dual]], false + otherwise." + [x] + (instance? Dual x)) + (defn tape? "Returns true if the supplied object is an instance of [[TapeCell]], false otherwise." @@ -230,11 +265,9 @@ where `` is the partial derivative of the output with respect to each input (defaults to `[]`)." ([tag primal] - (->TapeCell tag (fresh-id) primal 1 [])) + (->TapeCell tag (fresh-id) primal [])) ([tag primal partials] - (->TapeCell tag (fresh-id) primal 1 partials)) - ([tag primal dual partials] - (->TapeCell tag (fresh-id) primal dual partials))) + (->TapeCell tag (fresh-id) primal partials))) ;; TODO making [[tapify]] extensible is the key to differentiating things like ;; quaternion-valued functions. Forward-mode handles this differently, since we @@ -264,6 +297,10 @@ [^TapeCell tape] (.-tag tape)) +(defn dual-tag + [^Dual dual] + (.-tag dual)) + (defn tape-primal "Given a [[TapeCell]], returns the `primal` field of the supplied [[TapeCell]] object. For all other types, acts as identity. @@ -279,6 +316,26 @@ (.-primal ^TapeCell x) x))) +(defn dual-primal + ([x] + (if (dual? x) + (.-primal ^Dual x) + x)) + ([x tag] + (if (and (dual? x) (= tag (dual-tag x))) + (.-primal ^Dual x) + x))) + +(defn dual-tangent + ([x] + (if (dual? x) + (.-tangent ^Dual x) + 0)) + ([x tag] + (if (and (dual? x) (= tag (dual-tag x))) + (.-tangent ^Dual x) + 0))) + (defn tape-id "Returns the `-id` field of the supplied [[TapeCell]] object. Errors if any other type is supplied. @@ -314,7 +371,6 @@ {:tag (.-tag t) :id (.-id t) :primal (.-primal t) - :dual (.-dual t) :in->partial (.-in->partial t)}) ;; More permissive accessors... @@ -324,6 +380,7 @@ non-[[TapeCell]] instance." [x] (cond (tape? x) (tape-tag x) + (dual? x) (dual-tag x) (d/differential? x) (d/max-order-tag x) :else nil)) @@ -334,6 +391,7 @@ (primal-of v (tag-of v))) ([v tag] (cond (tape? v) (tape-primal v tag) + (dual? v) (dual-primal v tag) (d/differential? v) (d/primal-part v tag) :else v))) @@ -344,6 +402,7 @@ Given a non-[[TapeCell]], acts as identity." ([v] (cond (tape? v) (recur (tape-primal v)) + (dual? v) (recur (dual-primal v)) (d/differential? v) (recur (d/primal-part v)) :else v))) @@ -417,6 +476,7 @@ false))) (defn compare + ;; TODO bad docstring "Comparator that compares [[Differential]] instances with each other or non-differentials using only the [[finite-part]] of each instance. Matches the response of [[equiv]]. @@ -741,6 +801,20 @@ selectors) (matrix/seq-> (cons x more))))))) +(defn fwd + ([f] + (fn + ([] 0) + ([x] + (let [tag (d/fresh-tag) + inputs (->Dual tag x 1) + output (d/with-active-tag tag f [inputs])] + (d/extract-tangent output tag))) + ([x & more] + ((fwd (fn [args] + (apply f args))) + (matrix/seq-> (cons x more))))))) + ;; starting to work on the one that returns a pair of primal and fn. #_ @@ -831,15 +905,22 @@ (fn call [x] (cond (tape? x) (let [primal (tape-primal x) - dual (.-dual ^TapeCell x) partial (df:dx primal)] (make (tape-tag x) (call primal) - (if (g/numeric-zero? dual) - dual - (g/* partial dual)) [[x partial]])) + (dual? x) + (let [px (dual-primal x) + tx (dual-tangent x) + partial (df:dx px)] + (->Dual (dual-tag x) + (call px) + (if (g/numeric-zero? tx) + tx + (g/* partial tx)))) + + (d/differential? x) (let [[px tx] (d/primal-tangent-pair x) fx (call px)] @@ -884,28 +965,35 @@ b (d/d:+* b (df:dy xe ye) dy)))) + (operate-dual [tag] + (let [xe (dual-primal x tag) + ye (dual-primal y tag) + dx (dual-tangent x tag) + dy (dual-tangent y tag)] + (->Dual tag + (call xe ye) + (g/+ (if (g/numeric-zero? dx) + dx + (g/* (df:dx xe ye) dx)) + (if (g/numeric-zero? dy) + dy + (g/* (df:dy xe ye) dy)))))) + (operate-reverse [tag] (let [primal-x (tape-primal x tag) primal-y (tape-primal y tag) - dx (delay (df:dx primal-x primal-y)) - dy (delay (df:dy primal-x primal-y)) partial-x (if (and (tape? x) (= tag (tape-tag x))) - [[x @dx]] + [[x (df:dx primal-x primal-y)]] []) partial-y (if (and (tape? y) (= tag (tape-tag y))) - [[y @dy]] + [[y (df:dy primal-x primal-y)]] [])] (make tag (call primal-x primal-y) - (g/+ (if (and (tape? x) (= tag (tape-tag x))) - (g/* @dx (.-dual ^TapeCell x)) - 0) - (if (and (tape? y) (= tag (tape-tag y))) - (g/* @dy (.-dual ^TapeCell y)) - 0)) (into partial-x partial-y))))] (if-let [[tag dx] (tag+perturbation x y)] (cond (tape? dx) (operate-reverse tag) + (dual? dx) (operate-dual tag) (d/differential? dx) (operate-forward tag) :else (u/illegal "Non-tape or differential perturbation!")) @@ -958,7 +1046,8 @@ ([generic-op] (defunary generic-op (lift-1 generic-op))) ([generic-op differential-op] - (defmethod generic-op [::tape] [a] (differential-op a)))) + (defmethod generic-op [::tape] [a] (differential-op a)) + (defmethod generic-op [::dual] [a] (differential-op a)))) (defn- defbinary "Given: @@ -977,6 +1066,13 @@ [::d/differential ::tape] [::tape ::v/scalar] [::v/scalar ::tape] + + + [::dual ::dual] + [::tape ::dual] + [::dual ::tape] + [::dual ::v/scalar] + [::v/scalar ::dual] #_#_#_ [::d/differential ::d/differential] [::d/differential ::v/scalar] @@ -1108,8 +1204,23 @@ (TapeCell. (.-tag t) (.-id t) (g/simplify (.-primal t)) - (.-dual t) (mapv (fn [[node partial]] [(g/simplify node) (g/simplify partial)]) (.-in->partial t)))) + +;; DUAL: + +(defmethod g/zero-like [::dual] [_] 0) +(defmethod g/one-like [::dual] [_] 1) +(defmethod g/identity-like [::dual] [_] 1) +(defmethod g/freeze [::dual] [t] + `[~'Dual + ~(dual-tag t) + ~(g/freeze (dual-primal t)) + ~(g/freeze (dual-tangent t))]) + +(defmethod g/simplify [::dual] [^Dual t] + (Dual. (.-tag t) + (g/simplify (.-primal t)) + (g/simplify (.-tangent t))))