Skip to content

Commit

Permalink
Merge pull request #274 from FissoreD/dt_272_rebased
Browse files Browse the repository at this point in the history
[dt] #272 rebased
  • Loading branch information
gares authored Oct 24, 2024
2 parents 5739f28 + 888ef24 commit 9419bc8
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 59 deletions.
49 changes: 24 additions & 25 deletions src/discrimination_tree.ml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ let mkConstant ~safe ~data ~arity =
let mkPrimitive c = encode ~k:kPrimitive ~data:(CData.hash c lsl k_bits) ~arity:0

let mkVariable = encode ~k:kVariable ~data:0 ~arity:0
let mkLam = encode ~k:kOther ~data:0 ~arity:0
let mkAny = encode ~k:kOther ~data:0 ~arity:0

(* each argument starts with its mode *)
let mkInputMode = encode ~k:kOther ~data:1 ~arity:0
Expand All @@ -54,15 +54,17 @@ let mkListTailVariable = encode ~k:kOther ~data:3 ~arity:0
let mkListHead = encode ~k:kOther ~data:4 ~arity:0
let mkListEnd = encode ~k:kOther ~data:5 ~arity:0
let mkPathEnd = encode ~k:kOther ~data:6 ~arity:0
let mkListTailVariableUnif = encode ~k:kOther ~data:7 ~arity:0


let isVariable x = x == mkVariable
let isLam x = x == mkLam
let isAny x = x == mkAny
let isInput x = x == mkInputMode
let isOutput x = x == mkOutputMode
let isListHead x = x == mkListHead
let isListEnd x = x == mkListEnd
let isListTailVariable x = x == mkListTailVariable
let isListTailVariableUnif x = x == mkListTailVariableUnif
let isPathEnd x = x == mkPathEnd

type cell = int
Expand All @@ -82,26 +84,21 @@ let pp_cell fmt n =
else if isListHead n then "ListHead"
else if isListEnd n then "ListEnd"
else if isPathEnd n then "PathEnd"
else "Other")
else if isListTailVariableUnif n then "ListTailVariableUnif"
else if isAny n then "Other"
else failwith "Invalid path construct...")
else if k == kPrimitive then Format.fprintf fmt "Primitive"
else Format.fprintf fmt "%o" k

let show_cell n = Format.asprintf "%a" pp_cell n
module Path : sig
type t
val pp : Format.formatter -> t -> unit
val get : t -> int -> cell
type builder
val make : int -> cell -> builder
val emit : builder -> cell -> unit
val stop : builder -> t
val of_list : cell list -> t

end = struct

module Path = struct
type t = cell array [@@deriving show]
let get a i = a.(i)


type builder = { mutable pos : int; mutable path : cell array }
let get_builder_pos {pos} = pos
let make size e = { pos = 0; path = Array.make size e }
let rec emit p e =
let len = Array.length p.path in
Expand Down Expand Up @@ -187,11 +184,11 @@ module Trie = struct
let max = ref 0 in
let rec ins ~pos = let x = Path.get a pos in function
| Node ({ data } as t) when isPathEnd x -> max := pos; Node { t with data = v :: data }
| Node ({ other } as t) when isVariable x || isLam x ->
| Node ({ other } as t) when isVariable x || isAny x ->
let t' = match other with None -> empty | Some x -> x in
let t'' = ins ~pos:(pos+1) t' in
Node { t with other = Some t'' }
| Node ({ listTailVariable } as t) when isListTailVariable x ->
| Node ({ listTailVariable } as t) when isListTailVariable x || isListTailVariableUnif x ->
let t' = match listTailVariable with None -> empty | Some x -> x in
let t'' = ins ~pos:(pos+1) t' in
Node { t with listTailVariable = Some t'' }
Expand Down Expand Up @@ -229,7 +226,7 @@ end

let update_par_count n k =
if isListHead k then n + 1 else
if isListEnd k || isListTailVariable k then n - 1 else n
if isListEnd k || isListTailVariable k || isListTailVariableUnif k then n - 1 else n

