Skip to content

Commit

Permalink
dual numbers working
Browse files Browse the repository at this point in the history
  • Loading branch information
sritchie committed Apr 9, 2024
1 parent 05f0773 commit 5ca0924
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 27 deletions.
23 changes: 15 additions & 8 deletions src/emmy/calculus/derivative.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -256,10 +256,10 @@
(fn [x]
(let [tag (d/fresh-tag)
lifted
(d/bundle-element x 1 tag)

;; TODO this is the only change we need to get this working. Also implement comparable.
#_(emmy.tape/->Dual tag x 1)]
;; TODO if we want to do this, we can kill differentials and move the
;; tests over...
#_(d/bundle-element x 1 tag)
(emmy.tape/->Dual tag x 1)]
(-> (d/with-active-tag tag f [lifted])
(d/extract-tangent tag))
)))
Expand Down Expand Up @@ -548,9 +548,16 @@
(letfn [(process-term [term]
(g/simplify
(s/mapr (fn rec [x]
(if (d/differential? x)
(d/map-coefficients rec x)
(-> (g/simplify x)
(x/substitute replace-m))))
(cond (d/differential? x) (d/map-coefficients rec x)
(tape/dual? x)
(tape/->Dual (tape/dual-tag x)
(rec (tape/dual-primal x))
(rec (tape/dual-tangent x)))

(tape/tape? x)
(u/illegal "TODO implement this using fmap style.")

Check warning on line 558 in src/emmy/calculus/derivative.cljc

View check run for this annotation

Codecov / codecov/patch

src/emmy/calculus/derivative.cljc#L558

Added line #L558 was not covered by tests

:else (-> (g/simplify x)
(x/substitute replace-m))))
term)))]
(series/fmap process-term series)))))
26 changes: 14 additions & 12 deletions src/emmy/series.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -870,16 +870,18 @@
(defmethod g/exact? [::series] [_] false)
(defmethod g/exact? [::power-series] [_] false)
(defmethod g/freeze [::power-series] [^PowerSeries s]
(let [prefix (->> (g/simplify (take 4 (.-xs s)))
(g/freeze)
(filter (complement g/zero?))
(map-indexed
(fn [n a]
(if (g/one? a)
`(~'expt ~'_ ~n)
`(~'* ~a (~'expt ~'_ ~n))))))]
`(~'+ ~@prefix ~'...)))
(let [prefix (->> (g/simplify (take 4 (.-xs s)))
(g/freeze)
(into [] (comp
(map-indexed
(fn [n a]
(cond (g/zero? a) []
(g/one? a) [(list 'expt '_ n)]
:else [(list '* a (list 'expt '_ n))])))
cat)))]
`(~'+ ~@prefix ~'...)))

(defmethod g/freeze [::series] [^Series s]
(let [prefix (g/freeze
(g/simplify (take 4 (.-xs s))))]
`(~'+ ~@prefix ~'...)))
(let [prefix (g/freeze
(g/simplify (take 4 (.-xs s))))]
`(~'+ ~@prefix ~'...)))
27 changes: 24 additions & 3 deletions src/emmy/tape.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@

;; Here's the [[TapeCell]] type with the fields described above.

(declare compare)
(declare compare compare-dual dual-primal)

(deftype Dual [tag primal tangent]
v/IKind
Expand All @@ -147,15 +147,30 @@

Object
;; TODO revisit all of this
#?(:clj (equals [_ b] (v/= primal (dual-primal b))))

Check warning on line 150 in src/emmy/tape.cljc

View check run for this annotation

Codecov / codecov/patch

src/emmy/tape.cljc#L150

Added line #L150 was not covered by tests
#?(:cljs (valueOf [_] (.valueOf primal)))
(toString [_]
(str "#emmy.tape.Dual"
{:tag tag
:primal primal
:tangent tangent}))

;; TODO add comparable block from below
)
#?@(:clj
;; The motivation for this override is subtle. To participate in control
;; flow operations, like comparison with both [[TapeCell]] and
;; non-[[TapeCell]] instances, [[TapeCell]] instances should compare using
;; ONLY their primal terms. This means that comparison will ignore any
;; difference in `in->partial`.
[Comparable
(compareTo [a b] (compare-dual a b))]

:cljs
[IComparable
(-compare [a b] (compare-dual a b))

IPrintWithWriter
(-pr-writer [x writer _]
(write-all writer (.toString x)))]))


(deftype TapeCell [tag id primal in->partial]
Expand Down Expand Up @@ -487,6 +502,12 @@
(tape-primal a)
(tape-primal b)))

(defn compare-dual
[a b]
(v/compare
(dual-primal a)
(dual-primal b)))

;; ## Reverse-pass support

(defn inner-tag
Expand Down
7 changes: 3 additions & 4 deletions test/emmy/calculus/derivative_test.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -1569,17 +1569,16 @@
(is (v/= [0 1 0 0]
(take 4 ((D (fn [y]
(d/symbolic-taylor-series
(fn [x] (g/* x y))
0)))
(fn [x] (g/* x y)))))
'a)))
"proper function when symbolic-taylor-series is used INSIDE of a call to
`D`; this shows that it can do proper symbolic replacement inside of
differential instances.")

(testing "compare, one stays symbolic:"
(letfn [(f [[a b]]
(* (sin (* 3 a))
(cos (* 4 b))))]
(* (sin (* 3 a))
(cos (* 4 b))))]

(is (ish? [-0.020532965943782493
(s/down 0.4321318251769156 -0.558472974950351)]
Expand Down

0 comments on commit 5ca0924

Please sign in to comment.