diff --git a/src/lib/frontend/translate.ml b/src/lib/frontend/translate.ml index 262eff96d..f462db546 100644 --- a/src/lib/frontend/translate.ml +++ b/src/lib/frontend/translate.ml @@ -523,8 +523,8 @@ and handle_ty_app ?(update = false) ty_c l = variable. *) let rec apply_ty_substs tysubsts ty = match ty with - | Ty.Tvar { v; _ } -> - Ty.M.find v tysubsts + | Ty.Tvar tv -> + Ty.Subst.eval tysubsts tv | Text (tyl, hs) -> Ty.Text (List.map (apply_ty_substs tysubsts) tyl, hs) @@ -561,9 +561,9 @@ and handle_ty_app ?(update = false) ty_c l = List.fold_left2 ( fun acc tv ty -> match tv with - | Ty.Tvar { v; _ } -> Ty.M.add v ty acc + | Ty.Tvar tv -> Ty.Subst.update tv ty acc | _ -> assert false - ) Ty.M.empty args tyl + ) Ty.Subst.id args tyl in apply_ty_substs tysubsts ty @@ -1687,7 +1687,7 @@ let make_form name_base f loc ~decl_kind = in assert (Var.Map.is_empty (E.free_vars ff Var.Map.empty)); let ff = E.purify_form ff in - if Ty.Svty.is_empty (E.free_type_vars ff) then ff + if Ty.TvSet.is_empty (E.free_type_vars ff) then ff else E.mk_forall name_base loc Var.Map.empty [] ff ~toplevel:true ~decl_kind @@ -1943,7 +1943,7 @@ let make dloc_file acc stmt = assert (Var.Map.is_empty (E.free_vars ff Var.Map.empty)); let ff = E.purify_form ff in let e = - if Ty.Svty.is_empty (E.free_type_vars ff) then ff + if Ty.TvSet.is_empty (E.free_type_vars ff) then ff else E.mk_forall name_base loc Var.Map.empty [] ff ~toplevel:true ~decl_kind @@ -1964,7 +1964,7 @@ let make dloc_file acc stmt = assert (Var.Map.is_empty (E.free_vars ff Var.Map.empty)); let ff = E.purify_form ff in let e = - if Ty.Svty.is_empty (E.free_type_vars ff) then ff + if Ty.TvSet.is_empty (E.free_type_vars ff) then ff else E.mk_forall name_base loc Var.Map.empty [] ff ~toplevel:true ~decl_kind diff --git a/src/lib/reasoners/matching.ml b/src/lib/reasoners/matching.ml index 8e9dedd3d..01bde9e92 100644 --- a/src/lib/reasoners/matching.ml +++ b/src/lib/reasoners/matching.ml @@ -112,7 +112,7 @@ module Make (X : Arg) : S with type theory = X.t = struct if Options.get_debug_matching() >= 3 then let print fmt Matching_types.{ sbs; sty; _ } = Format.fprintf fmt ">>> sbs= %a | sty= %a@ " - (SubstE.pp E.print) sbs Ty.print_subst sty + (SubstE.pp E.print) sbs Ty.Subst.pp sty in print_dbg ~module_name:"Matching" ~function_name:"match_pats_modulo" @@ -124,7 +124,7 @@ module Make (X : Arg) : S with type theory = X.t = struct print_dbg ~module_name:"Matching" ~function_name:"match_one_pat" "match_pat: %a with subst: sbs= %a | sty= %a" - E.print pat0 (SubstE.pp E.print) sbs Ty.print_subst sty + E.print pat0 (SubstE.pp E.print) sbs Ty.Subst.pp sty let match_one_pat_against Matching_types.{ sbs; sty; _ } pat0 t = @@ -136,14 +136,14 @@ module Make (X : Arg) : S with type theory = X.t = struct E.print pat0 E.print t (SubstE.pp E.print) sbs - Ty.print_subst sty + Ty.Subst.pp sty let match_term Matching_types.{ sbs; sty; _ } t pat = if Options.get_debug_matching() >= 3 then print_dbg ~module_name:"Matching" ~function_name:"match_term" "I match %a against %a with subst: sbs=%a | sty= %a" - E.print pat E.print t (SubstE.pp E.print) sbs Ty.print_subst sty + E.print pat E.print t (SubstE.pp E.print) sbs Ty.Subst.pp sty let match_list Matching_types.{ sbs; sty; _ } pats xs = if Options.get_debug_matching() >= 3 then @@ -153,7 +153,7 @@ module Make (X : Arg) : S with type theory = X.t = struct E.print_list pats E.print_list xs (SubstE.pp E.print) sbs - Ty.print_subst sty + Ty.Subst.pp sty let match_class_of t cl = if Options.get_debug_matching() >= 3 then @@ -177,7 +177,7 @@ module Make (X : Arg) : S with type theory = X.t = struct (fun gsbt -> print_dbg ~header:false ">>> sbs = %a and sbty = %a@ " - (SubstE.pp E.print) gsbt.sbs Ty.print_subst gsbt.sty + (SubstE.pp E.print) gsbt.sbs Ty.Subst.pp gsbt.sty )res end @@ -492,7 +492,7 @@ module Make (X : Arg) : S with type theory = X.t = struct else let egs = { sbs = SubstE.empty; - sty = Ty.esubst; + sty = Ty.Subst.id; gen = 0; goal = false; s_term_orig = []; diff --git a/src/lib/structures/expr.ml b/src/lib/structures/expr.ml index 7945afddd..31d148f25 100644 --- a/src/lib/structures/expr.ml +++ b/src/lib/structures/expr.ml @@ -42,7 +42,7 @@ and term_view = { bind : bind_kind; tag: int; vars : (Ty.t * int) Var.Map.t; (* vars to types and nb of occurences *) - vty : Ty.Svty.t; + vty : Ty.TvSet.t; depth: int; nb_nodes : int; pure : bool; @@ -145,11 +145,11 @@ let hash t = t.tag let uid t = t.tag let compare_subst (s_t1, s_ty1) (s_t2, s_ty2) = - let c = Ty.compare_subst s_ty1 s_ty2 in + let c = Ty.Subst.compare s_ty1 s_ty2 in if c<>0 then c else Var.Map.compare compare s_t1 s_t2 let equal_subst (s_t1, s_ty1) (s_t2, s_ty2) = - Ty.equal_subst s_ty1 s_ty2 || Var.Map.equal equal s_t1 s_t2 + Ty.Subst.equal s_ty1 s_ty2 || Var.Map.equal equal s_t1 s_t2 let compare_let let1 let2 = let c = Var.compare let1.let_v let2.let_v in @@ -244,13 +244,6 @@ module Msbt : Map.S with type key = expr Var.Map.t = let compare a b = Var.Map.compare compare a b end) -module Msbty : Map.S with type key = Ty.t Ty.M.t = - Map.Make - (struct - type t = Ty.t Ty.M.t - let compare a b = Ty.M.compare Ty.compare a b - end) - module TSet : Set.S with type elt = expr = Set.Make (struct type t = expr let compare = compare end) @@ -333,10 +326,6 @@ module SmtPrinter = struct | `Forall -> Fmt.pf ppf "forall" | `Exists -> Fmt.pf ppf "exists" - (* This printer follows the convention used to print - type variables in the module [Ty]. *) - let pp_tyvar ppf v = Fmt.pf ppf "A%d" v - let rec pp_main bind ppf { user_trs; main; binders; _ } = if not @@ Var.Map.is_empty binders then Fmt.pf ppf "@[<2>(%a (%a)@, %a@, %a)@]" @@ -348,9 +337,9 @@ module SmtPrinter = struct pp_boxed ppf main and pp_quantified bind ppf q = - if q.toplevel && not @@ Ty.Svty.is_empty q.main.vty then + if q.toplevel && not @@ Ty.TvSet.is_empty q.main.vty then Fmt.pf ppf "@[<2>(par (%a)@, %a)@]" - Fmt.(box @@ iter ~sep:sp Ty.Svty.iter pp_tyvar) q.main.vty + Fmt.(box @@ iter ~sep:sp Ty.TvSet.iter DE.Ty.Var.print) q.main.vty (pp_main bind) q else pp_main bind ppf q @@ -802,7 +791,7 @@ let free_type_vars t = t.vty let is_ground t = Var.Map.is_empty (free_vars t Var.Map.empty) && - Ty.Svty.is_empty (free_type_vars t) + Ty.TvSet.is_empty (free_type_vars t) let size t = t.nb_nodes @@ -876,7 +865,7 @@ let free_vars_non_form s l ty = | _, e::r -> List.fold_left (fun s t -> merge_vars s t.vars) e.vars r let free_type_vars_non_form l ty = - List.fold_left (fun acc t -> Ty.Svty.union acc t.vty) (Ty.vty_of ty) l + List.fold_left (fun acc t -> Ty.TvSet.union acc t.vty) (Ty.vty_of ty) l let is_ite s = match s with | Sy.Op Sy.Tite -> true @@ -960,7 +949,7 @@ let vrai = let res = let nb_nodes = 0 in let vars = Var.Map.empty in - let vty = Ty.Svty.empty in + let vty = Ty.TvSet.empty in let faux = HC.make {f = Sy.False; xs = []; ty = Ty.Tbool; depth = -2; (*smallest depth*) @@ -1040,7 +1029,7 @@ let mk_or f1 f2 is_impl = let d = (max f1.depth f2.depth) in (* the +1 causes regression *) let nb_nodes = f1.nb_nodes + f2.nb_nodes + 1 in let vars = merge_vars f1.vars f2.vars in - let vty = Ty.Svty.union f1.vty f2.vty in + let vty = Ty.TvSet.union f1.vty f2.vty in let pos = HC.make {f=Sy.Form (Sy.F_Clause is_impl); xs=[f1; f2]; ty=Ty.Tbool; depth=d; tag= -42; vars; vty; nb_nodes; neg = None; @@ -1070,7 +1059,7 @@ let mk_iff f1 f2 = let d = (max f1.depth f2.depth) in (* the +1 causes regression *) let nb_nodes = f1.nb_nodes + f2.nb_nodes + 1 in let vars = merge_vars f1.vars f2.vars in - let vty = Ty.Svty.union f1.vty f2.vty in + let vty = Ty.TvSet.union f1.vty f2.vty in let pos = HC.make {f=Sy.Form Sy.F_Iff; xs=[f1; f2]; ty=Ty.Tbool; depth=d; tag= -42; vars; vty; nb_nodes; neg = None; @@ -1157,7 +1146,7 @@ let mk_forall_ter = lemma. Otherwise (if not toplevel), the free vtys of the lemma are those of lem.main *) let vty = - if new_q.toplevel then Ty.Svty.empty + if new_q.toplevel then Ty.TvSet.empty else free_type_vars new_q.main in let vars = @@ -1192,7 +1181,7 @@ let no_occur_check v e = not (Var.Map.mem v e.vars) let no_vtys l = - List.for_all (fun e -> Ty.Svty.is_empty e.vty) l + List.for_all (fun e -> Ty.TvSet.is_empty e.vty) l (** smart constructors for literals *) @@ -1356,12 +1345,12 @@ let no_capture_issue s_t binders = end let rec apply_subst_aux (s_t, s_ty) t = - if is_ground t || (Var.Map.is_empty s_t && Ty.M.is_empty s_ty) then t + if is_ground t || (Var.Map.is_empty s_t && Ty.Subst.is_id s_ty) then t else let { f; xs; ty; vars; vty; bind; _ } = t in let s_t = Var.Map.filter (fun v _ -> Var.Map.mem v vars) s_t in - let s_ty = Ty.M.filter (fun tvar _ -> Ty.Svty.mem tvar vty) s_ty in - if Var.Map.is_empty s_t && Ty.M.is_empty s_ty then t + let s_ty = Ty.Subst.restrict vty s_ty in + if Var.Map.is_empty s_t && Ty.Subst.is_id s_ty then t else let s = s_t, s_ty in let xs', same = My_list.apply (apply_subst_aux s) xs in @@ -1494,14 +1483,14 @@ and mk_let_aux ({ let_v; let_e; in_e; _ } as x) = let _, nb_occ = Var.Map.find let_v in_e.vars in if nb_occ = 1 && (let_e.pure (*1*) || Sy.equal (Sy.var let_v) in_e.f) || is_value_term let_e then (* inline in these situations *) - apply_subst_aux (Var.Map.singleton let_v let_e, Ty.esubst) in_e + apply_subst_aux (Var.Map.singleton let_v let_e, Ty.Subst.id) in_e else let ty = type_info in_e in let d = max let_e.depth in_e.depth in (* no + 1 ? *) let nb_nodes = let_e.nb_nodes + in_e.nb_nodes + 1 (* approx *) in (* do not include free vars in let_sko that have been simplified *) let vars = merge_vars let_e.vars (Var.Map.remove let_v in_e.vars) in - let vty = Ty.Svty.union let_e.vty in_e.vty in + let vty = Ty.TvSet.union let_e.vty in_e.vty in let pos = HC.make {f=Sy.Let; xs=[]; ty; depth=d; tag= -42; vars; vty; nb_nodes; neg = None; @@ -1524,7 +1513,7 @@ and mk_forall_bis (q : quantified) = let binders = (* ignore binders that are not used in f *) Var.Map.filter (fun v _ -> Var.Map.mem v q.main.vars) q.binders in - if Var.Map.is_empty binders && Ty.Svty.is_empty q.main.vty then q.main + if Var.Map.is_empty binders && Ty.TvSet.is_empty q.main.vty then q.main else let q = {q with binders} in (* Attempt to reduce the number of quantifiers. We try to find a @@ -1536,7 +1525,7 @@ and mk_forall_bis (q : quantified) = | None -> mk_forall_ter q | Some sbs -> - let subst = sbs, Ty.esubst in + let subst = sbs, Ty.Subst.id in let f = apply_subst_aux subst q.main in if is_ground f then f else @@ -1582,7 +1571,7 @@ and find_particular_subst = in fun binders trs f -> (* TODO: move the test for `trs` outside. *) - if not (Ty.Svty.is_empty f.vty) || has_hypotheses trs || + if not (Ty.TvSet.is_empty f.vty) || has_hypotheses trs || has_semantic_triggers trs then None @@ -1593,12 +1582,12 @@ and find_particular_subst = Var.Map.fold (fun v ty sbt -> try - let f = apply_subst_aux (sbt, Ty.esubst) f in + let f = apply_subst_aux (sbt, Ty.Subst.id) f in find_subst v (mk_term (Sy.var v) [] ty) f; sbt with Found (x, t) -> assert (not (Var.Map.mem x sbt)); - let one_sbt = Var.Map.singleton x t, Ty.esubst in + let one_sbt = Var.Map.singleton x t, Ty.Subst.id in let sbt = Var.Map.map (apply_subst_aux one_sbt) sbt in Var.Map.add x t sbt ) @@ -1609,15 +1598,20 @@ and find_particular_subst = let apply_subst, clear_subst_cache = - let (cache : t Msbty.t Msbt.t TMap.t ref) = ref TMap.empty in - let apply_subst ((sbt, sbty) as s) f = + let (cache : t Ty.Subst.Map.t Msbt.t TMap.t ref) = ref TMap.empty in + let apply_subst ((sbt, (sbty : Ty.subst)) as s) f = let ch = !cache in - try TMap.find f ch |> Msbt.find sbt |> Msbty.find sbty + try TMap.find f ch |> Msbt.find sbt |> Ty.Subst.Map.find sbty with Not_found -> let nf = apply_subst_aux s f in - let c_sbt = try TMap.find f ch with Not_found -> Msbt.empty in - let c_sbty = try Msbt.find sbt c_sbt with Not_found -> Msbty.empty in - cache := TMap.add f (Msbt.add sbt (Msbty.add sbty nf c_sbty) c_sbt) ch; + let c_sbt = + try TMap.find f ch with Not_found -> Msbt.empty + in + let c_sbty = + try Msbt.find sbt c_sbt with Not_found -> Ty.Subst.Map.empty + in + cache := + TMap.add f (Msbt.add sbt (Ty.Subst.Map.add sbty nf c_sbty) c_sbt) ch; nf in let clear_subst_cache () = @@ -1699,7 +1693,7 @@ let resolution_of_literal a binders free_vty acc = match lit_view a with | Pred(t, _) -> let cond = - Ty.Svty.subset free_vty (free_type_vars t) && + Ty.TvSet.subset free_vty (free_type_vars t) && let vars = free_vars t Var.Map.empty in Var.Map.for_all (fun v _ -> Var.Map.mem v vars) binders in @@ -1781,8 +1775,8 @@ let resolution_triggers ~is_back { kind; main = f; binders; _ } = )cand [] let free_type_vars_as_types e = - Ty.Svty.fold - (fun i z -> Ty.Set.add (Ty.Tvar {Ty.v=i; value = None}) z) + Ty.TvSet.fold + (fun tv z -> Ty.Set.add (Ty.Tvar tv) z) (free_type_vars e) Ty.Set.empty @@ -1806,7 +1800,7 @@ let mk_let let_v let_e in_e = let skolemize { main = f; binders; sko_v; sko_vty; _ } = let print fmt ty = - assert (Ty.Svty.is_empty (Ty.vty_of ty)); + assert (Ty.TvSet.is_empty (Ty.vty_of ty)); Format.fprintf fmt "<%a>" Ty.print ty in let pp_sep_nospace fmt () = Format.fprintf fmt "" in @@ -1837,11 +1831,11 @@ let skolemize { main = f; binders; sko_v; sko_vty; _ } = (fun x ty m -> let i = Var.uid x in let t = mk_term (mk_sym i "_sko") sko_v ty in - let t = apply_subst (grounding_sbt, Ty.esubst) t in + let t = apply_subst (grounding_sbt, Ty.Subst.id) t in Var.Map.add x t m ) binders Var.Map.empty in - let res = apply_subst_aux (sbt, Ty.esubst) f in + let res = apply_subst_aux (sbt, Ty.Subst.id) f in assert (is_ground res); res @@ -1876,16 +1870,16 @@ let rec elim_let = (fun v (ty, _) sbt -> Var.Map.add v (fresh_name ty) sbt) (free_vars sko Var.Map.empty) Var.Map.empty in - apply_subst (sbt, Ty.esubst) sko + apply_subst (sbt, Ty.Subst.id) sko in fun ~recursive ~conjs subst { let_v; let_e; in_e; let_sko; _ } -> assert (Var.Map.mem let_v (free_vars in_e Var.Map.empty)); (* usefull when let_sko still contains variables that are not in ie_e due to simplification *) - let let_sko = apply_subst (subst, Ty.esubst) let_sko in + let let_sko = apply_subst (subst, Ty.Subst.id) let_sko in let let_sko = ground_sko let_sko in assert (is_ground let_sko); - let let_e = apply_subst (subst, Ty.esubst) let_e in + let let_e = apply_subst (subst, Ty.Subst.id) let_e in if let_sko.nb_nodes >= let_e.nb_nodes && let_e.pure then let subst = Var.Map.add let_v let_e subst in elim_let_rec subst in_e ~recursive ~conjs @@ -1902,7 +1896,7 @@ and elim_let_rec subst in_e ~recursive ~conjs = match form_view in_e with | Let letin when recursive -> elim_let ~recursive ~conjs subst letin | _ -> - let f = apply_subst (subst, Ty.esubst) in_e in + let f = apply_subst (subst, Ty.Subst.id) in_e in List.fold_left (fun acc func -> func acc) f conjs @@ -1928,13 +1922,11 @@ let elim_iff f1 f2 ~with_conj = module Triggers = struct - module Svty = Ty.Svty - (* Set of patterns with their sets of free term and type variables. *) module STRS = Set.Make( struct - type t = expr * Var.Set.t * Svty.t + type t = expr * Var.Set.t * Ty.TvSet.t let compare (t1,_,_) (t2,_,_) = compare t1 t2 end) @@ -2105,7 +2097,7 @@ module Triggers = struct fun l -> unique (List.stable_sort cmp_trig_term_list l) [] - let vty_of_term acc t = Svty.union acc t.vty + let vty_of_term acc t = Ty.TvSet.union acc t.vty let not_pure t = not t.pure @@ -2126,11 +2118,11 @@ module Triggers = struct variables. *) not (List.exists not_pure l) && let s1 = List.fold_left (vars_of_term bv) Var.Set.empty l in - let s2 = List.fold_left vty_of_term Svty.empty l in + let s2 = List.fold_left vty_of_term Ty.TvSet.empty l in (* TODO: we can replace `Var.Set.subset bv s1` by `Var.Seq.equal bv s1`. By construction `s1` is a subset of `bv`. *) - Var.Set.subset bv s1 && Svty.subset vty s2 ) + Var.Set.subset bv s1 && Ty.TvSet.subset vty s2 ) trs (* unused @@ -2142,8 +2134,8 @@ module Triggers = struct if List.exists not_pure l then failwith "If-Then-Else are not allowed in (theory triggers)"; let s1 = List.fold_left (vars_of_term bv) SSet.empty l in - let s2 = List.fold_left vty_of_term Svty.empty l in - if not (Svty.subset vty s2) || not (SSet.subset bv s1) then + let s2 = List.fold_left vty_of_term Ty.TvSet.empty l in + if not (Ty.TvSet.subset vty s2) || not (SSet.subset bv s1) then failwith "Triggers of a theory should contain every quantified \ types and variables.") trs; @@ -2171,7 +2163,7 @@ module Triggers = struct module SLLT = Set.Make( struct - type t = expr list * Var.Set.t * Svty.t + type t = expr list * Var.Set.t * Ty.TvSet.t let compare (a, y1, _) (b, y2, _) = let c = try compare_lists a b compare; 0 with Util.Cmp c -> c in if c <> 0 then c else Var.Set.compare y1 y2 @@ -2194,7 +2186,7 @@ module Triggers = struct )t.vars Var.Map.empty in if Var.Map.is_empty sbt then t - else apply_subst (sbt, Ty.esubst) t + else apply_subst (sbt, Ty.Subst.id) t in fun bv ((t,vt,vty) as e) -> let s = Var.Set.diff vt bv in @@ -2222,14 +2214,14 @@ module Triggers = struct let llt, llt_ok = SLLT.fold (fun (l, bv2, vty2) (llt, llt_ok) -> - if Var.Set.subset bv1 bv2 && Svty.subset vty1 vty2 then + if Var.Set.subset bv1 bv2 && Ty.TvSet.subset vty1 vty2 then (* t doesn't bring new vars *) llt, llt_ok else let bv3 = Var.Set.union bv2 bv1 in - let vty3 = Svty.union vty2 vty1 in + let vty3 = Ty.TvSet.union vty2 vty1 in let e = t::l, bv3, vty3 in - if Var.Set.subset bv bv3 && Svty.subset vty vty3 then + if Var.Set.subset bv bv3 && Ty.TvSet.subset vty vty3 then (* The multi-trigger [e] cover all the free variables [bv] and [vty]. *) llt, SLLT.add e llt_ok @@ -2258,17 +2250,17 @@ module Triggers = struct List.exists (fun (_, bv',vty') -> (Var.Set.subset bv bv' && not(Var.Set.equal bv bv') - && Svty.subset vty vty') - || (Svty.subset vty vty' && not(Svty.equal vty vty') + && Ty.TvSet.subset vty vty') + || (Ty.TvSet.subset vty vty' && not(Ty.TvSet.equal vty vty') && Var.Set.subset bv bv') ) l in fun bv_a vty_a l -> let rec simpl_rec acc = function | [] -> acc | ((_, bv, vty) as e)::l -> if strict_subset bv vty l || strict_subset bv vty acc || - (Var.Set.subset bv_a bv && Svty.subset vty_a vty) || + (Var.Set.subset bv_a bv && Ty.TvSet.subset vty_a vty) || (Var.Set.equal (Var.Set.inter bv_a bv) Var.Set.empty && - Svty.equal (Svty.inter vty_a vty) Svty.empty) + Ty.TvSet.equal (Ty.TvSet.inter vty_a vty) Ty.TvSet.empty) then simpl_rec acc l else simpl_rec (e::acc) l in @@ -2294,7 +2286,7 @@ module Triggers = struct and [vtype]. *) let mono = List.filter (fun (_, bv_t, vty_t) -> - Var.Set.subset vterm bv_t && Svty.subset vtype vty_t) trs + Var.Set.subset vterm bv_t && Ty.TvSet.subset vtype vty_t) trs in let trs_v, trs_nv = List.partition (fun (t, _, _) -> is_var t) mono in let base = if menv.Util.triggers_var then trs_nv @ trs_v else trs_nv in @@ -2383,7 +2375,7 @@ module Triggers = struct Var.Map.exists (fun e _ -> Var.Set.mem e bv) bv_lf in let has_tyvar vty vty_lf = - Svty.exists (fun e -> Svty.mem e vty) vty_lf + Ty.TvSet.exists (fun e -> Ty.TvSet.mem e vty) vty_lf in let args_of e lets = match e.bind with @@ -2462,7 +2454,7 @@ module Triggers = struct let sbt = Var.Map.fold (fun v { let_e; _ } sbt -> - let let_e = apply_subst (sbt, Ty.esubst) let_e in + let let_e = apply_subst (sbt, Ty.Subst.id) let_e in if let_e.pure then Var.Map.add v let_e sbt else sbt [@ocaml.ppwarning "TODO: once 'let x = term in term' \ @@ -2471,7 +2463,7 @@ module Triggers = struct depending on the ordering of vars in lets"] ) lets Var.Map.empty in - let sbs = sbt, Ty.esubst in + let sbs = sbt, Ty.Subst.id in STRS.fold (fun (e, _, _) strs -> let e = apply_subst sbs e in @@ -2479,9 +2471,9 @@ module Triggers = struct )terms terms let check_user_triggers f toplevel binders trs0 ~decl_kind = - if Var.Map.is_empty binders && Ty.Svty.is_empty f.vty then trs0 + if Var.Map.is_empty binders && Ty.TvSet.is_empty f.vty then trs0 else - let vtype = if toplevel then f.vty else Ty.Svty.empty in + let vtype = if toplevel then f.vty else Ty.TvSet.empty in let vterm = Var.Map.fold (fun v _ s -> Var.Set.add v s) binders Var.Set.empty in @@ -2498,7 +2490,7 @@ module Triggers = struct filter_good_triggers (vterm, vtype) trs0 let make f binders decl_kind mconf = - if Var.Map.is_empty binders && Ty.Svty.is_empty f.vty then [] + if Var.Map.is_empty binders && Ty.TvSet.is_empty f.vty then [] else let vtype = f.vty in let vterm = @@ -2604,7 +2596,7 @@ let mk_forall name loc binders trs f ~toplevel ~decl_kind = user_trs = trs; main = f; sko_v; sko_vty; kind = decl_kind} let mk_exists name loc binders trs f ~toplevel ~decl_kind = - if not toplevel || Ty.Svty.is_empty f.vty then + if not toplevel || Ty.TvSet.is_empty f.vty then neg (mk_forall name loc binders trs (neg f) ~toplevel ~decl_kind) else (* If there are type variables in a toplevel exists: 1 - we add @@ -2623,7 +2615,7 @@ let rec compile_match mk_destr mker e cases accu = | [] -> accu | (Typed.Var x, p) :: _ -> - apply_subst ((Var.Map.singleton x e), Ty.esubst) p + apply_subst ((Var.Map.singleton x e), Ty.Subst.id) p | (Typed.Constr {name; args}, p) :: l -> let _then = diff --git a/src/lib/structures/expr.mli b/src/lib/structures/expr.mli index bdb726682..308ddb5a2 100644 --- a/src/lib/structures/expr.mli +++ b/src/lib/structures/expr.mli @@ -49,7 +49,7 @@ type term_view = private { (** Map of free term variables in the term to their type and number of occurrences. *) - vty : Ty.Svty.t; + vty : Ty.TvSet.t; (** Map of free type variables in the term. *) depth: int; @@ -222,7 +222,7 @@ val compare_let : letin -> letin -> int (** Some auxiliary functions *) val free_vars : t -> (Ty.t * int) Var.Map.t -> (Ty.t * int) Var.Map.t -val free_type_vars : t -> Ty.Svty.t +val free_type_vars : t -> Ty.TvSet.t val is_ground : t -> bool val size : t -> int val depth : t -> int diff --git a/src/lib/structures/ty.ml b/src/lib/structures/ty.ml index 31716e50a..75374c96f 100644 --- a/src/lib/structures/ty.ml +++ b/src/lib/structures/ty.ml @@ -27,6 +27,11 @@ module DE = Dolmen.Std.Expr +module TvSet = Set.Make (DE.Ty.Var) +module TvMap = Map.Make (DE.Ty.Var) + +type tvar = DE.ty_var + type t = | Tint | Treal @@ -38,8 +43,6 @@ type t = | Tadt of DE.ty_cst * t list | Trecord of trecord -and tvar = { v : int ; mutable value : t option } - and trecord = { mutable args : t list; name : DE.ty_cst; @@ -62,14 +65,12 @@ module Smtlib = struct | Text (args, name) | Trecord { args; name; _ } | Tadt (name, args) -> Fmt.(pf ppf "(@[%a %a@])" DE.Ty.Const.print name (list ~sep:sp pp) args) - | Tvar { v; value = None; _ } -> Fmt.pf ppf "A%d" v - | Tvar { value = Some t; _ } -> pp ppf t + | Tvar tv -> DE.Ty.Var.print ppf tv end let pp_smtlib = Smtlib.pp -exception TypeClash of t*t -exception Shorten of t +exception TypeClash of t * t type adt_constr = { constr : DE.term_cst ; @@ -96,7 +97,6 @@ let assoc_destrs hs cases = (*** pretty print ***) let print_generic body_of = - let h = Hashtbl.create 17 in let rec print = let open Format in fun body_of fmt -> function @@ -104,17 +104,7 @@ let print_generic body_of = | Treal -> fprintf fmt "real" | Tbool -> fprintf fmt "bool" | Tbitv n -> fprintf fmt "bitv[%d]" n - | Tvar{v=v ; value = None} -> fprintf fmt "'a_%d" v - | Tvar{v=v ; value = Some (Trecord { args = l; name = n; _ } as t) } -> - if Hashtbl.mem h v then - fprintf fmt "%a %a" print_list l DE.Ty.Const.print n - else - (Hashtbl.add h v (); - (*fprintf fmt "('a_%d->%a)" v print t *) - print body_of fmt t) - | Tvar{ value = Some t; _ } -> - (*fprintf fmt "('a_%d->%a)" v print t *) - print body_of fmt t + | Tvar tv -> fprintf fmt "'a_%a" DE.Ty.Var.print tv | Text(l, s) when l == [] -> fprintf fmt "%a" DE.Ty.Const.print s | Text(l,s) -> @@ -187,51 +177,11 @@ let print_generic body_of = let print_list = snd (print_generic None) let print = fst (print_generic None) None - -let fresh_var = - let cpt = ref (-1) in - fun () -> incr cpt; {v= !cpt ; value = None } - -let fresh_tvar () = Tvar (fresh_var ()) - -let rec shorten ty = - match ty with - | Tvar { value = None; _ } -> ty - | Tvar { value = Some (Tvar{ value = None; _ } as t'); _ } -> t' - | Tvar ({ value = Some (Tvar t2); _ } as t1) -> - t1.value <- t2.value; shorten ty - | Tvar { value = Some t'; _ } -> shorten t' - - | Text (l,s) -> - let l, same = My_list.apply shorten l in - if same then ty else Text(l,s) - - | Tfarray (t1,t2) -> - let t1' = shorten t1 in - let t2' = shorten t2 in - if t1 == t1' && t2 == t2' then ty - else Tfarray(t1', t2') - - | Trecord r -> - r.args <- List.map shorten r.args; - r.lbs <- List.map (fun (lb, ty) -> lb, shorten ty) r.lbs; - ty - - | Tadt (n, args) -> - let args' = List.map shorten args in - shorten_body n args; - (* should not rebuild the type if no changes are made *) - Tadt (n, args') - - | Tint | Treal | Tbool | Tbitv _ -> ty - -and shorten_body _ _ = - () - [@ocaml.ppwarning "TODO: should be implemented ?"] +let fresh_tvar () = Tvar (DE.Ty.Var.mk "A") let rec compare t1 t2 = - match shorten t1 , shorten t2 with - | Tvar{ v = v1; _ } , Tvar{ v = v2; _ } -> Int.compare v1 v2 + match t1, t2 with + | Tvar v1, Tvar v2 -> DE.Ty.Var.compare v1 v2 | Tvar _, _ -> -1 | _ , Tvar _ -> 1 | Text(l1, s1) , Text(l2, s2) -> let c = DE.Ty.Const.compare s1 s2 in @@ -282,8 +232,8 @@ and compare_list l1 l2 = match l1, l2 with let rec equal t1 t2 = t1 == t2 || - match shorten t1 , shorten t2 with - | Tvar{ v = v1; _ }, Tvar{ v = v2; _ } -> v1 = v2 + match t1, t2 with + | Tvar v1, Tvar v2 -> DE.Ty.Var.equal v1 v2 | Text(l1, s1), Text(l2, s2) -> (try DE.Ty.Const.equal s1 s2 && List.for_all2 equal l1 l2 with Invalid_argument _ -> false) @@ -311,18 +261,64 @@ let rec equal t1 t2 = | _ -> false -(*** matching with a substitution mechanism ***) -module M = Util.MI -type subst = t M.t +module Subst = struct + type subst = t TvMap.t + + let id = TvMap.empty + + let is_id sbt = TvMap.is_empty sbt + + let eval sbt tv = + match TvMap.find tv sbt with + | ty -> ty + | exception Not_found -> Tvar tv + + let update tv ty sbt = + match ty with + | Tvar tv' when DE.Ty.Var.equal tv tv' -> sbt + | _ -> TvMap.add tv ty sbt + + let try_bind tv ty sbt = + match ty with + | Tvar tv' when DE.Ty.Var.equal tv tv' -> sbt + | _ -> + TvMap.update tv + (fun ty_opt -> + match ty_opt with + | None -> Some ty + | Some t when equal t ty -> Some t + | Some t -> raise (TypeClash (t, ty))) + sbt + + let in_domain = TvMap.mem + + let restrict set sbt = + TvMap.filter (fun tv _ -> TvSet.mem tv set) sbt + + let compare = TvMap.compare compare + + let equal = TvMap.equal equal + + let pp = + let sep ppf () = Fmt.pf ppf " -> " in + Fmt.(box @@ braces + @@ iter_bindings ~sep:comma TvMap.iter + @@ pair ~sep DE.Ty.Var.print print) + + module Map = + Map.Make + (struct + type t = subst + let compare = compare + end) +end -let esubst = M.empty +type subst = Subst.subst +(*** matching with a substitution mechanism ***) let rec matching s pat t = - match pat , t with - | Tvar {v=n;value=None} , _ -> - (try if not (equal (M.find n s) t) then raise (TypeClash(pat,t)); s - with Not_found -> M.add n t s) - | Tvar { value = _; _ }, _ -> raise (Shorten pat) + match pat, t with + | Tvar tv, _ -> Subst.try_bind tv t s | Text (l1,s1) , Text (l2,s2) when DE.Ty.Const.equal s1 s2 -> List.fold_left2 matching s l1 l2 | Tfarray (ta1,ta2), Tfarray (tb1,tb2) -> @@ -341,8 +337,8 @@ let rec matching s pat t = let apply_subst = let rec apply_subst s ty = match ty with - | Tvar { v= n; _ } -> - (try M.find n s with Not_found -> ty) + | Tvar tv -> + Subst.eval s tv | Text (l,e) -> let l, same = My_list.apply (apply_subst s) l in @@ -370,23 +366,17 @@ let apply_subst = | Tint | Treal | Tbool | Tbitv _ -> ty in - fun s ty -> if M.is_empty s then ty else apply_subst s ty + fun s ty -> + if Subst.is_id s then ty else apply_subst s ty -(* Assume that [shorten] have been applied on [ty]. *) let rec fresh ty subst = match ty with - | Tvar { value = Some _; _ } -> - (* This case is eliminated by the normalization performed - in [shorten]. *) - assert false - - | Tvar { v= x; _ } -> - begin - try M.find x subst, subst - with Not_found -> - let nv = Tvar (fresh_var()) in - nv, M.add x nv subst - end + | Tvar tv -> + if Subst.in_domain tv subst then + Subst.eval subst tv, subst + else + let ntv = fresh_tvar () in + ntv, Subst.update tv ntv subst | Text (args, n) -> let args, subst = fresh_list args subst in Text (args, n), subst @@ -408,16 +398,15 @@ let rec fresh ty subst = Tadt (s, args), subst | t -> t, subst -(* Assume that [shorten] have been applied on [lty]. *) and fresh_list lty subst = List.fold_right (fun ty (lty, subst) -> let ty, subst = fresh ty subst in ty::lty, subst) lty ([], subst) -let fresh ty subst = fresh (shorten ty) subst +let fresh ty subst = fresh ty subst -let fresh_list lty subst = fresh_list (List.map shorten lty) subst +let fresh_list lty subst = fresh_list lty subst module Decls = struct @@ -439,7 +428,7 @@ module Decls = struct let fresh_type params cases = - let params, subst = fresh_list params esubst in + let params, subst = fresh_list params Subst.id in let _subst, cases = List.fold_left (fun (subst, cases) {constr; destrs} -> @@ -480,16 +469,13 @@ module Decls = struct try List.fold_left2 (fun sbt vty ty -> - let vty = shorten vty in match vty with - | Tvar { value = Some _ ; _ } -> assert false - | Tvar {v ; value = None} -> - if equal vty ty then sbt else M.add v ty sbt + | Tvar tv -> Subst.update tv ty sbt | _ -> Printer.print_err "vty = %a and ty = %a" print vty print ty; assert false - )M.empty params args + ) Subst.id params args with Invalid_argument _ -> assert false in let cases = @@ -564,7 +550,7 @@ let trecord ~record_constr lv name lbs = let rec hash t = match t with - | Tvar{ v; _ } -> v + | Tvar tv -> DE.Ty.Var.hash tv | Text(l,s) -> abs (List.fold_left (fun acc x-> acc*19 + hash x) (DE.Ty.Const.hash s) l) | Tfarray (t1,t2) -> 19 * (hash t1) + 23 * (hash t2) @@ -584,10 +570,6 @@ let rec hash t = | _ -> Hashtbl.hash t -let compare_subst = M.compare compare - -let equal_subst = M.equal equal - module Svty = Util.SI module Set = @@ -599,10 +581,9 @@ module Set = let vty_of t = let rec vty_of_rec acc t = - let t = shorten t in match t with - | Tvar { v = i ; value = None } -> Svty.add i acc - | Text(l,_) -> List.fold_left vty_of_rec acc l + | Tvar tv -> TvSet.add tv acc + | Text (l,_) -> List.fold_left vty_of_rec acc l | Tfarray (t1,t2) -> vty_of_rec (vty_of_rec acc t1) t2 | Trecord { args; lbs; _ } -> let acc = List.fold_left vty_of_rec acc args in @@ -610,16 +591,10 @@ let vty_of t = | Tadt(_, args) -> List.fold_left vty_of_rec acc args - | Tvar { value = Some _ ; _ } | Tint | Treal | Tbool | Tbitv _ -> acc in - vty_of_rec Svty.empty t - -let print_subst = - let sep ppf () = Fmt.pf ppf " -> " in - Fmt.(box @@ braces - @@ iter_bindings ~sep:comma M.iter (pair ~sep int print)) + vty_of_rec TvSet.empty t let print_full = fst (print_generic (Some type_body)) (Some type_body) diff --git a/src/lib/structures/ty.mli b/src/lib/structures/ty.mli index 4bb69218b..f7dfcdfc7 100644 --- a/src/lib/structures/ty.mli +++ b/src/lib/structures/ty.mli @@ -31,6 +31,12 @@ (** {2 Definition} *) +type tvar = Dolmen.Std.Expr.ty_var +(** Type of type variable. *) + +module TvSet : Set.S with type elt = tvar +module TvMap : Map.S with type key = tvar + type t = | Tint (** Integer numbers *) @@ -63,17 +69,6 @@ type t = | Trecord of trecord (** Record type. *) -and tvar = { - v : int; - (** Unique identifier *) - mutable value : t option; - (** Pointer to the current value of the type variable. *) -} -(** Type variables. - The [value] field is mutated during unification, - hence distinct types should have disjoints sets of - type variables (see function {!val:fresh}). *) - and trecord = { mutable args : t list; (** Arguments passed to the record constructor *) @@ -102,13 +97,9 @@ type adt_constr = for recursive ADTs *) type type_body = adt_constr list -module Svty : Set.S with type elt = int -(** Sets of type variables, indexed by their identifier. *) - module Set : Set.S with type elt = t (** Sets of types *) - val assoc_destrs : Dolmen.Std.Expr.term_cst -> adt_constr list -> @@ -144,7 +135,7 @@ val print_list : Format.formatter -> t list -> unit val print_full : Format.formatter -> t -> unit (** Print function including the record fields. *) -val vty_of : t -> Svty.t +val vty_of : t -> TvSet.t (** Returns the set of type variables that occur in a given type. *) @@ -153,10 +144,6 @@ val vty_of : t -> Svty.t val tunit : t (** The unit type. *) -val fresh_var : unit -> tvar -(** Generate a fresh type variable, guaranteed to be distinct - from any other previously generated by this function. *) - val fresh_tvar : unit -> t (** Wrap the {!val:fresh_var} function to return a type. *) @@ -193,24 +180,62 @@ val trecord : (** {2 Substitutions} *) -module M : Map.S with type key = int -(** Maps from type variables identifiers. *) +module Subst : sig + type subst + (** Type of substitution from type variables to types. + + A substitution is equal to the identity substitution but for a finite + number of type variables. + + The domain of the substitution is the set of types variable that + are not sent on itself. *) + + val id : subst + (** The identity substitution. *) + + val is_id : subst -> bool + (** Check if the substitution is the identify substitution. *) + + val eval : subst -> tvar -> t + (** [eval sbt tv] returns the value of the substitution for [tv]. *) + + val update : tvar -> t -> subst -> subst + (** [update tv ty sbt] replaces the value of [sbt] for [tv] by [ty]. + + If the previous value of [tv] in [sbt] is equal to [ty] for [equal], + the returned substitution is physically equal to [sbt]. *) + + val try_bind : tvar -> t -> subst -> subst + (** [try_bind tv ty sbt] tries to bind [tv] with [ty]. The function + succeeds if [tv] is not in the domain of [sbt] or the current + value of [tv] in [sbt] is equal to [ty]. + + If the current value of [tv] is not compatible with [ty], raises + the exception {!exception TypeClash}. + + If the previous value of [tv] in [sbt] is equal to [ty] for [equal], + the returned substitution is physically equal to [sbt]. *) + + val in_domain : tvar -> subst -> bool + (** [in_domain tv sbt] checks if [tv] is in the domain of [sbt]. *) + + val restrict : TvSet.t -> subst -> subst + (** [restrict set sbt] returns a substitution that is equal to [sbt] on + the set [set] and the identity otherwise. *) -type subst = t M.t -(** The type of substitution, i.e. maps - from type variables identifiers to types.*) + val compare : subst -> subst -> int + (** Comparison of substitutions. *) -val compare_subst : subst -> subst -> int -(** Comparison of substitutions. *) + val equal : subst -> subst -> bool + (** Equality of substitutions. *) -val equal_subst : subst -> subst -> bool -(** Equality of substitutions. *) + val pp : subst Fmt.t + (** Print function for substitutions. *) -val print_subst: Format.formatter -> subst -> unit -(** Print function for substitutions. *) + module Map : Map.S with type key = subst +end -val esubst : subst -(** The empty substitution, a.k.a. the identity. *) +type subst = Subst.subst val apply_subst : subst -> t -> t (** Substitution application. *)