Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dt] #272 rebased #274

Merged
merged 6 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 19 additions & 20 deletions src/discrimination_tree.ml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ let mkListTailVariable = encode ~k:kOther ~data:3 ~arity:0
let mkListHead = encode ~k:kOther ~data:4 ~arity:0
let mkListEnd = encode ~k:kOther ~data:5 ~arity:0
let mkPathEnd = encode ~k:kOther ~data:6 ~arity:0
let mkListTailVariableUnif = encode ~k:kOther ~data:7 ~arity:0


let isVariable x = x == mkVariable
Expand All @@ -63,6 +64,7 @@ let isOutput x = x == mkOutputMode
let isListHead x = x == mkListHead
let isListEnd x = x == mkListEnd
let isListTailVariable x = x == mkListTailVariable
let isListTailVariableUnif x = x == mkListTailVariableUnif
let isPathEnd x = x == mkPathEnd

type cell = int
Expand All @@ -82,26 +84,21 @@ let pp_cell fmt n =
else if isListHead n then "ListHead"
else if isListEnd n then "ListEnd"
else if isPathEnd n then "PathEnd"
else "Other")
else if isListTailVariableUnif n then "ListTailVariableUnif"
else if isLam n then "Other"
else failwith "Invalid path construct...")
else if k == kPrimitive then Format.fprintf fmt "Primitive"
else Format.fprintf fmt "%o" k

let show_cell n = Format.asprintf "%a" pp_cell n
module Path : sig
type t
val pp : Format.formatter -> t -> unit
val get : t -> int -> cell
type builder
val make : int -> cell -> builder
val emit : builder -> cell -> unit
val stop : builder -> t
val of_list : cell list -> t

end = struct

module Path = struct
type t = cell array [@@deriving show]
let get a i = a.(i)


type builder = { mutable pos : int; mutable path : cell array }
let get_builder_pos {pos} = pos
let make size e = { pos = 0; path = Array.make size e }
let rec emit p e =
let len = Array.length p.path in
Expand Down Expand Up @@ -191,7 +188,7 @@ module Trie = struct
let t' = match other with None -> empty | Some x -> x in
let t'' = ins ~pos:(pos+1) t' in
Node { t with other = Some t'' }
| Node ({ listTailVariable } as t) when isListTailVariable x ->
| Node ({ listTailVariable } as t) when isListTailVariable x || isListTailVariableUnif x ->
let t' = match listTailVariable with None -> empty | Some x -> x in
let t'' = ins ~pos:(pos+1) t' in
Node { t with listTailVariable = Some t'' }
Expand Down Expand Up @@ -229,7 +226,7 @@ end

let update_par_count n k =
if isListHead k then n + 1 else
if isListEnd k || isListTailVariable k then n - 1 else n
if isListEnd k || isListTailVariable k || isListTailVariableUnif k then n - 1 else n

let skip ~pos path (*hd tl*) : int =
let rec aux_list acc p =
Expand Down Expand Up @@ -285,17 +282,18 @@ let skip_listTailVariable ~pos path : int =
In the example it is no needed to index the goal path to depth 100, but rather considering
the maximal depth of the first argument, which 4 << 100
*)
type 'a t = {t: 'a Trie.t; max_size : int; max_depths : int array }
type 'a t = {t: 'a Trie.t; max_size : int; max_depths : int array; max_list_length: int }

let pp pp_a fmt { t } : unit = Trie.pp (fun fmt data -> pp_a fmt data) fmt t
let show pp_a { t } : string = Trie.show (fun fmt data -> pp_a fmt data) t

let index { t; max_size; max_depths } path data =
let index { t; max_size; max_depths; max_list_length = mll } ~max_list_length path data =
let t, m = Trie.add path data t in
{ t; max_size = max max_size m; max_depths }
{ t; max_size = max max_size m; max_depths; max_list_length = max max_list_length mll }

let max_path { max_size } = max_size
let max_depths { max_depths } = max_depths
let max_list_length { max_list_length } = max_list_length

