diff --git a/ml-proto/host/parser.mly b/ml-proto/host/parser.mly index 38b48ba0b3..2c4e15126f 100644 --- a/ml-proto/host/parser.mly +++ b/ml-proto/host/parser.mly @@ -67,7 +67,7 @@ let enter_func c = assert (VarMap.is_empty c.labels); {c with labels = VarMap.add "return" 0 c.labels; locals = empty ()} -let lookup_type c x = +let type_ c x = try VarMap.find x.it c.types.tmap with Not_found -> Error.error x.at ("unknown type " ^ x.it) @@ -78,7 +78,6 @@ let lookup category space x = let func c x = lookup "function" c.funcs x let import c x = lookup "import" c.imports x let local c x = lookup "local" c.locals x -let table c x = lookup "table" (empty ()) x let label c x = try VarMap.find x.it c.labels with Not_found -> Error.error x.at ("unknown label " ^ x.it) @@ -114,7 +113,7 @@ let anon_label c = {c with labels = VarMap.map ((+) 1) c.labels} let empty_type = {ins = []; out = None} let explicit_decl c name t at = - let x = name c lookup_type in + let x = name c type_ in if x.it < List.length c.types.tlist && t <> empty_type && t <> List.nth c.types.tlist x.it then @@ -126,7 +125,6 @@ let implicit_decl c t at = | None -> let i = List.length c.types.tlist in anon_type c t; i @@ at | Some i -> i @@ at - %} %token INT FLOAT TEXT VAR VALUE_TYPE LPAR RPAR @@ -231,7 +229,7 @@ expr1 : | CALL var expr_list { fun c -> call ($2 c func, $3 c) } | CALL_IMPORT var expr_list { fun c -> call_import ($2 c import, $3 c) } | CALL_INDIRECT var expr expr_list - { fun c -> call_indirect ($2 c table, $3 c, $4 c) } + { fun c -> call_indirect ($2 c type_, $3 c, $4 c) } | GET_LOCAL var { fun c -> get_local ($2 c local) } | SET_LOCAL var expr { fun c -> set_local ($2 c local, $3 c) } | LOAD expr { fun c -> load ($1, $2 c) } @@ -354,6 +352,11 @@ type_def : { fun c -> bind_type c $3 $6 } ; +table : + | LPAR TABLE var_list RPAR + { fun c -> $3 c func } +; + import : | LPAR IMPORT TEXT TEXT type_use RPAR { let at = at () in @@ -382,7 +385,7 @@ module_fields : | /* empty */ { fun c -> {memory = None; types = c.types.tlist; funcs = []; imports = []; - exports = []; tables = []} } + exports = []; table = []} } | func module_fields { fun c -> let f = $1 c in let m = $2 c in {m with funcs = f () :: m.funcs} } @@ -392,9 +395,9 @@ module_fields : | export module_fields { fun c -> let m = $2 c in {m with exports = $1 c :: m.exports} } - | LPAR TABLE var_list RPAR module_fields - { fun c -> let m = $5 c in - {m with tables = ($3 c func @@ ati 3) :: m.tables} } + | table module_fields + { fun c -> let m = $2 c in + {m with table = ($1 c) @ m.table} } | type_def module_fields { fun c -> $1 c; $2 c } | memory module_fields diff --git a/ml-proto/host/print.ml b/ml-proto/host/print.ml index e718723074..7ba63b784e 100644 --- a/ml-proto/host/print.ml +++ b/ml-proto/host/print.ml @@ -28,8 +28,8 @@ let print_func_sig m prefix i f = let print_export_sig m prefix n f = printf "%s \"%s\" : %s\n" prefix n (string_of_func_type (func_type m f)) -let print_table_sig prefix i t_opt = - printf "%s %d : %s\n" prefix i (string_of_table_type t_opt) +let print_table_elem i x = + printf "table [%d] = func %d\n" i x.it (* Ast *) @@ -40,19 +40,11 @@ let print_func m i f = let print_export m i ex = print_export_sig m "export" ex.it.name (List.nth m.it.funcs ex.it.func.it) -let print_table m i tab = - let t_opt = - match tab.it with - | [] -> None - | x::_ -> Some (func_type m (List.nth m.it.funcs x.it)) - in print_table_sig "table" i t_opt - - let print_module m = - let {funcs; exports; tables} = m.it in + let {funcs; exports; table} = m.it in List.iteri (print_func m) funcs; List.iteri (print_export m) exports; - List.iteri (print_table m) tables; + List.iteri print_table_elem table; flush_all () let print_module_sig m = diff --git a/ml-proto/spec/ast.ml b/ml-proto/spec/ast.ml index a52fd5056f..7e0cf8e90e 100644 --- a/ml-proto/spec/ast.ml +++ b/ml-proto/spec/ast.ml @@ -140,8 +140,6 @@ and import' = func_name : string; } -type table = var list Source.phrase - type module_ = module_' Source.phrase and module_' = { @@ -150,5 +148,5 @@ and module_' = funcs : func list; imports : import list; exports : export list; - tables : table list; + table : var list; } diff --git a/ml-proto/spec/check.ml b/ml-proto/spec/check.ml index 0ce2e17fd0..6b7ac6758f 100644 --- a/ml-proto/spec/check.ml +++ b/ml-proto/spec/check.ml @@ -20,7 +20,6 @@ type context = types : func_type list; funcs : func_type list; imports : func_type list; - tables : func_type list; locals : value_type list; return : expr_type; labels : expr_type list; @@ -31,10 +30,10 @@ let lookup category list x = try List.nth list x.it with Failure _ -> error x.at ("unknown " ^ category ^ " " ^ string_of_int x.it) +let type_ types x = lookup "function type" types x let func c x = lookup "function" c.funcs x let import c x = lookup "import" c.imports x let local c x = lookup "local" c.locals x -let table c x = lookup "table" c.tables x let label c x = lookup "label" c.labels x @@ -45,9 +44,6 @@ let check_type actual expected at = ("type mismatch: expression has type " ^ string_of_expr_type actual ^ " but the context requires " ^ string_of_expr_type expected) -let check_func_type actual expected at = - require (actual = expected) at "inconsistent function type in table" - (* Type Synthesis *) @@ -154,7 +150,7 @@ let rec check_expr c et e = check_type out et e.at | CallIndirect (x, e1, es) -> - let {ins; out} = table c x in + let {ins; out} = type_ c.types x in check_expr c (Some Int32Type) e1; check_exprs c ins es; check_type out et e.at @@ -268,25 +264,14 @@ and check_mem_type ty sz at = * s : func_type *) -let get_type types t = - require (t.it < List.length types) t.at "type index out of bounds"; - List.nth types t.it - let check_func c f = let {ftype; locals; body} = f.it in - let s = get_type c.types ftype in + let s = type_ c.types ftype in let c' = {c with locals = s.ins @ locals; return = s.out} in check_expr c' s.out body -let check_table funcs tables tab = - match tab.it with - | [] -> - error tab.at "empty table" - | x::xs -> - let func x = lookup "function" funcs x in - let s = func x in - List.iter (fun xI -> check_func_type (func xI) s xI.at) xs; - tables @ [s] +let check_elem c x = + ignore (func c x) module NameSet = Set.Make(String) @@ -315,16 +300,15 @@ let check_memory memory = ignore (List.fold_left (check_segment mem.initial) Int64.zero mem.segments) let check_module m = - let {memory; types; funcs; imports; exports; tables} = m.it in + let {memory; types; funcs; imports; exports; table} = m.it in Lib.Option.app check_memory memory; - let func_types = List.map (fun f -> get_type types f.it.ftype) funcs in let c = {types; - funcs = func_types; - imports = List.map (fun i -> get_type types i.it.itype) imports; - tables = List.fold_left (check_table func_types) [] tables; + funcs = List.map (fun f -> type_ types f.it.ftype) funcs; + imports = List.map (fun i -> type_ types i.it.itype) imports; locals = []; return = None; labels = []; has_memory = memory <> None} in List.iter (check_func c) funcs; + List.iter (check_elem c) table; ignore (List.fold_left (check_export c) NameSet.empty exports) diff --git a/ml-proto/spec/eval.ml b/ml-proto/spec/eval.ml index 32d16ba0a9..c8fdec62ac 100644 --- a/ml-proto/spec/eval.ml +++ b/ml-proto/spec/eval.ml @@ -24,7 +24,6 @@ type instance = module_ : module_; imports : import list; exports : export_map; - tables : func list list; memory : Memory.t option; host : host_params } @@ -45,12 +44,18 @@ let lookup category list x = try List.nth list x.it with Failure _ -> error x.at ("runtime: undefined " ^ category ^ " " ^ string_of_int x.it) +let type_ c x = lookup "type" c.instance.module_.it.types x let func c x = lookup "function" c.instance.module_.it.funcs x let import c x = lookup "import" c.instance.imports x -let table c x y = lookup "entry" (lookup "table" c.instance.tables x) y let local c x = lookup "local" c.locals x let label c x = lookup "label" c.labels x +let table_elem c i at = + if i < 0l || i <> Int32.of_int (Int32.to_int i) then + error at ("runtime: undefined table element " ^ Int32.to_string i); + let x = (Int32.to_int i) @@ at in + lookup "table element" c.instance.module_.it.table x + let export m x = try ExportMap.find x.it m.exports with Not_found -> @@ -114,10 +119,6 @@ let mem_overflow x = let callstack_exhaustion at = error at ("runtime: callstack exhausted") -let func_type instance f = - assert (f.it.ftype.it < List.length instance.module_.it.types); - List.nth instance.module_.it.types f.it.ftype.it - (* Evaluation *) @@ -171,11 +172,13 @@ let rec eval_expr (c : config) (e : expr) = let vs = List.map (fun ev -> some (eval_expr c ev) ev.at) es in (import c x) vs - | CallIndirect (x, e1, es) -> + | CallIndirect (ftype, e1, es) -> let i = int32 (eval_expr c e1) e1.at in let vs = List.map (fun vo -> some (eval_expr c vo) vo.at) es in - (* TODO: The conversion to int could overflow. *) - eval_func c.instance (table c x (Int32.to_int i @@ e1.at)) vs + let f = func c (table_elem c i e1.at) in + if ftype.it <> f.it.ftype.it then + error e1.at "runtime: indirect call signature mismatch"; + eval_func c.instance f vs | GetLocal x -> Some !(local c x) @@ -269,7 +272,7 @@ and eval_func instance f vs = let vars = List.map (fun t -> ref (default_value t)) f.it.locals in let locals = args @ vars in let c = {instance; locals; labels = []} in - coerce (func_type instance f).out (eval_expr c f.it.body) + coerce (type_ c f.it.ftype).out (eval_expr c f.it.body) and coerce et vo = if et = None then None else vo @@ -309,22 +312,22 @@ let init_memory {it = {initial; segments; _}} = Memory.init mem (List.map it segments); mem +let add_export funcs ex = + ExportMap.add ex.it.name (List.nth funcs ex.it.func.it) + let init m imports host = assert (List.length imports = List.length m.it.Ast.imports); assert (host.page_size > 0L); assert (Lib.Int64.is_power_of_two host.page_size); - let {memory; funcs; exports; tables; _} = m.it in - let memory' = Lib.Option.map init_memory memory in - let func x = List.nth funcs x.it in - let export ex = ExportMap.add ex.it.name (func ex.it.func) in - let exports = List.fold_right export exports ExportMap.empty in - let tables = List.map (fun tab -> List.map func tab.it) tables in - {module_ = m; imports; exports; tables; memory = memory'; host} + let {memory; funcs; exports; _} = m.it in + {module_ = m; + imports; + exports = List.fold_right (add_export funcs) exports ExportMap.empty; + memory = Lib.Option.map init_memory memory; + host} let invoke instance name vs = try - let f = export instance (name @@ no_region) in - assert (List.length vs = List.length (func_type instance f).ins); - eval_func instance f vs + eval_func instance (export instance (name @@ no_region)) vs with Stack_overflow -> callstack_exhaustion no_region diff --git a/ml-proto/test/func_ptrs.wast b/ml-proto/test/func_ptrs.wast index c79144f7bb..43d65bd289 100644 --- a/ml-proto/test/func_ptrs.wast +++ b/ml-proto/test/func_ptrs.wast @@ -30,5 +30,48 @@ (assert_return (invoke "three" (i32.const 13)) (i32.const 11)) (invoke "four" (i32.const 83)) -(assert_invalid (module (func (type 42))) "type index out of bounds") -(assert_invalid (module (import "stdio" "print" (type 43))) "type index out of bounds") +(assert_invalid (module (func (type 42))) "unknown function type 42") +(assert_invalid (module (import "stdio" "print" (type 43))) "unknown function type 43") + +(module + (type $T (func (param) (result i32))) + (type $U (func (param) (result i32))) + (table $t1 $t2 $t3 $u1 $u2 $t1 $t3) + + (func $t1 (type $T) (i32.const 1)) + (func $t2 (type $T) (i32.const 2)) + (func $t3 (type $T) (i32.const 3)) + (func $u1 (type $U) (i32.const 4)) + (func $u2 (type $U) (i32.const 5)) + + (func $callt (param $i i32) (result i32) + (call_indirect $T (get_local $i)) + ) + (export "callt" $callt) + + (func $callu (param $i i32) (result i32) + (call_indirect $U (get_local $i)) + ) + (export "callu" $callu) +) + +(assert_return (invoke "callt" (i32.const 0)) (i32.const 1)) +(assert_return (invoke "callt" (i32.const 1)) (i32.const 2)) +(assert_return (invoke "callt" (i32.const 2)) (i32.const 3)) +(assert_trap (invoke "callt" (i32.const 3)) "runtime: indirect call signature mismatch") +(assert_trap (invoke "callt" (i32.const 4)) "runtime: indirect call signature mismatch") +(assert_return (invoke "callt" (i32.const 5)) (i32.const 1)) +(assert_return (invoke "callt" (i32.const 6)) (i32.const 3)) +(assert_trap (invoke "callt" (i32.const 7)) "runtime: undefined table element 7") +(assert_trap (invoke "callt" (i32.const 100)) "runtime: undefined table element 100") +(assert_trap (invoke "callt" (i32.const -1)) "runtime: undefined table element -1") + +(assert_trap (invoke "callu" (i32.const 0)) "runtime: indirect call signature mismatch") +(assert_trap (invoke "callu" (i32.const 1)) "runtime: indirect call signature mismatch") +(assert_trap (invoke "callu" (i32.const 2)) "runtime: indirect call signature mismatch") +(assert_return (invoke "callu" (i32.const 3)) (i32.const 4)) +(assert_return (invoke "callu" (i32.const 4)) (i32.const 5)) +(assert_trap (invoke "callu" (i32.const 5)) "runtime: indirect call signature mismatch") +(assert_trap (invoke "callu" (i32.const 6)) "runtime: indirect call signature mismatch") +(assert_trap (invoke "callu" (i32.const 7)) "runtime: undefined table element 7") +(assert_trap (invoke "callu" (i32.const -1)) "runtime: undefined table element -1")