Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge dual number + tape implementations? #161

Closed
wants to merge 14 commits into from
3 changes: 2 additions & 1 deletion .github/workflows/coverage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ jobs:
CLOVERAGE_VERSION=1.2.4 clojure -M:test:coverage --codecov || :

- name: Upload coverage to Codecov
uses: codecov/codecov-action@v1
uses: codecov/codecov-action@v3
with:
fail_ci_if_error: true
file: ./target/coverage/codecov.json
token: ${{ secrets.CODECOV_TOKEN }}
22 changes: 22 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,28 @@

## [unreleased]

- #159:

- Fixes `Differential`'s implementation of `emmy.value/numerical?` to always
return `false`. The reason is that `numerical?` is used by `g/*` and friends
to decide on simplifications like `(* <dx-with-1> x) => x`, which would lose
the structure of `dx-with-1`. By returning false we avoid these
simplifications.

- Converts a number of `emmy.value/numerical?` calls to `emmy.value/scalar?`.
The `numerical?` protocol method is used only in generic functions like
`g/*` for deciding whether or not to apply numerical simplifications, like
`(* x 1) => x`.

Guarding on `v/scalar?` instead allows us to let in numbers, tapes and
differentials, since these latter two are meant to WRAP numbers, but should
not be subject to numerical simplifications.


- Adds `emmy.structure/fold-chain` for performing a tree-like fold on
structures, saving us work over the alternate pattern of `s/mapr`, `flatten`
and `reduce`.

- #155

- Replace the implementation of arbitrary-precision rational arithmetic
Expand Down
88 changes: 67 additions & 21 deletions src/emmy/abstract/function.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,16 @@
The namespace also contains an implementation of a small language for
declaring the input and output types of [[literal-function]] instances."
(:refer-clojure :exclude [name])
(:require [emmy.abstract.number :as an]
(:require #?(:clj [clojure.pprint :as pprint])
[emmy.abstract.number :as an]
[emmy.differential :as d]
[emmy.function :as f]
[emmy.generic :as g]
[emmy.matrix :as m]
[emmy.numsymb :as sym]
[emmy.polynomial]
[emmy.structure :as s]
[emmy.tape :as tape]
[emmy.util :as u]
[emmy.value :as v])
#?(:clj
Expand Down Expand Up @@ -139,6 +141,13 @@
(defmethod print-method Function [^Function f ^java.io.Writer w]
(.write w (.toString f))))

#?(:clj
;; NOTE that this override only works in Clojure. In cljs, `simple-dispatch`
;; isn't extensible.
(defmethod pprint/simple-dispatch Function [f]
(pprint/simple-dispatch
(.-f-name ^Function f))))

(derive Function ::function)

(defn literal-function?
Expand Down Expand Up @@ -242,35 +251,65 @@
(->Function
fexp (f/arity f) (domain-types f) (range-type f))))

(defn- forward-mode-fold [f primal-s tag]
(fn
([] (apply f primal-s))
([acc] acc)
([acc [x path _]]
(let [dx (d/tangent-part x tag)]
(if (g/numeric-zero? dx)

Check warning on line 260 in src/emmy/abstract/function.cljc

View check run for this annotation

Codecov / codecov/patch

src/emmy/abstract/function.cljc#L260

Added line #L260 was not covered by tests
acc
(let [partial (literal-partial f path)]
(d/d:+ acc (d/d:* (literal-apply partial primal-s)
dx))))))))

