Skip to content

Commit

Permalink
fix tape
Browse files Browse the repository at this point in the history
  • Loading branch information
sritchie committed Aug 13, 2024
1 parent c41c215 commit e7f8091
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 18 deletions.
20 changes: 12 additions & 8 deletions src/emmy/tape.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -512,16 +512,20 @@
(u/illegal
(str "Selectors " selectors
" not allowed for non-structural input " x)))

(let [tag (d/fresh-tag)
inputs (if (empty? selectors)
(tapify x tag)
(update-in x selectors tapify tag))
output (d/with-active-tag tag f [inputs])
input (if-let [piece (get-in x selectors)]
(if (empty? selectors)
(tapify piece tag)
(assoc-in x selectors (tapify piece tag)))
;; The call to `get-in` will return nil if the
;; `selectors` don't index correctly into the supplied
;; `input`, triggering this exception.
(u/illegal
(str "Bad selectors " selectors " for structure " x)))
output (d/with-active-tag tag f [input])
completed (d/extract-tangent output tag d/REVERSE-MODE)]
(if (empty? selectors)
(interpret inputs completed tag)
(interpret (get-in inputs selectors) completed tag)))))))
(-> (get-in input selectors)
(interpret completed tag)))))))

(defmethod g/zero-like [::tape] [_] 0)
(defmethod g/one-like [::tape] [_] 1)
Expand Down
37 changes: 27 additions & 10 deletions test/emmy/calculus/derivative_test.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -446,13 +446,27 @@
(testing "space"
(let [g (af/literal-function 'g [0 0] 0)
h (af/literal-function 'h [0 0] 0)]
(is (= '(+ (((partial 0) g) x y) (((partial 0) h) x y))
(simplify (((partial 0) (+ g h)) 'x 'y))))
(is (= '(+ (* (((partial 0) g) x y) (h x y)) (* (((partial 0) h) x y) (g x y)))
(simplify (((partial 0) (* g h)) 'x 'y))))
(is (= '(+ (* (((partial 0) g) x y) (h x y) (expt (g x y) (+ (h x y) -1)))
(* (((partial 0) h) x y) (log (g x y)) (expt (g x y) (h x y))))
(simplify (((partial 0) (g/expt g h)) 'x 'y))))))
(is (zero?
(simplify
(g/- (g/+ (((partial 0) g) 'x 'y)
(((partial 0) h) 'x 'y))
(((partial 0) (+ g h)) 'x 'y)))))
(is (zero?
(simplify
(g/-
(g/+ (g/* (((partial 0) g) 'x 'y) (h 'x 'y))
(g/* (((partial 0) h) 'x 'y) (g 'x 'y)))
(((partial 0) (* g h)) 'x 'y)))))
(is (zero?
(simplify
(g/-
(g/+ (g/* (((partial 0) g) 'x 'y)
(h 'x 'y)
(g/expt (g 'x 'y) (+ (h 'x 'y) -1)))
(g/* (((partial 0) h) 'x 'y)
(g/log (g 'x 'y))
(g/expt (g 'x 'y) (h 'x 'y))))
(((partial 0) (g/expt g h)) 'x 'y)))))))

(testing "operators"
(is (= '(down 1 1 1 1 1 1 1 1 1 1)
Expand Down Expand Up @@ -485,9 +499,12 @@
f3 (fn [x y] (* (tan x) (log y)))
f4 (fn [x y] (* (tan x) (sin y)))
f5 (fn [x y] (/ (tan x) (sin y)))]
(is (= '(down (* (log y) (cos x))
(/ (sin x) y))
(simplify ((D f2) 'x 'y))))
(is (= '(down 0 0)
(simplify
(g/- (s/down
(g/* (g/log 'y) (g/cos 'x))
(g// (g/sin 'x) 'y))
((D f2) 'x 'y)))))
(is (= '(down (/ (log y) (expt (cos x) 2))
(/ (tan x) y))
(simplify ((D f3) 'x 'y))))
Expand Down

0 comments on commit e7f8091

Please sign in to comment.