Skip to content

Commit

Permalink
first batch
Browse files Browse the repository at this point in the history
  • Loading branch information
sritchie committed Apr 8, 2024
1 parent 13dda79 commit 5287511
Show file tree
Hide file tree
Showing 12 changed files with 142 additions and 37 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/coverage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ jobs:
CLOVERAGE_VERSION=1.2.4 clojure -M:test:coverage --codecov || :
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v1
uses: codecov/codecov-action@v3
with:
fail_ci_if_error: true
file: ./target/coverage/codecov.json
token: ${{ secrets.CODECOV_TOKEN }}
22 changes: 22 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,28 @@

## [unreleased]

- #159:

- Fixes `Differential`'s implementation of `emmy.value/numerical?` to always
return `false`. The reason is that `numerical?` is used by `g/*` and friends
to decide on simplifications like `(* <dx-with-1> x) => x`, which would lose
the structure of `dx-with-1`. By returning false we avoid these
simplifications.

- Converts a number of `emmy.value/numerical?` calls to `emmy.value/scalar?`.
The `numerical?` protocol method is used only in generic functions like
`g/*` for deciding whether or not to apply numerical simplifications, like
`(* x 1) => x`.

Guarding on `v/scalar?` instead allows us to let in numbers, tapes and
differentials, since these latter two are meant to WRAP numbers, but should
not be subject to numerical simplifications.


- Adds `emmy.structure/fold-chain` for performing a tree-like fold on
structures, saving us work over the alternate pattern of `s/mapr`, `flatten`
and `reduce`.

- #155

- Replace the implementation of arbitrary-precision rational arithmetic
Expand Down
14 changes: 12 additions & 2 deletions src/emmy/abstract/function.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
The namespace also contains an implementation of a small language for
declaring the input and output types of [[literal-function]] instances."
(:refer-clojure :exclude [name])
(:require [emmy.abstract.number :as an]
(:require #?(:clj [clojure.pprint :as pprint])
[emmy.abstract.number :as an]
[emmy.differential :as d]
[emmy.function :as f]
[emmy.generic :as g]
Expand Down Expand Up @@ -139,6 +140,13 @@
(defmethod print-method Function [^Function f ^java.io.Writer w]
(.write w (.toString f))))

#?(:clj
;; NOTE that this override only works in Clojure. In cljs, `simple-dispatch`
;; isn't extensible.
(defmethod pprint/simple-dispatch Function [f]
(pprint/simple-dispatch
(.-f-name ^Function f))))

(derive Function ::function)

