Skip to content

Commit

Permalink
Merge pull request #7 from tmoux/ematch
Browse files Browse the repository at this point in the history
Fix e-matching in generic.ml
  • Loading branch information
kiranandcode authored Jan 20, 2024
2 parents 58a1f74 + 4a6a337 commit 1d9a39a
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 22 deletions.
6 changes: 4 additions & 2 deletions lib/basic.ml
Original file line number Diff line number Diff line change
Expand Up @@ -410,8 +410,10 @@ module EGraph = struct
| _ -> Iter.empty
end
| p ->
Vector.to_iter (Id.Map.find classes eid)
|> concat_map (fun enode -> enode_matches p enode env) in
match Id.Map.find_opt classes eid with
| Some v -> Vector.to_iter v |> concat_map (fun enode -> enode_matches p enode env)
| None -> Iter.empty
in
(fun f -> Id.Map.iter (Fun.curry f) classes)
|> concat_map (fun (eid, _) ->
Iter.map (fun s -> (eid, s)) (match_in pattern eid StringMap.empty))
Expand Down
32 changes: 13 additions & 19 deletions lib/generic.ml
Original file line number Diff line number Diff line change
Expand Up @@ -421,38 +421,32 @@ struct

(* ** Matching *)
let ematch eg (classes: (Id.t L.shape, 'a) Vector.t Id.Map.t) pattern =
let concat_map f l = Iter.concat (Iter.map f l) in
let rec enode_matches p enode env =
match[@warning "-8"] p with
| Query.Q (f, _) when not @@ L.equal_op f (L.op enode) ->
None
Iter.empty
| Q (_, args) ->
(fun f -> List.iter2 (Fun.curry f) args (L.children enode))
|> Iter.fold_while (fun env (qvar, trm) ->
match env with
| None -> None, `Stop
| Some env ->
match match_in qvar trm env with
| Some _ as res -> res, `Continue
| None -> None, `Stop
) (Some env)
|> Iter.fold (fun envs (qvar, trm) ->
concat_map (fun env' -> match_in qvar trm env') envs) (Iter.singleton env)
and match_in p eid env =
let eid = find eg eid in
match p with
| V id -> begin
match StringMap.find_opt id env with
| None -> Some (StringMap.add id eid env)
| Some eid' when Id.eq_id eid eid' -> Some env
| _ -> None
| None -> Iter.singleton (StringMap.add id eid env)
| Some eid' when Id.eq_id eid eid' -> Iter.singleton env
| _ -> Iter.empty
end
| p ->
Option.bind (Id.Map.find_opt classes eid)
(fun v -> Vector.to_iter v |> Iter.find_map (fun enode -> enode_matches p enode env)) in
match Id.Map.find_opt classes eid with
| Some v -> Vector.to_iter v |> concat_map (fun enode -> enode_matches p enode env)
| None -> Iter.empty
in
(fun f -> Id.Map.iter (Fun.curry f) classes)
|> Iter.filter_map (fun (eid,_) ->
match match_in pattern eid StringMap.empty with
| Some env -> Some (eid, env)
| _ -> None
)
|> concat_map (fun (eid, _) ->
Iter.map (fun s -> (eid, s)) (match_in pattern eid StringMap.empty))

let find_matches eg =
let eclasses = eclasses eg in
Expand Down
2 changes: 1 addition & 1 deletion test/test_math.ml
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ let () =
check_proves_equal ~node_limit:(`Bounded 100_000) ~fuel:(`Bounded 35) rules
[%s (d x (1 + (2. * x)))] [%s 2. ];
"dx/dy of xy + 1 is y", `Quick,
check_extract ~node_limit:(`Unbounded) ~fuel:(`Bounded 100) rules
check_extract ~node_limit:(`Unbounded) ~fuel:(`Bounded 15) rules
[%s (d x (1. + (y * x)))] [%s y ];
"dx/dy of ln x is 1 / x", `Quick,
check_proves_equal ~node_limit:(`Bounded 100_000) ~fuel:(`Bounded 35) rules
Expand Down
21 changes: 21 additions & 0 deletions test/test_prop.ml
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,27 @@ let proves_cached ?(match_limit=1_000) ?(ban_length=5) ?node_limit ?fuel start g

let () =
Alcotest.run "prop" [
"ematch tests", [
"check matches after merging", `Quick,
(fun () -> let graph = EGraph.init () in
let n1 = EGraph.add_node graph (L.of_sexp [%s (x && z)]) in
let n2 = EGraph.add_node graph (L.of_sexp [%s (y && z)]) in
EGraph.merge graph n1 n2;
EGraph.rebuild graph;
let query = qf [%s "?a" && z] in
let matches = EGraph.find_matches (EGraph.freeze graph) query |> Iter.length in
Alcotest.(check int) "2 matches" 2 matches);

"check matches after saturating", `Quick,
fun () -> let graph = EGraph.init () in
let scheduler = Ego.Generic.Scheduler.Backoff.with_params ~match_limit:1000 ~ban_length:5 in
let _ = EGraph.add_node graph (L.of_sexp [%s (x && y)]) in
let query = [%s "?a" && "?b"] @-> [%s "?b" && "?a"] in
ignore @@ EGraph.run_until_saturation ~scheduler graph [query];
let q = qf [%s "?a" && "?b"] in
let matches = EGraph.find_matches (EGraph.freeze graph) q |> Iter.length in
Alcotest.(check int) "2 matches" 2 matches
];
"proving contrapositive", [
"proves idempotent", `Quick, proves [%s (x => y)] [[%s (x => y)]];
"proves negation", `Quick, proves [%s (x => y)] [[%s (x => y)];
Expand Down

0 comments on commit 1d9a39a

Please sign in to comment.