diff --git a/lib/basic.ml b/lib/basic.ml index 92a792e..38272e8 100644 --- a/lib/basic.ml +++ b/lib/basic.ml @@ -410,8 +410,10 @@ module EGraph = struct | _ -> Iter.empty end | p -> - Vector.to_iter (Id.Map.find classes eid) - |> concat_map (fun enode -> enode_matches p enode env) in + match Id.Map.find_opt classes eid with + | Some v -> Vector.to_iter v |> concat_map (fun enode -> enode_matches p enode env) + | None -> Iter.empty + in (fun f -> Id.Map.iter (Fun.curry f) classes) |> concat_map (fun (eid, _) -> Iter.map (fun s -> (eid, s)) (match_in pattern eid StringMap.empty)) diff --git a/lib/generic.ml b/lib/generic.ml index ce32854..8c54ede 100644 --- a/lib/generic.ml +++ b/lib/generic.ml @@ -421,38 +421,32 @@ struct (* ** Matching *) let ematch eg (classes: (Id.t L.shape, 'a) Vector.t Id.Map.t) pattern = + let concat_map f l = Iter.concat (Iter.map f l) in let rec enode_matches p enode env = match[@warning "-8"] p with | Query.Q (f, _) when not @@ L.equal_op f (L.op enode) -> - None + Iter.empty | Q (_, args) -> (fun f -> List.iter2 (Fun.curry f) args (L.children enode)) - |> Iter.fold_while (fun env (qvar, trm) -> - match env with - | None -> None, `Stop - | Some env -> - match match_in qvar trm env with - | Some _ as res -> res, `Continue - | None -> None, `Stop - ) (Some env) + |> Iter.fold (fun envs (qvar, trm) -> + concat_map (fun env' -> match_in qvar trm env') envs) (Iter.singleton env) and match_in p eid env = let eid = find eg eid in match p with | V id -> begin match StringMap.find_opt id env with - | None -> Some (StringMap.add id eid env) - | Some eid' when Id.eq_id eid eid' -> Some env - | _ -> None + | None -> Iter.singleton (StringMap.add id eid env) + | Some eid' when Id.eq_id eid eid' -> Iter.singleton env + | _ -> Iter.empty end | p -> - Option.bind (Id.Map.find_opt classes eid) - (fun v -> Vector.to_iter v |> Iter.find_map (fun enode -> enode_matches p enode env)) in + match Id.Map.find_opt classes eid with + | Some v -> Vector.to_iter v |> concat_map (fun enode -> enode_matches p enode env) + | None -> Iter.empty + in (fun f -> Id.Map.iter (Fun.curry f) classes) - |> Iter.filter_map (fun (eid,_) -> - match match_in pattern eid StringMap.empty with - | Some env -> Some (eid, env) - | _ -> None - ) + |> concat_map (fun (eid, _) -> + Iter.map (fun s -> (eid, s)) (match_in pattern eid StringMap.empty)) let find_matches eg = let eclasses = eclasses eg in diff --git a/test/test_math.ml b/test/test_math.ml index 9028b66..3c376b7 100644 --- a/test/test_math.ml +++ b/test/test_math.ml @@ -416,7 +416,7 @@ let () = check_proves_equal ~node_limit:(`Bounded 100_000) ~fuel:(`Bounded 35) rules [%s (d x (1 + (2. * x)))] [%s 2. ]; "dx/dy of xy + 1 is y", `Quick, - check_extract ~node_limit:(`Unbounded) ~fuel:(`Bounded 100) rules + check_extract ~node_limit:(`Unbounded) ~fuel:(`Bounded 15) rules [%s (d x (1. + (y * x)))] [%s y ]; "dx/dy of ln x is 1 / x", `Quick, check_proves_equal ~node_limit:(`Bounded 100_000) ~fuel:(`Bounded 35) rules diff --git a/test/test_prop.ml b/test/test_prop.ml index 347cb7f..f61791b 100644 --- a/test/test_prop.ml +++ b/test/test_prop.ml @@ -251,6 +251,27 @@ let proves_cached ?(match_limit=1_000) ?(ban_length=5) ?node_limit ?fuel start g let () = Alcotest.run "prop" [ + "ematch tests", [ + "check matches after merging", `Quick, + (fun () -> let graph = EGraph.init () in + let n1 = EGraph.add_node graph (L.of_sexp [%s (x && z)]) in + let n2 = EGraph.add_node graph (L.of_sexp [%s (y && z)]) in + EGraph.merge graph n1 n2; + EGraph.rebuild graph; + let query = qf [%s "?a" && z] in + let matches = EGraph.find_matches (EGraph.freeze graph) query |> Iter.length in + Alcotest.(check int) "2 matches" 2 matches); + + "check matches after saturating", `Quick, + fun () -> let graph = EGraph.init () in + let scheduler = Ego.Generic.Scheduler.Backoff.with_params ~match_limit:1000 ~ban_length:5 in + let _ = EGraph.add_node graph (L.of_sexp [%s (x && y)]) in + let query = [%s "?a" && "?b"] @-> [%s "?b" && "?a"] in + ignore @@ EGraph.run_until_saturation ~scheduler graph [query]; + let q = qf [%s "?a" && "?b"] in + let matches = EGraph.find_matches (EGraph.freeze graph) q |> Iter.length in + Alcotest.(check int) "2 matches" 2 matches + ]; "proving contrapositive", [ "proves idempotent", `Quick, proves [%s (x => y)] [[%s (x => y)]]; "proves negation", `Quick, proves [%s (x => y)] [[%s (x => y)];