(* the equivalent of skip, but on the index, thus the list of trees
that are rooted just after the term represented by the tree root
Expand Down Expand Up @@ -341,8 +339,8 @@ let rec retrieve ~pos ~add_result mode path tree : unit =
else if isInput hd || isOutput hd then
(* next argument, we update the mode *)
retrieve ~pos:(pos+1) ~add_result hd path tree
else if isListTailVariable hd then
let sub_tries = skip_to_listEnd mode tree in
else if isListTailVariable hd || isListTailVariableUnif hd then
let sub_tries = skip_to_listEnd (if isListTailVariableUnif hd then mkOutputMode else mode) tree in
List.iter (retrieve ~pos:(pos+1) ~add_result mode path) sub_tries
else begin
(* Here the constructor can be Constant, Primitive, Variable, Other, ListHead, ListEnd *)
Expand Down Expand Up @@ -389,7 +387,7 @@ and on_all_children ~pos ~add_result mode path map =

let empty_dt args_depth : 'a t =
let max_depths = Array.make (List.length args_depth) 0 in
{t = Trie.empty; max_depths; max_size = 0}
{t = Trie.empty; max_depths; max_size = 0; max_list_length=0}

let retrieve ~pos ~add_result path index =
let mode = Path.get path pos in
Expand Down Expand Up @@ -420,6 +418,7 @@ module Internal = struct
let isListHead = isListHead
let isListEnd = isListEnd
let isListTailVariable = isListTailVariable
let isListTailVariableUnif = isListTailVariableUnif
let isPathEnd = isPathEnd

end
6 changes: 5 additions & 1 deletion src/discrimination_tree.mli
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ module Path : sig
val pp : Format.formatter -> t -> unit
val get : t -> int -> cell
type builder
val get_builder_pos : builder -> int
val make : int -> cell -> builder
val emit : builder -> cell -> unit
val stop : builder -> t
Expand All @@ -17,6 +18,7 @@ val mkLam : cell
val mkInputMode : cell
val mkOutputMode : cell
val mkListTailVariable : cell
val mkListTailVariableUnif : cell
val mkListHead : cell
val mkListEnd : cell
val mkPathEnd : cell
Expand All @@ -31,9 +33,10 @@ type 'a t

@note: in the elpi runtime, there are no two rule having the same [~time]
*)
val index : 'a t -> Path.t -> 'a -> 'a t
val index : 'a t -> max_list_length:int -> Path.t -> 'a -> 'a t

val max_path : 'a t -> int
val max_list_length : 'a t -> int
val max_depths : 'a t -> int array

val empty_dt : 'b list -> 'a t
Expand Down Expand Up @@ -85,5 +88,6 @@ module Internal: sig
val isListHead : cell -> bool
val isListEnd : cell -> bool
val isListTailVariable : cell -> bool
val isListTailVariableUnif : cell -> bool
val isPathEnd : cell -> bool
end
56 changes: 45 additions & 11 deletions src/runtime.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2507,8 +2507,32 @@ let hash_arg_list is_goal hd ~depth args mode spec =
let hash_clause_arg_list = hash_arg_list false
let hash_goal_arg_list = hash_arg_list true