let skip ~pos path (*hd tl*) : int =
let rec aux_list acc p =
Expand Down Expand Up @@ -285,17 +282,18 @@ let skip_listTailVariable ~pos path : int =
In the example it is no needed to index the goal path to depth 100, but rather considering
the maximal depth of the first argument, which 4 << 100
*)
type 'a t = {t: 'a Trie.t; max_size : int; max_depths : int array }
type 'a t = {t: 'a Trie.t; max_size : int; max_depths : int array; max_list_length: int }

let pp pp_a fmt { t } : unit = Trie.pp (fun fmt data -> pp_a fmt data) fmt t
let show pp_a { t } : string = Trie.show (fun fmt data -> pp_a fmt data) t

let index { t; max_size; max_depths } path data =
let index { t; max_size; max_depths; max_list_length = mll } ~max_list_length path data =
let t, m = Trie.add path data t in
{ t; max_size = max max_size m; max_depths }
{ t; max_size = max max_size m; max_depths; max_list_length = max max_list_length mll }

let max_path { max_size } = max_size
let max_depths { max_depths } = max_depths
let max_list_length { max_list_length } = max_list_length

(* the equivalent of skip, but on the index, thus the list of trees
that are rooted just after the term represented by the tree root
Expand Down Expand Up @@ -330,7 +328,7 @@ let skip_to_listEnd ~add_result mode (Trie.Node { other; map; listTailVariable }

let skip_to_listEnd mode t = call (skip_to_listEnd mode t)

let get_all_children v mode = isLam v || (isVariable v && isOutput mode)
let get_all_children v mode = isAny v || (isVariable v && isOutput mode)

let rec retrieve ~pos ~add_result mode path tree : unit =
let hd = Path.get path pos in
Expand All @@ -341,8 +339,8 @@ let rec retrieve ~pos ~add_result mode path tree : unit =
else if isInput hd || isOutput hd then
(* next argument, we update the mode *)
retrieve ~pos:(pos+1) ~add_result hd path tree
else if isListTailVariable hd then
let sub_tries = skip_to_listEnd mode tree in
else if isListTailVariable hd || isListTailVariableUnif hd then
let sub_tries = skip_to_listEnd (if isListTailVariableUnif hd then mkOutputMode else mode) tree in
List.iter (retrieve ~pos:(pos+1) ~add_result mode path) sub_tries
else begin
(* Here the constructor can be Constant, Primitive, Variable, Other, ListHead, ListEnd *)
Expand Down Expand Up @@ -389,7 +387,7 @@ and on_all_children ~pos ~add_result mode path map =

let empty_dt args_depth : 'a t =
let max_depths = Array.make (List.length args_depth) 0 in
{t = Trie.empty; max_depths; max_size = 0}
{t = Trie.empty; max_depths; max_size = 0; max_list_length=0}

let retrieve ~pos ~add_result path index =
let mode = Path.get path pos in
Expand All @@ -414,12 +412,13 @@ module Internal = struct
let data_of = data_of

let isVariable = isVariable
let isLam = isLam
let isAny = isAny
let isInput = isInput
let isOutput = isOutput
let isListHead = isListHead
let isListEnd = isListEnd
let isListTailVariable = isListTailVariable
let isListTailVariableUnif = isListTailVariableUnif
let isPathEnd = isPathEnd

end
75 changes: 58 additions & 17 deletions src/discrimination_tree.mli
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ module Path : sig
val pp : Format.formatter -> t -> unit
val get : t -> int -> cell
type builder
val get_builder_pos : builder -> int
val make : int -> cell -> builder
val emit : builder -> cell -> unit
val stop : builder -> t
Expand All @@ -12,42 +13,81 @@ end

val mkConstant : safe:bool -> data:int -> arity:int -> cell
val mkPrimitive : Elpi_util.Util.CData.t -> cell

(** This is for: unification variables, discard *)
val mkVariable : cell
val mkLam : cell

(** This is for:
- lambda-abstractions: DT does not perform HO unification, nor βη-redex unif
- too big terms: if the path of the goal is bigger then the max path in the
rules, each term is replaced with this constructor. Note that
we do not use mkVariable, since in input mode a variable
cannot be unified with non-flex terms. In DT, mkAny is
unified with anything
*)
val mkAny : cell
val mkInputMode : cell
val mkOutputMode : cell

(** This is for the last term of opened lists.
If the list ends is [1,2|X] == Cons (CData'1',Cons(CData'2',Arg (_, _)))
The corresponding path will be: ListHead, Primitive, Primitive,
ListTailVariable
ListTailVariable is considered as a variable, and if it appears in a goal in
input position, it cannot be unified with non-flex terms
*)
val mkListTailVariable : cell

(** This is used for capped lists.
If the length of the maximal list in the rules of a predicate is N, then any
(sub-)list in the goal longer then N encodes the first N elements in the path
and the last are replaced with ListTailVariableUnif.
The main difference with ListTailVariable is that DT will unify this symbol to
both flex and non-flex terms in the path of rules
*)
val mkListTailVariableUnif : cell
val mkListHead : cell
val mkListEnd : cell

(** This is padding used to fill the array in paths and indicate the retrieve
function to stop making unification.
Note that the array for the path has a length which is doubled each time the
terms in it are larger then the forseen length. Each time the array is
doubled, the new cells are filled with PathEnd.
*)
val mkPathEnd : cell

type 'a t

(** [index dt path value ~time] returns a new discrimination tree starting from [dt]
where [value] is added wrt its [path]. [~time] is used as a priority
marker between two rules.
A rule with a given [~time] has higher priority on other rules with lower [~time]
(** [index dt ~max_list_length path value] returns a new discrimination tree
starting from [dt] where [value] is added wrt its [path].
@note: in the elpi runtime, there are no two rule having the same [~time]
[max_list_length] is provided and used to update (if needed) the length of
the longest list in the received path.
*)
val index : 'a t -> Path.t -> 'a -> 'a t
val index : 'a t -> max_list_length:int -> Path.t -> 'a -> 'a t

val max_path : 'a t -> int
val max_list_length : 'a t -> int
val max_depths : 'a t -> int array

val empty_dt : 'b list -> 'a t

(** [retrive path dt] Retrives all values in a discrimination tree [dt] from a given path [p].
(** [retrive cmp path dt] Retrives all values in a discrimination tree [dt] from
a given path [p].
The retrival algorithm performs a light unification between [p] and the
nodes in the discrimination tree. This light unification takes care of
unification variables that can be either in the path or in the nodes of [dt]
The retrival algorithm performs a light unification between [p] and the nodes
in the discrimination tree. This light unification takes care of unification
variables that can be either in the path or in the nodes of [dt]
The returned list of values are sorted wrt to the order in which values are
added in the tree: given two rules r_1 and r_2 with same path, if r_1
has been added at time [t] and r_2 been added at time [t+1] then
r_2 will appear before r_1 in the final result
The returned list is sorted wrt [cmp] so that rules are given in the expected
order
*)
val retrieve : ('a -> 'a -> int) -> Path.t -> 'a t -> 'a Bl.scan

Expand Down Expand Up @@ -79,11 +119,12 @@ module Internal: sig
val data_of : cell -> int

val isVariable : cell -> bool
val isLam : cell -> bool
val isAny : cell -> bool
val isInput : cell -> bool
val isOutput : cell -> bool
val isListHead : cell -> bool
val isListEnd : cell -> bool
val isListTailVariable : cell -> bool
val isListTailVariableUnif : cell -> bool
val isPathEnd : cell -> bool
end
58 changes: 46 additions & 12 deletions src/runtime.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2507,8 +2507,32 @@ let hash_arg_list is_goal hd ~depth args mode spec =
let hash_clause_arg_list = hash_arg_list false
let hash_goal_arg_list = hash_arg_list true

(* returns the maximal length of any sub_list
this operation is done at compile time per each clause being index
for example the term (app['a','b',app['c','d','e'], 'f', app['a','b','c','d','e','f']])
has three lists L1 = ['a','b',app['c','d','e'], 'f', app['a','b','c','d','e','f']
L2 = ['c','d','e']
L3 = app['a','b','c','d','e','f']
and the longest one has length 6
*)
let rec max_list_length acc = function
| Nil -> acc
| Cons (a, (UVar (_, _, _) | AppUVar (_, _, _) | Arg (_, _) | AppArg (_, _) | Discard)) ->
let alen = max_list_length 0 a in
max (acc+2) alen
| Cons (a, b)->
let alen = max_list_length 0 a in
let blen = max_list_length (acc+1) b in
max blen alen
| App (_,x,xs) -> List.fold_left (fun acc x -> max acc (max_list_length 0 x)) acc (x::xs)
| Builtin (_, xs) -> List.fold_left (fun acc x -> max acc (max_list_length 0 x)) acc xs
| Lam t -> max_list_length acc t
| Discard | Const _ |CData _ | UVar (_, _, _) | AppUVar (_, _, _) | Arg (_, _) | AppArg (_, _) -> acc

let max_lists_length = List.fold_left (fun acc e -> max (max_list_length 0 e) acc) 0

(**
[arg_to_trie_path ~safe ~depth ~is_goal args arg_depths arg_modes]
[arg_to_trie_path ~safe ~depth ~is_goal args arg_depths arg_modes mp max_list_length]
returns the path represetation of a term to be used in indexing with trie.
args, args_depths and arg_modes are the lists of respectively the arguments, the
depths and the modes of the current term to be indexed.
Expand All @@ -2517,10 +2541,14 @@ let hash_goal_arg_list = hash_arg_list true
In the former case, each argument we add a special mkInputMode/mkOutputMode
node before each argument to be indexed. This special node is used during
instance retrival to know the mode of the current argument
- mp is the max_path size of any path in the index used to truncate the goal
- max_list_length is the length of the longest sublist in each term of args
*)
let arg_to_trie_path ~safe ~depth ~is_goal args arg_depths args_depths_ar mode mp : Discrimination_tree.Path.t =
let arg_to_trie_path ~safe ~depth ~is_goal args arg_depths args_depths_ar mode mp (max_list_length':int) : Discrimination_tree.Path.t =
let open Discrimination_tree in
let path = Path.make (max mp 8) mkPathEnd in

let path_length = mp + Array.length args_depths_ar + 1 in
let path = Path.make path_length mkPathEnd in

let current_ar_pos = ref 0 in
let current_user_depth = ref 0 in
Expand All @@ -2537,7 +2565,7 @@ let arg_to_trie_path ~safe ~depth ~is_goal args arg_depths args_depths_ar mode m
end
in

let rec list_to_trie_path ~safe ~depth ?(h=0) path_depth (len: int) (t: term) : unit =
let rec list_to_trie_path ~safe ~depth ~h path_depth (len: int) (t: term) : unit =
match deref_head ~depth t with
| Nil -> Path.emit path mkListEnd; update_current_min_depth path_depth
| Cons (a, b) ->
Expand All @@ -2549,10 +2577,10 @@ let arg_to_trie_path ~safe ~depth ~is_goal args arg_depths args_depths_ar mode m
(* has the node `app` with arity `1` as first*)
(* cell, then come the elment of the list *)
(* up to the 30^th elemebt *)
if h > 30 then (Path.emit path mkListEnd; update_current_min_depth path_depth)
if is_goal && h >= max_list_length' then (Path.emit path mkListTailVariableUnif; update_current_min_depth path_depth)
else
main ~safe ~depth a path_depth;
list_to_trie_path ~depth ~safe ~h:(h+1) path_depth (len+1) b
(main ~safe ~depth a path_depth;
list_to_trie_path ~depth ~safe ~h:(h+1) path_depth (len+1) b)

(* These cases can come from terms like `[_ | _]`, `[_ | A]` ... *)
| UVar _ | AppUVar _ | Arg _ | AppArg _ | Discard -> Path.emit path mkListTailVariable; update_current_min_depth path_depth
Expand All @@ -2578,6 +2606,8 @@ let arg_to_trie_path ~safe ~depth ~is_goal args arg_depths args_depths_ar mode m
(** gives the path representation of a term *)
and main ~safe ~depth t path_depth : unit =
if path_depth = 0 then update_current_min_depth path_depth
else if is_goal && Path.get_builder_pos path >= path_length then
(Path.emit path mkAny; update_current_min_depth path_depth)
else
let path_depth = path_depth - 1 in
match deref_head ~depth t with
Expand All @@ -2587,7 +2617,7 @@ let arg_to_trie_path ~safe ~depth ~is_goal args arg_depths args_depths_ar mode m
| CData d -> Path.emit path @@ mkPrimitive d; update_current_min_depth path_depth
| App (k,_,_) when k == Global_symbols.uvarc -> Path.emit path @@ mkVariable; update_current_min_depth path_depth
| App (k,a,_) when k == Global_symbols.asc -> main ~safe ~depth a (path_depth+1)
| Lam _ -> Path.emit path @@ mkLam; update_current_min_depth path_depth (* loose indexing to enable eta *)
| Lam _ -> Path.emit path mkAny; update_current_min_depth path_depth (* loose indexing to enable eta *)
| Arg _ | UVar _ | AppArg _ | AppUVar _ | Discard -> Path.emit path @@ mkVariable; update_current_min_depth path_depth
| Builtin (k,tl) ->
Path.emit path @@ mkConstant ~safe ~data:k ~arity:(if path_depth = 0 then 0 else List.length tl);
Expand All @@ -2601,7 +2631,7 @@ let arg_to_trie_path ~safe ~depth ~is_goal args arg_depths args_depths_ar mode m
| Cons (x,xs) ->
Path.emit path mkListHead;
main ~safe ~depth x (path_depth + 1);
list_to_trie_path ~safe ~depth (path_depth + 1) 0 xs
list_to_trie_path ~safe ~depth ~h:1 (path_depth + 1) 0 xs

(** builds the sub-path of a sublist of arguments of the current clause *)
and make_sub_path arg_hd arg_tl arg_depth_hd arg_depth_tl mode_hd mode_tl =
Expand Down Expand Up @@ -2689,9 +2719,11 @@ let add_clause_to_snd_lvl_idx ~depth ~insert predicate clause = function
| IndexWithDiscriminationTree {mode; arg_depths; args_idx; } ->
let max_depths = Discrimination_tree.max_depths args_idx in
let max_path = Discrimination_tree.max_path args_idx in
let path = arg_to_trie_path ~depth ~safe:true ~is_goal:false clause.args arg_depths max_depths mode max_path in
let max_list_length = max_lists_length clause.args in
(* Format.printf "[%d] Going to index [%a]\n%!" max_list_length (pplist pp_term ",") clause.args; *)
let path = arg_to_trie_path ~depth ~safe:true ~is_goal:false clause.args arg_depths max_depths mode max_path max_list_length in
[%spy "dev:disc-tree:depth-path" ~rid pp_string "Inst: MaxDepths " (pplist pp_int "") (Array.to_list max_depths)];
let args_idx = Discrimination_tree.index args_idx path clause in
let args_idx = Discrimination_tree.index args_idx path clause ~max_list_length in
IndexWithDiscriminationTree {
mode; arg_depths;
args_idx = args_idx
Expand Down Expand Up @@ -2899,8 +2931,10 @@ let get_clauses ~depth predicate goal { index = { idx = m } } =
| IndexWithDiscriminationTree {arg_depths; mode; args_idx} ->
let max_depths = Discrimination_tree.max_depths args_idx in
let max_path = Discrimination_tree.max_path args_idx in
let (path: Discrimination_tree.Path.t) = arg_to_trie_path ~safe:false ~depth ~is_goal:true (trie_goal_args goal) arg_depths max_depths mode max_path in
let max_list_length = Discrimination_tree.max_list_length args_idx in
let path = arg_to_trie_path ~safe:false ~depth ~is_goal:true (trie_goal_args goal) arg_depths max_depths mode max_path max_list_length in
[%spy "dev:disc-tree:depth-path" ~rid pp_string "Goal: MaxDepths " (pplist pp_int ";") (Array.to_list max_depths)];
[%spy "dev:disc-tree:list-size-path" ~rid pp_string "Goal: MaxListSize " pp_int max_list_length];
(* Format.(printf "Goal: MaxDepth is %a\n" (pp_print_list ~pp_sep:(fun fmt _ -> pp_print_string fmt " ") pp_print_int) (Discrimination_tree.max_depths args_idx |> Array.to_list)); *)
[%spy "dev:disc-tree:path" ~rid
Discrimination_tree.Path.pp path
Expand Down
Loading

0 comments on commit 9419bc8

Please sign in to comment.