(defn- dual-forward-mode-fold [f primal-s tag]
(fn
([] 0)
([tangent] (tape/->Dual tag (apply f primal-s) tangent))
([tangent [x path _]]
(let [dx (tape/dual-tangent x tag)]
(if (g/numeric-zero? dx)
tangent
(let [partial (literal-partial f path)]
(g/+ tangent (g/* (literal-apply partial primal-s)
dx))))))))

(defn- reverse-mode-fold [f primal-s tag]
(fn
([] [])
([partials]
(tape/make tag (apply f primal-s) partials))
([partials [entry path _]]
(if (and (tape/tape? entry) (= tag (tape/tape-tag entry)))
(let [partial (literal-partial f path)]
(conj partials [entry (literal-apply partial primal-s)]))
partials))))

(defn- literal-derivative
"Takes a literal function `f` and a sequence of arguments `xs`, and generates an
"TODO fix this, not true anymore...

Takes a literal function `f` and a sequence of arguments `xs`, and generates an
expanded `((D f) xs)` by applying the chain rule and summing the partial
derivatives for each differential argument in the input structure."
[f xs]
(let [v (m/seq-> xs)
flat-v (flatten v)
tag (apply d/max-order-tag flat-v)
ve (s/mapr #(d/primal-part % tag) v)
partials (s/map-chain
(fn [x path _]
(let [dx (d/tangent-part x tag)]
(if (g/zero? dx)
0
(d/d:* (literal-apply
(literal-partial f path) ve)
dx))))
v)]
(apply d/d:+ (apply f ve) (flatten partials))))
[f s tag dx]
(let [fold-fn (cond (tape/tape? dx) reverse-mode-fold
(tape/dual? dx) dual-forward-mode-fold
(d/differential? dx) forward-mode-fold
:else (u/illegal "No tape or differential inputs."))
primal-s (s/mapr (fn [x] (tape/primal-of x tag)) s)]
(s/fold-chain (fold-fn f primal-s tag) s)))

(defn- check-argument-type
"Check that the argument provided at index i has the same type as
the exemplar expected."
[f provided expected indexes]
(cond (number? expected)
(when-not (v/numerical? provided)
(when-not (v/scalar? provided)
(u/illegal (str "expected numerical quantity in argument " indexes
" of function call " f
" but got " provided)))

(s/structure? expected)
(do (when-not (and (or (s/structure? provided) (sequential? provided))
(= (s/orientation provided) (s/orientation expected))
Expand All @@ -279,6 +318,7 @@
" but got " provided)))
(doseq [[provided expected sub-index] (map list provided expected (range))]
(check-argument-type f provided expected (conj indexes sub-index))))

(keyword? expected) ;; a keyword has to match the argument's kind
(when-not (= (v/kind provided) expected)
(u/illegal (str "expected argument of type " expected " but got " (v/kind provided)
Expand All @@ -288,9 +328,15 @@

(defn- literal-apply [f xs]
(check-argument-type f xs (domain-types f) [0])
(if (some d/perturbed? xs)
(literal-derivative f xs)
(an/literal-number `(~(name f) ~@(map g/freeze xs)))))
(let [s (m/seq-> xs)]
(if-let [[tag dx] (s/fold-chain
(fn
([] [])
([acc] (apply tape/tag+perturbation acc))
([acc [d]] (conj acc d)))
s)]
(literal-derivative f s tag dx)
(an/literal-number `(~(name f) ~@(map g/freeze xs))))))

;; ## Specific Generics
;;
Expand Down
3 changes: 0 additions & 3 deletions src/emmy/abstract/number.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@
(:import (clojure.lang Symbol))))

(extend-type Symbol
v/Numerical
(numerical? [_] true)

v/IKind
(kind [_] Symbol))

Expand Down
37 changes: 31 additions & 6 deletions src/emmy/calculus/derivative.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
[emmy.operator :as o]
[emmy.series :as series]
[emmy.structure :as s]
[emmy.tape :as tape]
[emmy.util :as u]
[emmy.value :as v])
#?(:clj
Expand Down Expand Up @@ -254,9 +255,14 @@
[f]
(fn [x]
(let [tag (d/fresh-tag)
lifted (d/bundle-element x 1 tag)]
lifted
;; TODO if we want to do this, we can kill differentials and move the
;; tests over...
#_(d/bundle-element x 1 tag)
(emmy.tape/->Dual tag x 1)]
(-> (d/with-active-tag tag f [lifted])
(d/extract-tangent tag)))))
(d/extract-tangent tag))
)))

;; The result of applying the derivative `(D f)` of a multivariable function `f`
;; to a sequence of `args` is a structure of the same shape as `args` with all
Expand Down Expand Up @@ -431,16 +437,20 @@

(doseq [t [::v/function ::s/structure]]
(defmethod g/partial-derivative [t v/seqtype] [f selectors]
#_(tape/gradient f selectors)
(multivariate f selectors))

(defmethod g/partial-derivative [t nil] [f _]
#_(tape/gradient f [])
(multivariate f [])))

;; ## Operators
;;
;; This section exposes various differential operators as [[o/Operator]]
;; instances.

;; TODO note that this will now go to reverse mode.

(def D
"Derivative operator. Takes some function `f` and returns a function whose value
at some point can multiply an increment in the arguments to produce the best
Expand Down Expand Up @@ -472,6 +482,14 @@
(o/make-operator #(g/partial-derivative % selectors)
`(~'partial ~@selectors)))

(def D-fwd
(o/make-operator #(multivariate % [])
g/derivative-symbol))

(defn partial-fwd [& selectors]
(o/make-operator #(multivariate % selectors)
`(~'partial ~@selectors)))

;; ## Derivative Utilities
;;
;; Functions that make use of the differential operators defined above in
Expand Down Expand Up @@ -530,9 +548,16 @@
(letfn [(process-term [term]
(g/simplify
(s/mapr (fn rec [x]
(if (d/differential? x)
(d/map-coefficients rec x)
(-> (g/simplify x)
(x/substitute replace-m))))
(cond (d/differential? x) (d/map-coefficients rec x)
(tape/dual? x)
(tape/->Dual (tape/dual-tag x)
(rec (tape/dual-primal x))
(rec (tape/dual-tangent x)))

(tape/tape? x)
(u/illegal "TODO implement this using fmap style.")

Check warning on line 558 in src/emmy/calculus/derivative.cljc

View check run for this annotation

Codecov / codecov/patch

src/emmy/calculus/derivative.cljc#L558

Added line #L558 was not covered by tests

:else (-> (g/simplify x)
(x/substitute replace-m))))
term)))]
(series/fmap process-term series)))))
3 changes: 0 additions & 3 deletions src/emmy/complex/impl.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,6 @@

v/INumericTower

v/Numerical
(numerical? [_] true)

#?@(:clj [Object
(equals [a b] (equal? a b))
(toString [a] (->string a))]
Expand Down
20 changes: 7 additions & 13 deletions src/emmy/differential.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
;; derivatives of entire programs in a similar way by building them out of the
;; derivatives of the smaller pieces of those programs.
;;

;; ### Forward-Mode Automatic Differentiation
;;
;; For many scientific computing applications, it's valuable be able to generate
Expand Down Expand Up @@ -335,7 +336,10 @@

(defprotocol IPerturbed
(perturbed? [this]
"Returns true if the supplied object has some known non-zero tangent to be
"TODO clear this up, this is part of combining the interfaces for tape and
differential. This is here so that literal functions can handle tape or
differential instances..."
#_"Returns true if the supplied object has some known non-zero tangent to be
extracted via [[extract-tangent]], false otherwise. (Return `false` by
default if you can't detect a perturbation.)")

Expand Down Expand Up @@ -565,16 +569,6 @@
(declare compare equiv finite-term from-terms)

(deftype Differential [terms]
;; A [[Differential]] as implemented can act as a chain-rule accounting device
;; for all sorts of types, not just numbers. A [[Differential]] is
;; only [[v/numerical?]] if its coefficients are numerical (or if `terms` is
;; empty, interpreted as a [[Differential]] equal to `0`.)
v/Numerical
(numerical? [_]
(or (empty? terms)
(v/numerical?
(coefficient (nth terms 0)))))

IPerturbed
(perturbed? [_] true)

Expand Down Expand Up @@ -835,10 +829,10 @@
`tag` defaults to a side-effecting call to [[fresh-tag]]; you can retrieve
this unknown tag by calling [[max-order-tag]]."
([primal]
{:pre [(v/numerical? primal)]}
{:pre [(v/scalar? primal)]}
(bundle-element primal 1 (fresh-tag)))
([primal tag]
{:pre [(v/numerical? primal)]}
{:pre [(v/scalar? primal)]}
(bundle-element primal 1 tag))
([primal tangent tag]
(let [term (make-term (uv/make [tag]) 1)]
Expand Down
2 changes: 1 addition & 1 deletion src/emmy/env.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -581,4 +581,4 @@
[emmy.special.elliptic elliptic-f]
[emmy.special.factorial factorial]
[emmy.value = compare
numerical? kind kind-predicate principal-value])
kind kind-predicate principal-value])
3 changes: 0 additions & 3 deletions src/emmy/expression.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,6 @@
;; other abstract structures referenced in [[abstract-types]].

(deftype Literal [type expression m]
v/Numerical
(numerical? [_] (= type ::numeric))

v/IKind
(kind [_] type)

Expand Down
8 changes: 4 additions & 4 deletions src/emmy/function.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -403,12 +403,12 @@
(with-arity [:exactly 1])))

(defn coerce-to-fn
"Given a [[value/numerical?]] input `x`, returns a function of arity `arity`
"Given an [[emmy.value/scalar?]] input `x`, returns a function of arity `arity`
that always returns `x` no matter what input it receives.

For non-numerical `x`, returns `x`."
([x arity]
(if (v/numerical? x)
(if (v/scalar? x)
(-> (constantly x)
(with-arity arity))
x)))
Expand All @@ -426,8 +426,8 @@
```"
[op]
(letfn [(h [f g]
(let [f-arity (if (v/numerical? f) (arity g) (arity f))
g-arity (if (v/numerical? g) f-arity (arity g))
(let [f-arity (if (v/scalar? f) (arity g) (arity f))
g-arity (if (v/scalar? g) f-arity (arity g))
f1 (coerce-to-fn f f-arity)
g1 (coerce-to-fn g g-arity)
arity (joint-arity [f-arity g-arity])]
Expand Down
9 changes: 7 additions & 2 deletions src/emmy/generic.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,13 @@
([] 1)
([x] x)
([x y]
(let [numx? (v/numerical? x)
numy? (v/numerical? y)]
;; TODO document the change for this in the PR.
;; (g/cos (g/+ 10 (g/* (emmy.env/literal-number 0) 10)))
;;
;; I think it is BETTER that we ignore wrapped numbers actually and get (cos
;; 10) vs the float.
(let [numx? (v/number? x)
numy? (v/number? y)]
(cond (and numx? (zero? x)) (zero-like y)
(and numy? (zero? y)) (zero-like x)
(and numx? (one? x)) y
Expand Down
Loading
Loading