From 5ca092416b3e63da06ecc0fdc4db361a66674e15 Mon Sep 17 00:00:00 2001 From: Sam Ritchie Date: Tue, 9 Apr 2024 13:27:34 -0600 Subject: [PATCH] dual numbers working --- src/emmy/calculus/derivative.cljc | 23 +++++++++++++-------- src/emmy/series.cljc | 26 +++++++++++++----------- src/emmy/tape.cljc | 27 ++++++++++++++++++++++--- test/emmy/calculus/derivative_test.cljc | 7 +++---- 4 files changed, 56 insertions(+), 27 deletions(-) diff --git a/src/emmy/calculus/derivative.cljc b/src/emmy/calculus/derivative.cljc index f3a507dc..206767f6 100644 --- a/src/emmy/calculus/derivative.cljc +++ b/src/emmy/calculus/derivative.cljc @@ -256,10 +256,10 @@ (fn [x] (let [tag (d/fresh-tag) lifted - (d/bundle-element x 1 tag) - - ;; TODO this is the only change we need to get this working. Also implement comparable. - #_(emmy.tape/->Dual tag x 1)] + ;; TODO if we want to do this, we can kill differentials and move the + ;; tests over... + #_(d/bundle-element x 1 tag) + (emmy.tape/->Dual tag x 1)] (-> (d/with-active-tag tag f [lifted]) (d/extract-tangent tag)) ))) @@ -548,9 +548,16 @@ (letfn [(process-term [term] (g/simplify (s/mapr (fn rec [x] - (if (d/differential? x) - (d/map-coefficients rec x) - (-> (g/simplify x) - (x/substitute replace-m)))) + (cond (d/differential? x) (d/map-coefficients rec x) + (tape/dual? x) + (tape/->Dual (tape/dual-tag x) + (rec (tape/dual-primal x)) + (rec (tape/dual-tangent x))) + + (tape/tape? x) + (u/illegal "TODO implement this using fmap style.") + + :else (-> (g/simplify x) + (x/substitute replace-m)))) term)))] (series/fmap process-term series))))) diff --git a/src/emmy/series.cljc b/src/emmy/series.cljc index 2d98fcd0..6c26c6ed 100644 --- a/src/emmy/series.cljc +++ b/src/emmy/series.cljc @@ -870,16 +870,18 @@ (defmethod g/exact? [::series] [_] false) (defmethod g/exact? [::power-series] [_] false) (defmethod g/freeze [::power-series] [^PowerSeries s] - (let [prefix (->> (g/simplify (take 4 (.-xs s))) - (g/freeze) - (filter (complement g/zero?)) - (map-indexed - (fn [n a] - (if (g/one? a) - `(~'expt ~'_ ~n) - `(~'* ~a (~'expt ~'_ ~n))))))] - `(~'+ ~@prefix ~'...))) + (let [prefix (->> (g/simplify (take 4 (.-xs s))) + (g/freeze) + (into [] (comp + (map-indexed + (fn [n a] + (cond (g/zero? a) [] + (g/one? a) [(list 'expt '_ n)] + :else [(list '* a (list 'expt '_ n))]))) + cat)))] + `(~'+ ~@prefix ~'...))) + (defmethod g/freeze [::series] [^Series s] - (let [prefix (g/freeze - (g/simplify (take 4 (.-xs s))))] - `(~'+ ~@prefix ~'...))) + (let [prefix (g/freeze + (g/simplify (take 4 (.-xs s))))] + `(~'+ ~@prefix ~'...))) diff --git a/src/emmy/tape.cljc b/src/emmy/tape.cljc index cf077481..e613c894 100644 --- a/src/emmy/tape.cljc +++ b/src/emmy/tape.cljc @@ -126,7 +126,7 @@ ;; Here's the [[TapeCell]] type with the fields described above. -(declare compare) +(declare compare compare-dual dual-primal) (deftype Dual [tag primal tangent] v/IKind @@ -147,6 +147,7 @@ Object ;; TODO revisit all of this + #?(:clj (equals [_ b] (v/= primal (dual-primal b)))) #?(:cljs (valueOf [_] (.valueOf primal))) (toString [_] (str "#emmy.tape.Dual" @@ -154,8 +155,22 @@ :primal primal :tangent tangent})) - ;; TODO add comparable block from below - ) + #?@(:clj + ;; The motivation for this override is subtle. To participate in control + ;; flow operations, like comparison with both [[TapeCell]] and + ;; non-[[TapeCell]] instances, [[TapeCell]] instances should compare using + ;; ONLY their primal terms. This means that comparison will ignore any + ;; difference in `in->partial`. + [Comparable + (compareTo [a b] (compare-dual a b))] + + :cljs + [IComparable + (-compare [a b] (compare-dual a b)) + + IPrintWithWriter + (-pr-writer [x writer _] + (write-all writer (.toString x)))])) (deftype TapeCell [tag id primal in->partial] @@ -487,6 +502,12 @@ (tape-primal a) (tape-primal b))) +(defn compare-dual + [a b] + (v/compare + (dual-primal a) + (dual-primal b))) + ;; ## Reverse-pass support (defn inner-tag diff --git a/test/emmy/calculus/derivative_test.cljc b/test/emmy/calculus/derivative_test.cljc index 4921f9e9..b2b79cbf 100644 --- a/test/emmy/calculus/derivative_test.cljc +++ b/test/emmy/calculus/derivative_test.cljc @@ -1569,8 +1569,7 @@ (is (v/= [0 1 0 0] (take 4 ((D (fn [y] (d/symbolic-taylor-series - (fn [x] (g/* x y)) - 0))) + (fn [x] (g/* x y))))) 'a))) "proper function when symbolic-taylor-series is used INSIDE of a call to `D`; this shows that it can do proper symbolic replacement inside of @@ -1578,8 +1577,8 @@ (testing "compare, one stays symbolic:" (letfn [(f [[a b]] - (* (sin (* 3 a)) - (cos (* 4 b))))] + (* (sin (* 3 a)) + (cos (* 4 b))))] (is (ish? [-0.020532965943782493 (s/down 0.4321318251769156 -0.558472974950351)]