Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract function and node specific logic from TD3 #596

Merged
merged 19 commits into from
Nov 30, 2022
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gobview
37 changes: 19 additions & 18 deletions src/framework/analyses.ml
Original file line number Diff line number Diff line change
Expand Up @@ -481,26 +481,24 @@ sig
val access: (D.t, G.t, C.t, V.t) ctx -> Queries.access -> A.t
end

type analyzed_data = {
solver_data: Obj.t;
}

type increment_data = {
server: bool;

old_data: analyzed_data option;
solver_data: Obj.t;
changes: CompareCIL.change_info;

(* Globals for which the constraint
system unknowns should be restarted *)
restarting: VarQuery.t list;
}

let empty_increment_data ?(server=false) () = {
server;
old_data = None;
changes = CompareCIL.empty_change_info ();
restarting = []
(** Abstract incremental change to constraint system.
@param 'v constrain system variable type *)
type 'v sys_change_info = {
obsolete: 'v list; (** Variables to destabilize. *)
delete: 'v list; (** Variables to delete. *)
reluctant: 'v list; (** Variables to solve reluctantly. *)
restart: 'v list; (** Variables to restart. *)
}

(** A side-effecting system. *)
Expand All @@ -522,10 +520,8 @@ sig
(** The system in functional form. *)
val system : v -> ((v -> d) -> (v -> d -> unit) -> d) m

(** Data used for incremental analysis *)
val increment : increment_data

val iter_vars: (v -> d) -> VarQuery.t -> v VarQuery.f -> unit
val sys_change: (v -> d) -> v sys_change_info
(** Compute incremental constraint system change from old solution. *)
end

(** Any system of side-effecting equations over lattices. *)
Expand All @@ -539,9 +535,8 @@ sig

module D : Lattice.S
module G : Lattice.S
val increment : increment_data
val system : LVar.t -> ((LVar.t -> D.t) -> (LVar.t -> D.t -> unit) -> (GVar.t -> G.t) -> (GVar.t -> G.t -> unit) -> D.t) option
val iter_vars: (LVar.t -> D.t) -> (GVar.t -> G.t) -> VarQuery.t -> LVar.t VarQuery.f -> GVar.t VarQuery.f -> unit
val sys_change: (LVar.t -> D.t) -> (GVar.t -> G.t) -> [`L of LVar.t | `G of GVar.t] sys_change_info
end

(** A solver is something that can translate a system into a solution (hash-table).
Expand All @@ -552,10 +547,13 @@ module type GenericEqBoxIncrSolverBase =
sig
type marshal

val copy_marshal: marshal -> marshal
val relift_marshal: marshal -> marshal

(** The hash-map that is the first component of [solve box xs vs] is a local solution for interesting variables [vs],
reached from starting values [xs].
As a second component the solver returns data structures for incremental serialization. *)
val solve : (S.v -> S.d -> S.d -> S.d) -> (S.v*S.d) list -> S.v list -> S.d H.t * marshal
val solve : (S.v -> S.d -> S.d -> S.d) -> (S.v*S.d) list -> S.v list -> marshal option -> S.d H.t * marshal
end

(** (Incremental) solver argument, indicating which postsolving should be performed by the solver. *)
Expand Down Expand Up @@ -590,10 +588,13 @@ module type GenericGlobSolver =
sig
type marshal

val copy_marshal: marshal -> marshal
val relift_marshal: marshal -> marshal

(** The hash-map that is the first component of [solve box xs vs] is a local solution for interesting variables [vs],
reached from starting values [xs].
As a second component the solver returns data structures for incremental serialization. *)
val solve : (S.LVar.t*S.D.t) list -> (S.GVar.t*S.G.t) list -> S.LVar.t list -> (S.D.t LH.t * S.G.t GH.t) * marshal
val solve : (S.LVar.t*S.D.t) list -> (S.GVar.t*S.G.t) list -> S.LVar.t list -> marshal option -> (S.D.t LH.t * S.G.t GH.t) * marshal
end

module ResultType2 (S:Spec) =
Expand Down
215 changes: 172 additions & 43 deletions src/framework/constraints.ml
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,40 @@ end

module type Increment =
sig
val increment: increment_data
val increment: increment_data option
end

(** Combined variables so that we can also use the more common [EqConstrSys]
that uses only one kind of a variable. *)
module Var2 (LV:VarType) (GV:VarType)
: VarType
with type t = [ `L of LV.t | `G of GV.t ]
=
struct
type t = [ `L of LV.t | `G of GV.t ] [@@deriving eq, ord, hash]
let relift = function
| `L x -> `L (LV.relift x)
| `G x -> `G (GV.relift x)

let pretty_trace () = function
| `L a -> LV.pretty_trace () a
| `G a -> GV.pretty_trace () a

let printXml f = function
| `L a -> LV.printXml f a
| `G a -> GV.printXml f a

let var_id = function
| `L a -> LV.var_id a
| `G a -> GV.var_id a

let node = function
| `L a -> LV.node a
| `G a -> GV.node a

let is_write_only = function
| `L a -> LV.is_write_only a
| `G a -> GV.is_write_only a
end

(** The main point of this file---generating a [GlobConstrSys] from a [Spec]. *)
Expand All @@ -460,9 +493,6 @@ struct
1. S.V -> S.G -- used for Spec
2. fundec -> set of S.C -- used for IterSysVars Node *)

(* Dummy module. No incremental analysis supported here*)
let increment = I.increment

let sync ctx =
match Cfg.prev ctx.prev_node with
| _ :: _ :: _ -> S.sync ctx `Join
Expand Down Expand Up @@ -760,6 +790,133 @@ struct
) cs
| _ ->
()

let sys_change getl getg =
let open CompareCIL in

let c = match I.increment with
| Some {changes; _} -> changes
| None -> empty_change_info ()
in
List.(Printf.printf "change_info = { unchanged = %d; changed = %d; added = %d; removed = %d }\n" (length c.unchanged) (length c.changed) (length c.added) (length c.removed));

let changed_funs = List.filter_map (function
| {old = GFun (f, _); diff = None; _} ->
print_endline ("Completely changed function: " ^ f.svar.vname);
Some f
| _ -> None
) c.changed
in
let part_changed_funs = List.filter_map (function
| {old = GFun (f, _); diff = Some nd; _} ->
print_endline ("Partially changed function: " ^ f.svar.vname);
Some (f, nd.primObsoleteNodes, nd.unchangedNodes)
| _ -> None
) c.changed
in
let removed_funs = List.filter_map (function
| GFun (f, _) ->
print_endline ("Removed function: " ^ f.svar.vname);
Some f
| _ -> None
) c.removed
in

let module HM = Hashtbl.Make (Var2 (LVar) (GVar)) in

let mark_node hm f node =
iter_vars getl getg (Node {node; fundec = Some f}) (fun v ->
HM.replace hm (`L v) ()
) (fun v ->
HM.replace hm (`G v) ()
)
in

let reluctant = GobConfig.get_bool "incremental.reluctant.enabled" in
let reanalyze_entry f =
(* destabilize the entry points of a changed function when reluctant is off,
or the function is to be force-reanalyzed *)
(not reluctant) || CompareCIL.VarinfoSet.mem f.svar c.exclude_from_rel_destab
in
let obsolete_ret = HM.create 103 in
let obsolete_entry = HM.create 103 in
let obsolete_prim = HM.create 103 in

(* When reluctant is on:
Only add function entry nodes to obsolete_entry if they are in force-reanalyze *)
List.iter (fun f ->
if reanalyze_entry f then
(* collect function entry for eager destabilization *)
mark_node obsolete_entry f (FunctionEntry f)
else
(* collect function return for reluctant analysis *)
mark_node obsolete_ret f (Function f)
) changed_funs;
(* Unknowns from partially changed functions need only to be collected for eager destabilization when reluctant is off *)
(* We utilize that force-reanalyzed functions are always considered as completely changed (and not partially changed) *)
if not reluctant then (
List.iter (fun (f, pn, _) ->
List.iter (fun n ->
mark_node obsolete_prim f n
) pn;
mark_node obsolete_ret f (Function f);
) part_changed_funs;
);
sim642 marked this conversation as resolved.
Show resolved Hide resolved

let obsolete = Enum.append (HM.keys obsolete_entry) (HM.keys obsolete_prim) |> List.of_enum in
let reluctant = HM.keys obsolete_ret |> List.of_enum in

let marked_for_deletion = HM.create 103 in

let dummy_pseudo_return_node f =
(* not the same as in CFG, but compares equal because of sid *)
Node.Statement ({Cil.dummyStmt with sid = CfgTools.get_pseudo_return_id f})
in
let add_nodes_of_fun (functions: fundec list) (withEntry: fundec -> bool) =
let add_stmts (f: fundec) =
List.iter (fun s ->
mark_node marked_for_deletion f (Statement s)
) f.sallstmts
in
List.iter (fun f ->
if withEntry f then
mark_node marked_for_deletion f (FunctionEntry f);
mark_node marked_for_deletion f (Function f);
add_stmts f;
mark_node marked_for_deletion f (dummy_pseudo_return_node f)
) functions;
in

add_nodes_of_fun changed_funs reanalyze_entry;
add_nodes_of_fun removed_funs (fun _ -> true);
(* it is necessary to remove all unknowns for changed pseudo-returns because they have static ids *)
let add_pseudo_return f un =
let pseudo = dummy_pseudo_return_node f in
if not (List.exists (Node.equal pseudo % fst) un) then
mark_node marked_for_deletion f (dummy_pseudo_return_node f)
in
List.iter (fun (f,_,un) ->
mark_node marked_for_deletion f (Function f);
add_pseudo_return f un
) part_changed_funs;

let delete = HM.keys marked_for_deletion |> List.of_enum in

let restart = match I.increment with
| Some data ->
let restart = ref [] in
List.iter (fun g ->
iter_vars getl getg g (fun v ->
restart := `L v :: !restart
) (fun v ->
restart := `G v :: !restart
)
) data.restarting;
!restart
| None -> []
in

{obsolete; delete; reluctant; restart}
end

(** Convert a non-incremental solver into an "incremental" solver.
Expand All @@ -771,45 +928,15 @@ module EqIncrSolverFromEqSolver (Sol: GenericEqBoxSolver): GenericEqBoxIncrSolve
module Post = PostSolver.MakeList (PostSolver.ListArgFromStdArg (S) (VH) (Arg))

type marshal = unit
let copy_marshal () = ()
let relift_marshal () = ()

let solve box xs vs =
let solve box xs vs _ =
let vh = Sol.solve box xs vs in
Post.post xs vs vh;
(vh, ())
end

(** Combined variables so that we can also use the more common [EqConstrSys]
that uses only one kind of a variable. *)
module Var2 (LV:VarType) (GV:VarType)
: VarType
with type t = [ `L of LV.t | `G of GV.t ]
=
struct
type t = [ `L of LV.t | `G of GV.t ] [@@deriving eq, ord, hash]
let relift = function
| `L x -> `L (LV.relift x)
| `G x -> `G (GV.relift x)

let pretty_trace () = function
| `L a -> LV.pretty_trace () a
| `G a -> GV.pretty_trace () a

let printXml f = function
| `L a -> LV.printXml f a
| `G a -> GV.printXml f a

let var_id = function
| `L a -> LV.var_id a
| `G a -> GV.var_id a

let node = function
| `L a -> LV.node a
| `G a -> GV.node a

let is_write_only = function
| `L a -> LV.is_write_only a
| `G a -> GV.is_write_only a
end

(** Translate a [GlobConstrSys] into a [EqConstrSys] *)
module EqConstrSysFromGlobConstrSys (S:GlobConstrSys)
Expand All @@ -828,7 +955,6 @@ struct
| `Lifted2 a -> S.D.printXml f a
| (`Bot | `Top) as x -> printXml f x
end
let increment = S.increment
type v = Var.t
type d = Dom.t

Expand Down Expand Up @@ -858,8 +984,8 @@ struct
| `G _ -> None
| `L x -> Option.map conv (S.system x)

let iter_vars get vq f =
S.iter_vars (getL % get % l) (getG % get % g) vq (f % l) (f % g)
let sys_change get =
S.sys_change (getL % get % l) (getG % get % g)
end

(** Splits a [EqConstrSys] solution into a [GlobConstrSys] solution with given [Hashtbl.S] for the [EqConstrSys]. *)
Expand Down Expand Up @@ -902,7 +1028,7 @@ end

(** Transforms a [GenericEqBoxIncrSolver] into a [GenericGlobSolver]. *)
module GlobSolverFromEqSolver (Sol:GenericEqBoxIncrSolverBase)
: GenericGlobSolver
(* : GenericGlobSolver *)
sim642 marked this conversation as resolved.
Show resolved Hide resolved
= functor (S:GlobConstrSys) ->
functor (LH:Hashtbl.S with type key=S.LVar.t) ->
functor (GH:Hashtbl.S with type key=S.GVar.t) ->
Expand All @@ -916,11 +1042,14 @@ module GlobSolverFromEqSolver (Sol:GenericEqBoxIncrSolverBase)

type marshal = Sol'.marshal

let solve ls gs l =
let copy_marshal = Sol'.copy_marshal
let relift_marshal = Sol'.relift_marshal

let solve ls gs l old_data =
let vs = List.map (fun (x,v) -> `L x, `Lifted2 v) ls
@ List.map (fun (x,v) -> `G x, `Lifted1 v) gs in
let sv = List.map (fun x -> `L x) l in
let hm, solver_data = Sol'.solve EqSys.box vs sv in
let hm, solver_data = Sol'.solve EqSys.box vs sv old_data in
Splitter.split_solution hm, solver_data
end

Expand Down
Loading