(* returns the maximal length of any sub_list
this operation is done at compile time per each clause being index
for example the term (app['a','b',app['c','d','e'], 'f', app['a','b','c','d','e','f']])
has three lists L1 = ['a','b',app['c','d','e'], 'f', app['a','b','c','d','e','f']
L2 = ['c','d','e']
L3 = app['a','b','c','d','e','f']
and the longest one has length 6
*)
let rec max_list_length acc = function
| Nil -> acc
| Cons (a, (UVar (_, _, _) | AppUVar (_, _, _) | Arg (_, _) | AppArg (_, _) | Discard)) ->
let alen = max_list_length 0 a in
max (acc+2) alen
| Cons (a, b)->
let alen = max_list_length 0 a in
let blen = max_list_length (acc+1) b in
max blen alen
| App (_,x,xs) -> List.fold_left (fun acc x -> max acc (max_list_length 0 x)) acc (x::xs)
| Builtin (_, xs) -> List.fold_left (fun acc x -> max acc (max_list_length 0 x)) acc xs
| Lam t -> max_list_length acc t
| Discard | Const _ |CData _ | UVar (_, _, _) | AppUVar (_, _, _) | Arg (_, _) | AppArg (_, _) -> acc

let max_lists_length = List.fold_left (fun acc e -> max (max_list_length 0 e) acc) 0

(**
[arg_to_trie_path ~safe ~depth ~is_goal args arg_depths arg_modes]
[arg_to_trie_path ~safe ~depth ~is_goal args arg_depths arg_modes mp max_list_length]
returns the path represetation of a term to be used in indexing with trie.
args, args_depths and arg_modes are the lists of respectively the arguments, the
depths and the modes of the current term to be indexed.
Expand All @@ -2517,10 +2541,14 @@ let hash_goal_arg_list = hash_arg_list true
In the former case, each argument we add a special mkInputMode/mkOutputMode
node before each argument to be indexed. This special node is used during
instance retrival to know the mode of the current argument
- mp is the max_path size of any path in the index used to truncate the goal
- max_list_length is the length of the longest sublist in each term of args
*)
let arg_to_trie_path ~safe ~depth ~is_goal args arg_depths args_depths_ar mode mp : Discrimination_tree.Path.t =
let arg_to_trie_path ~safe ~depth ~is_goal args arg_depths args_depths_ar mode mp (max_list_length':int) : Discrimination_tree.Path.t =
let open Discrimination_tree in
let path = Path.make (max mp 8) mkPathEnd in

let path_length = mp + Array.length args_depths_ar + 1 in
let path = Path.make path_length mkPathEnd in

let current_ar_pos = ref 0 in
let current_user_depth = ref 0 in
Expand All @@ -2537,7 +2565,7 @@ let arg_to_trie_path ~safe ~depth ~is_goal args arg_depths args_depths_ar mode m
end
in

let rec list_to_trie_path ~safe ~depth ?(h=0) path_depth (len: int) (t: term) : unit =
let rec list_to_trie_path ~safe ~depth ~h path_depth (len: int) (t: term) : unit =
match deref_head ~depth t with
| Nil -> Path.emit path mkListEnd; update_current_min_depth path_depth
| Cons (a, b) ->
Expand All @@ -2549,10 +2577,10 @@ let arg_to_trie_path ~safe ~depth ~is_goal args arg_depths args_depths_ar mode m
(* has the node `app` with arity `1` as first*)
(* cell, then come the elment of the list *)
(* up to the 30^th elemebt *)
if h > 30 then (Path.emit path mkListEnd; update_current_min_depth path_depth)
if is_goal && h >= max_list_length' then (Path.emit path mkListTailVariableUnif; update_current_min_depth path_depth)
else
main ~safe ~depth a path_depth;
list_to_trie_path ~depth ~safe ~h:(h+1) path_depth (len+1) b
(main ~safe ~depth a path_depth;
list_to_trie_path ~depth ~safe ~h:(h+1) path_depth (len+1) b)

(* These cases can come from terms like `[_ | _]`, `[_ | A]` ... *)
| UVar _ | AppUVar _ | Arg _ | AppArg _ | Discard -> Path.emit path mkListTailVariable; update_current_min_depth path_depth
Expand All @@ -2578,6 +2606,8 @@ let arg_to_trie_path ~safe ~depth ~is_goal args arg_depths args_depths_ar mode m
(** gives the path representation of a term *)
and main ~safe ~depth t path_depth : unit =
if path_depth = 0 then update_current_min_depth path_depth
else if is_goal && Path.get_builder_pos path >= path_length then
(Path.emit path mkLam; update_current_min_depth path_depth)
else
let path_depth = path_depth - 1 in
match deref_head ~depth t with
Expand All @@ -2601,7 +2631,7 @@ let arg_to_trie_path ~safe ~depth ~is_goal args arg_depths args_depths_ar mode m
| Cons (x,xs) ->
Path.emit path mkListHead;
main ~safe ~depth x (path_depth + 1);
list_to_trie_path ~safe ~depth (path_depth + 1) 0 xs
list_to_trie_path ~safe ~depth ~h:1 (path_depth + 1) 0 xs

(** builds the sub-path of a sublist of arguments of the current clause *)
and make_sub_path arg_hd arg_tl arg_depth_hd arg_depth_tl mode_hd mode_tl =
Expand Down Expand Up @@ -2689,9 +2719,11 @@ let add_clause_to_snd_lvl_idx ~depth ~insert predicate clause = function
| IndexWithDiscriminationTree {mode; arg_depths; args_idx; } ->
let max_depths = Discrimination_tree.max_depths args_idx in
let max_path = Discrimination_tree.max_path args_idx in
let path = arg_to_trie_path ~depth ~safe:true ~is_goal:false clause.args arg_depths max_depths mode max_path in
let max_list_length = max_lists_length clause.args in
(* Format.printf "[%d] Going to index [%a]\n%!" max_list_length (pplist pp_term ",") clause.args; *)
let path = arg_to_trie_path ~depth ~safe:true ~is_goal:false clause.args arg_depths max_depths mode max_path max_list_length in
[%spy "dev:disc-tree:depth-path" ~rid pp_string "Inst: MaxDepths " (pplist pp_int "") (Array.to_list max_depths)];
let args_idx = Discrimination_tree.index args_idx path clause in
let args_idx = Discrimination_tree.index args_idx path clause ~max_list_length in
IndexWithDiscriminationTree {
mode; arg_depths;
args_idx = args_idx
Expand Down Expand Up @@ -2899,8 +2931,10 @@ let get_clauses ~depth predicate goal { index = { idx = m } } =
| IndexWithDiscriminationTree {arg_depths; mode; args_idx} ->
let max_depths = Discrimination_tree.max_depths args_idx in
let max_path = Discrimination_tree.max_path args_idx in
let (path: Discrimination_tree.Path.t) = arg_to_trie_path ~safe:false ~depth ~is_goal:true (trie_goal_args goal) arg_depths max_depths mode max_path in
let max_list_length = Discrimination_tree.max_list_length args_idx in
let path = arg_to_trie_path ~safe:false ~depth ~is_goal:true (trie_goal_args goal) arg_depths max_depths mode max_path max_list_length in
[%spy "dev:disc-tree:depth-path" ~rid pp_string "Goal: MaxDepths " (pplist pp_int ";") (Array.to_list max_depths)];
[%spy "dev:disc-tree:list-size-path" ~rid pp_string "Goal: MaxListSize " pp_int max_list_length];
(* Format.(printf "Goal: MaxDepth is %a\n" (pp_print_list ~pp_sep:(fun fmt _ -> pp_print_string fmt " ") pp_print_int) (Discrimination_tree.max_depths args_idx |> Array.to_list)); *)
[%spy "dev:disc-tree:path" ~rid
Discrimination_tree.Path.pp path
Expand Down
4 changes: 2 additions & 2 deletions src/test_discrimination_tree.ml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ let () =
(* Format.printf "%a\n" pp_path pathGoal; *)
let pathInsts = List.map (fun (x,y) -> x @ [mkPathEnd], y) pathInsts in
let add_to_trie t (key,value) =
index t (Path.of_list key) value in
index t (Path.of_list key) ~max_list_length:1000 value in
let trie = List.fold_left add_to_trie (empty_dt []) pathInsts in
let retrived = retrieve (fun x y -> y - x) pathGoal trie in
let retrived_nb = Elpi.Internal.Bl.length retrived in
Expand Down Expand Up @@ -80,7 +80,7 @@ let () =
let () =
let get_length dt path = DT.retrieve compare path !dt |> Elpi.Internal.Bl.length in
let remove dt e = dt := DT.remove (fun x -> x = e) !dt in
let index dt path v = dt := DT.index !dt path v in
let index dt path v = dt := DT.index !dt path ~max_list_length:1000 v in

let constA = mkConstant ~safe:false ~data:~-1 ~arity:~-0 in (* a *)
let p1 = [mkListHead; constA; mkListTailVariable; constA] in
Expand Down
23 changes: 23 additions & 0 deletions tests/sources/dt_bug272.elpi
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
:index(1 1)
pred map2 i:list A, i:list B, i:(A -> B -> C -> o), o:list C.
map2 [] [] _ [].
map2 [A|As] [B|Bs] F [C|Cs] :- F A B C, map2 As Bs F Cs.

pred any_list i:int, o:list int.
any_list 0 [] :- !.
any_list N [N|L] :- any_list {calc (N - 1)} L.

pred any_pred i:int, i:int, o:int.
any_pred A B R :- R is A + B.

pred test i:int, o:list int.
test N R :- any_list N L, map2 L L any_pred R.

pred loop i:int, i:int.
loop N N :- !.
loop N M :-
test N _,
N1 is N + 1,
loop N1 M.

main :- loop 1 100.
7 changes: 7 additions & 0 deletions tests/suite/correctness_HO.ml
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,13 @@ let () = declare "chr_with_hypotheses"
~expectation:Success
()

let () = declare "dt_bug_272"
~source_elpi:"dt_bug272.elpi"
~description:"dt list truncation heuristic"
~typecheck:true
~expectation:Success
()

let () = declare "bug-256"
~source_elpi:"bug-256.elpi"
~description:"move/unif"
Expand Down
Loading