(defn literal-function?
Expand Down Expand Up @@ -267,10 +275,11 @@
the exemplar expected."
[f provided expected indexes]
(cond (number? expected)
(when-not (v/numerical? provided)
(when-not (v/scalar? provided)
(u/illegal (str "expected numerical quantity in argument " indexes
" of function call " f
" but got " provided)))

(s/structure? expected)
(do (when-not (and (or (s/structure? provided) (sequential? provided))
(= (s/orientation provided) (s/orientation expected))
Expand All @@ -279,6 +288,7 @@
" but got " provided)))
(doseq [[provided expected sub-index] (map list provided expected (range))]
(check-argument-type f provided expected (conj indexes sub-index))))

(keyword? expected) ;; a keyword has to match the argument's kind
(when-not (= (v/kind provided) expected)
(u/illegal (str "expected argument of type " expected " but got " (v/kind provided)
Expand Down
17 changes: 7 additions & 10 deletions src/emmy/differential.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -565,15 +565,12 @@
(declare compare equiv finite-term from-terms)

(deftype Differential [terms]
;; A [[Differential]] as implemented can act as a chain-rule accounting device
;; for all sorts of types, not just numbers. A [[Differential]] is
;; only [[v/numerical?]] if its coefficients are numerical (or if `terms` is
;; empty, interpreted as a [[Differential]] equal to `0`.)
;; A [[Differential]] has to respond `false` to all [[emmy.value/numerical?]]
;; inquiries; if we didn't do this, then [[emmy.generic/*]] and friends would
;; attempt to apply shortcuts like `(* x <dx-with-1>) => x`, stripping off
;; the [[Differential]] identity of the result and ruining the derivative.
v/Numerical
(numerical? [_]
(or (empty? terms)
(v/numerical?
(coefficient (nth terms 0)))))
(numerical? [_] false)

IPerturbed
(perturbed? [_] true)
Expand Down Expand Up @@ -835,10 +832,10 @@
`tag` defaults to a side-effecting call to [[fresh-tag]]; you can retrieve
this unknown tag by calling [[max-order-tag]]."
([primal]
{:pre [(v/numerical? primal)]}
{:pre [(v/scalar? primal)]}
(bundle-element primal 1 (fresh-tag)))
([primal tag]
{:pre [(v/numerical? primal)]}
{:pre [(v/scalar? primal)]}
(bundle-element primal 1 tag))
([primal tangent tag]
(let [term (make-term (uv/make [tag]) 1)]
Expand Down
8 changes: 4 additions & 4 deletions src/emmy/function.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -403,12 +403,12 @@
(with-arity [:exactly 1])))

(defn coerce-to-fn
"Given a [[value/numerical?]] input `x`, returns a function of arity `arity`
"Given an [[emmy.value/scalar?]] input `x`, returns a function of arity `arity`
that always returns `x` no matter what input it receives.
For non-numerical `x`, returns `x`."
([x arity]
(if (v/numerical? x)
(if (v/scalar? x)
(-> (constantly x)
(with-arity arity))
x)))
Expand All @@ -426,8 +426,8 @@
```"
[op]
(letfn [(h [f g]
(let [f-arity (if (v/numerical? f) (arity g) (arity f))
g-arity (if (v/numerical? g) f-arity (arity g))
(let [f-arity (if (v/scalar? f) (arity g) (arity f))
g-arity (if (v/scalar? g) f-arity (arity g))
f1 (coerce-to-fn f f-arity)
g1 (coerce-to-fn g g-arity)
arity (joint-arity [f-arity g-arity])]
Expand Down
4 changes: 2 additions & 2 deletions src/emmy/matrix.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@
(s->m ls ms rs)))
([ls ms rs]
(when *careful-conversion*
(assert (v/numerical? (g/* ls (g/* ms rs)))))
(assert (v/scalar? (g/* ls (g/* ms rs)))))
(let [ndowns (s/dimension ls)
nups (s/dimension rs)]
(generate ndowns nups
Expand Down Expand Up @@ -720,7 +720,7 @@
(s/unflatten (nth-col m j) col-shape))
(s/compatible-shape rs))]
(when *careful-conversion*
(assert (v/numerical? (g/* ls (g/* ms rs)))
(assert (v/scalar? (g/* ls (g/* ms rs)))
(str "product is not numerical: " ls ms rs)))
ms))

Expand Down
6 changes: 3 additions & 3 deletions src/emmy/mechanics/hamilton.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@
(and (s/up? s)
(= (count s) 3)
(let [[t q v] s]
(and (v/numerical? t)
(or (and (v/numerical? q)
(v/numerical? v))
(and (v/scalar? t)
(or (and (v/scalar? q)
(v/scalar? v))
(and (s/up? q)
(s/down? v)
(= (s/dimension q)
Expand Down
8 changes: 4 additions & 4 deletions src/emmy/numerical/derivative.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,15 @@
;; First, a function that will print nicely rendered infix versions
;; of (simplified) symbolic expressions:

(->clerk
(->clerk-only
(defn- show [e]
(nextjournal.clerk/tex
(->TeX
(g/simplify e)))))

;; And a function to play with:

(->clerk
(->clerk-only
(def func
(af/literal-function 'f)))

Expand All @@ -70,7 +70,7 @@

;; Here's the taylor series expansions of $f(x + h)$:

(->clerk
(->clerk-only
(def fx+h
(-> ((d/taylor-series func 'x) 'h)
(series/sum 4))))
Expand Down Expand Up @@ -110,7 +110,7 @@

;; We could also expand $f(x - h)$:

(->clerk
(->clerk-only
(def fx-h
(-> ((d/taylor-series func 'x) (g/negate 'h))
(series/sum 4))))
Expand Down
65 changes: 61 additions & 4 deletions src/emmy/structure.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,63 @@
[f & structures]
(sum:r:l f structures))

(defn fold-chain
"Returns the result of accumulating all non-structural entries in `s` using the
supplied fold function `f` into the optional accumulator `init` (defaults
to `(f)`).
`f` must be a 2-argument fn of type `(accumulator, [x chain orientations]) =>
accumulator` responsible for merging some value `x` into the ongoing
accumulation. The second argument is a 3-vector containing
- the entry in the structure
- a vector of its 'access chain', i.e., the path you'd pass
to [[clojure.core/get-in]] to access the entry
- a vector of orientations associated with each index in the access chain
`f` should return a new instance of the accumulator.
Additional arities allow you to supply
- `init`, the initial (empty) accumulator (defaults to `(f)`)
- `present`, a function that will be applied to the final, aggregated
result (defaults to `f`)
For example:
```clojure
(fold-chain
(fn ([] [])
([acc] acc)
([acc [s chain orientations]]
(conj acc {:s s
:chain chain
:orientations orientations})))
(s/down (s/up 1 2) (s/up 3 4)))
[{:s 1, :chain [0 0], :orientations [::s/down ::s/up]}
{:s 2, :chain [0 1], :orientations [::s/down ::s/up]}
{:s 3, :chain [1 0], :orientations [::s/down ::s/up]}
{:s 4, :chain [1 1], :orientations [::s/down ::s/up]}]
```"
([f s] (fold-chain f (f) f s))
([f init s] (fold-chain f init f s))
([f init present s]
(letfn [(walk [acc s chain orientations]
(if (structure? s)
(let [o (orientation s)]
(reduce
(fn [acc i]
(walk acc
(s:nth s i)
(conj chain i)
(conj orientations o)))
acc
(range (count s))))
(f acc [s chain orientations])))]
(present
(walk init s [] [])))))

(defn- map:l
"Returns a new structure generated by mapping `f` across the same-indexed
entries of all supplied structures, one level deep."
Expand Down Expand Up @@ -737,10 +794,10 @@
```clojure
(dorun (map-chain println (s/down (s/up 1 2) (s/up 3 4))))
1 [0 0] [:s/down :s/up]
2 [0 1] [:s/down :s/up]
3 [1 0] [:s/down :s/up]
4 [1 1] [:s/down :s/up]
1 [0 0] [::s/down ::s/up]
2 [0 1] [::s/down ::s/up]
3 [1 0] [::s/down ::s/up]
4 [1 1] [::s/down ::s/up]
```"
[f s]
(letfn [(walk [s chain orientations]
Expand Down
2 changes: 1 addition & 1 deletion src/emmy/value.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@
number
(-equiv [this other]
(cond (core/number? other) (identical? this other)
(numerical? other) (= this (.valueOf other))
(scalar? other) (= this (.valueOf other))
:else false))

goog.math.Integer
Expand Down
12 changes: 8 additions & 4 deletions test/emmy/differential_test.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@
(defmethod g/zero? [#?(:clj String :cljs js/String)] [_] false)
(testing "v/numerical? special cases"
(is (not (v/numerical? (d/from-terms {[] "face"}))))
(is (v/numerical? (d/->Differential []))
(is (v/scalar? (d/->Differential []))
"An empty term list is interpreted as a 0-valued [[Differential]]."))

(checking "native comparison operators work with differential" 100
Expand Down Expand Up @@ -112,9 +112,13 @@
(zero? compare-bit) (is (and (<= l r) (= l r) (>= l r)))
:else (is (> l r))))))))

(checking "v/numerical?" 100 [diff (sg/differential sg/real)]
(is (v/numerical? diff)
"True for all differentials populated by v/numerical? things"))
(checking "v/numerical?, v/scalar?" 100 [diff (sg/differential sg/real)]
(is (v/scalar? diff)
"True for all differentials")

(is (not (v/numerical? diff))
"False for all differentials, even wrapping numerical things...
we don't want g/* and friends to simplify us away."))

(testing "value protocol implementation"
(let [zero (d/->Differential [])
Expand Down
18 changes: 16 additions & 2 deletions test/emmy/structure_test.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -510,10 +510,10 @@
about the entries.")))

(deftest mapper-tests
(testing "sumr"
(testing "sumr, fold-chain"
(with-comparator (v/within 1e-7)
(checking "sumr sums all entries when passed a single structure" 100
[s (-> (gen/fmap #(g/modulo % 100) sg/real)
[s (-> (sg/reasonable-real 100)
(sg/structure 3))]
(is (ish? (reduce g/+ (flatten s))
(s/sumr identity s)))))
Expand Down Expand Up @@ -542,6 +542,20 @@
(s/down 'g 'h)))))
"sumr uses g/+, so symbols etc can be added too."))

(testing "fold-chain"
(is (= [{:s 1, :chain [0 0], :orientations [::s/down ::s/up]}
{:s 2, :chain [0 1], :orientations [::s/down ::s/up]}
{:s 3, :chain [1 0], :orientations [::s/down ::s/up]}
{:s 4, :chain [1 1], :orientations [::s/down ::s/up]}]
(s/fold-chain
(fn ([] [])
([acc] acc)
([acc [s chain orientations]]
(conj acc {:s s
:chain chain
:orientations orientations})))
(s/down (s/up 1 2) (s/up 3 4))))))

(testing "mapr"
(is (= (s/up (s/down 1 4 9)
(s/down 16 25 36)
Expand Down

0 comments on commit 5287511

Please sign in to comment.