Skip to content

Commit

Permalink
Merge pull request #2 from andersfugmann/andersfugmann/merge_messages
Browse files Browse the repository at this point in the history
Implement message merging
  • Loading branch information
andersfugmann authored Feb 13, 2024
2 parents 9004657 + 9140a26 commit e2247b2
Show file tree
Hide file tree
Showing 22 changed files with 885 additions and 218 deletions.
94 changes: 46 additions & 48 deletions src/ocaml_protoc_plugin/deserialize.ml
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,16 @@ module S = Spec.Deserialize
module C = S.C
open S

type required = Required | Optional

type 'a reader = 'a -> Reader.t -> Field.field_type -> 'a
type 'a getter = 'a -> 'a
type ('a, 'b) getter = 'a -> 'b
type 'a field_spec = (int * 'a reader)
type 'a value = ('a field_spec list * required * 'a * 'a getter)
type _ value = Value: ('b field_spec list * 'b * ('b, 'a) getter) -> 'a value
type extensions = (int * Field.t) list

type (_, _) value_list =
| VNil : ('a, 'a) value_list
| VNil_ext : (extensions -> 'a, 'a) value_list
| VCons : ('a value) * ('b, 'c) value_list -> ('a -> 'b, 'c) value_list
| VCons : 'a value * ('b, 'c) value_list -> ('a -> 'b, 'c) value_list

type sentinel_field_spec = int * (Reader.t -> Field.field_type -> unit)
type 'a sentinel_getter = unit -> 'a
Expand Down Expand Up @@ -85,10 +83,10 @@ let read_of_spec: type a. a spec -> Field.field_type * (Reader.t -> a) = functio
let v = Bytes.create length in
Bytes.blit_string ~src:data ~src_pos:offset ~dst:v ~dst_pos:0 ~len:length;
v
| Message from_proto -> Length_delimited, fun reader ->
| Message (from_proto, _merge) -> Length_delimited, fun reader ->
let Field.{ offset; length; data } = Reader.read_length_delimited reader in
from_proto (Reader.create ~offset ~length data)

(*
let default_value: type a. a spec -> a = function
| Double -> 0.0
| Float -> 0.0
Expand All @@ -102,7 +100,7 @@ let default_value: type a. a spec -> a = function
| Fixed64 -> Int64.zero
| SFixed32 -> Int32.zero
| SFixed64 -> Int64.zero
| Message of_proto -> of_proto (Reader.create "")
| Message (of_proto, _merge) -> of_proto (Reader.create "")
| String -> ""
| Bytes -> Bytes.empty
| Int32_int -> 0
Expand All @@ -117,7 +115,7 @@ let default_value: type a. a spec -> a = function
| SFixed64_int -> 0
| Enum of_int -> of_int 0
| Bool -> false

*)
let id x = x
let keep_last _ v = v

Expand All @@ -129,20 +127,29 @@ let read_field ~read:(expect, read_f) ~map v reader field_type =
error_wrong_field "Deserialize" field

let value: type a. a compound -> a value = function
| Basic_req (index, spec) ->
let map _ v2 = Some v2 in
let read = read_field ~read:(read_of_spec spec) ~map in
let getter = function Some v -> v | None -> error_required_field_missing () in
Value ([(index, read)], None, getter)
| Basic (index, spec, default) ->
let read = read_field ~read:(read_of_spec spec) ~map:keep_last in
let required = match default with
| Some _ -> Optional
| None -> Required
let map = keep_last
in
let default = match default with
| None -> default_value spec
| Some default -> default
in
([(index, read)], required, default, id)
let read = read_field ~read:(read_of_spec spec) ~map in
Value ([(index, read)], default, id)
| Basic_opt (index, spec) ->
let read = read_field ~read:(read_of_spec spec) ~map:(fun _ v -> Some v) in
([(index, read)], Optional, None, id)
let map = match spec with
| Message (_, merge) ->
let map v1 v2 =
match v1 with
| None -> Some v2
| Some v1 -> Some (merge v1 v2)
in
map
| _ -> fun _ v -> Some v (* Keep last for all other non-repeated types *)
in
let read = read_field ~read:(read_of_spec spec) ~map in
Value ([(index, read)], None, id)
| Repeated (index, spec, Packed) ->
let field_type, read_f = read_of_spec spec in
let rec read_packed_values read_f acc reader =
Expand All @@ -161,16 +168,16 @@ let value: type a. a compound -> a value = function
let field = Reader.read_field_content ft reader in
error_wrong_field "Deserialize" field
in
([(index, read)], Optional, [], List.rev)
Value ([(index, read)], [], List.rev)
| Repeated (index, spec, Not_packed) ->
let read = read_field ~read:(read_of_spec spec) ~map:(fun vs v -> v :: vs) in
([(index, read)], Optional, [], List.rev)
Value ([(index, read)], [], List.rev)
| Oneof oneofs ->
let make_reader: a oneof -> a field_spec = fun (Oneof_elem (index, spec, constr)) ->
let read = read_field ~read:(read_of_spec spec) ~map:(fun _ -> constr) in
(index, read)
in
(List.map ~f:make_reader oneofs, Optional, `not_set, id)
Value (List.map ~f:make_reader oneofs, `not_set, id)

module IntMap = Map.Make(struct type t = int let compare = Int.compare end)

Expand All @@ -183,15 +190,12 @@ let deserialize_full: type constr a. extension_ranges -> (constr, a) value_list
| VNil -> NNil
| VNil_ext -> NNil_ext
(* Consider optimizing when optional is true *)
| VCons ((fields, required, default, getter), rest) ->
let v = ref (default, required) in
let get () = match !v with
| _, Required -> error_required_field_missing ();
| v, Optional-> getter v
in
| VCons (Value (fields, default, getter), rest) ->
let v = ref default in
let get () = getter !v in
let fields =
List.map ~f:(fun (index, read) ->
let read reader field_type = let v' = fst !v in v := (read v' reader field_type, Optional) in
let read reader field_type = (v := read !v reader field_type) in
(index, read)
) fields
in
Expand Down Expand Up @@ -263,11 +267,11 @@ let deserialize: type constr a. (constr, a) compound_list -> constr -> Reader.t
in

let rec read_values: type constr a. extension_ranges -> Field.field_type -> int -> Reader.t -> constr -> extensions -> (constr, a) value_list -> a = fun extension_ranges tpe idx reader constr extensions ->
let rec read_repeated tpe index read_f default get reader =
let rec read_repeated tpe index read_f default reader =
let default = read_f default reader tpe in
let (tpe, idx) = next_field reader in
match idx = index with
| true -> read_repeated tpe index read_f default get reader
| true -> read_repeated tpe index read_f default reader
| false -> default, tpe, idx
in
function
Expand All @@ -276,34 +280,27 @@ let deserialize: type constr a. (constr, a) compound_list -> constr -> Reader.t
| VNil_ext when idx = Int.max_int ->
constr (List.rev extensions)
(* All fields read successfully. Apply extensions and return result. *)
| VCons (([index, read_f], _required, default, get), vs) when index = idx ->
| VCons (Value ([index, read_f], default, get), vs) when index = idx ->
(* Read all values, and apply constructor once all fields have been read.
This pattern is the most likely to be matched for all values, and is added
as an optimization to avoid reconstructing the value list for each recursion.
*)
let default, tpe, idx = read_repeated tpe index read_f default get reader in
let default, tpe, idx = read_repeated tpe index read_f default reader in
let constr = (constr (get default)) in
read_values extension_ranges tpe idx reader constr extensions vs
| VCons (((index, read_f) :: fields, _required, default, get), vs) when index = idx ->
| VCons (Value ((index, read_f) :: fields, default, get), vs) when index = idx ->
(* Read all values for the given field *)
let default, tpe, idx = read_repeated tpe index read_f default get reader in
read_values extension_ranges tpe idx reader constr extensions (VCons ((fields, Optional, default, get), vs))
let default, tpe, idx = read_repeated tpe index read_f default reader in
read_values extension_ranges tpe idx reader constr extensions (VCons (Value (fields, default, get), vs))
| vs when in_extension_ranges extension_ranges idx ->
(* Extensions may be sent inline. Store all valid extensions, before starting to apply constructors *)
let extensions = (idx, Reader.read_field_content tpe reader) :: extensions in
let (tpe, idx) = next_field reader in
read_values extension_ranges tpe idx reader constr extensions vs
| VCons (([], Required, _default, _get), _vs) ->
(* If there are no more fields to be read we will never find the value.
If all values are read, then raise, else revert to full deserialization *)
begin match (idx = Int.max_int) with
| true -> error_required_field_missing ()
| false -> raise Restart_full
end
| VCons ((_ :: fields, optional, default, get), vs) ->
| VCons (Value (_ :: fields, default, get), vs) ->
(* Drop the field, as we dont expect to find it. *)
read_values extension_ranges tpe idx reader constr extensions (VCons ((fields, optional, default, get), vs))
| VCons (([], Optional, default, get), vs) ->
read_values extension_ranges tpe idx reader constr extensions (VCons (Value (fields, default, get), vs))
| VCons (Value ([], default, get), vs) ->
(* Apply destructor. This case is only relevant for oneof fields *)
read_values extension_ranges tpe idx reader (constr (get default)) extensions vs
| VNil | VNil_ext ->
Expand All @@ -321,6 +318,7 @@ let deserialize: type constr a. (constr, a) compound_list -> constr -> Reader.t
let (tpe, idx) = next_field reader in
try
read_values extension_ranges tpe idx reader constr [] values
with Restart_full ->
with (Restart_full | Result.Error `Required_field_missing) ->
(* Revert to full deserialization *)
Reader.reset reader offset;
deserialize_full extension_ranges values constr reader
1 change: 1 addition & 0 deletions src/ocaml_protoc_plugin/extensions.ml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ let compare _ _ = 0
let index_of_spec: type a. a Spec.Serialize.compound -> int = function
| Basic (index, _, _) -> index
| Basic_opt (index, _) -> index
| Basic_req (index, _) -> index
| Repeated (index, _, _) -> index
| Oneof _ -> failwith "Oneof fields not allowed in extensions"

Expand Down
29 changes: 29 additions & 0 deletions src/ocaml_protoc_plugin/merge.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
(** Merge a two values. Need to match on the spec to merge messages recursivly *)
let merge: type t. t Spec.Deserialize.compound -> t -> t -> t = fun spec t t' -> match spec with
| Spec.Deserialize.Basic (_field, Message (_, _), _) -> failwith "Messages with defaults cannot happen"
| Spec.Deserialize.Basic (_field, _spec, default) when t' = default -> t
| Spec.Deserialize.Basic (_field, _spec, _) -> t'

(* The spec states that proto2 required fields must be transmitted exactly once.
So merging these fields is not possible. The essentially means that you cannot merge
proto2 messages containing required fields.
In this implementation, we choose to ignore this, and adopt 'keep last'
*)
| Spec.Deserialize.Basic_req (_field, Message (_, merge)) -> merge t t'
| Spec.Deserialize.Basic_req (_field, _spec) -> t'
| Spec.Deserialize.Basic_opt (_field, Message (_, merge)) ->
begin
match t, t' with
| None, None -> None
| Some t, None -> Some t
| None, Some t -> Some t
| Some t, Some t' -> Some (merge t t')
end
| Spec.Deserialize.Basic_opt (_field, _spec) -> begin
match t' with
| Some _ -> t'
| None -> t
end
| Spec.Deserialize.Repeated (_field, _, _) -> t @ t'
(* | Spec.Deserialize.Oneof _ when t' = `not_set -> t *)
| Spec.Deserialize.Oneof _ -> failwith "Implementation is part of generated code"
1 change: 1 addition & 0 deletions src/ocaml_protoc_plugin/ocaml_protoc_plugin.ml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module Serialize = Serialize
module Deserialize = Deserialize
module Spec = Spec
module Runtime = Runtime
module Field = Field
(**/**)

module Reader = Reader
Expand Down
1 change: 1 addition & 0 deletions src/ocaml_protoc_plugin/runtime.ml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ module Runtime' = struct
module Extensions = Extensions
module Reader = Reader
module Writer = Writer
module Merge = Merge
end
19 changes: 9 additions & 10 deletions src/ocaml_protoc_plugin/serialize.ml
Original file line number Diff line number Diff line change
Expand Up @@ -120,16 +120,15 @@ let rec write: type a. a compound -> Writer.t -> a -> unit = function
*)
| Basic (index, spec, default) -> begin
let write = write_field spec index in
match default with
| Some default ->
fun writer v -> begin
match v with
| v when v = default -> ()
| v -> write v writer
end
| None ->
fun writer v -> write v writer
let writer writer = function
| v when v = default -> ()
| v -> write v writer
in
writer
end
| Basic_req (index, spec) ->
let write = write_field spec index in
fun writer v -> write v writer
| Basic_opt (index, spec) -> begin
let write = write_field spec index in
fun writer v ->
Expand All @@ -145,7 +144,7 @@ let rec write: type a. a compound -> Writer.t -> a -> unit = function
(* Wonder if we could get the specs before calling v. Wonder what f is? *)
(* We could prob. return a list of all possible values + f v -> v. *)
let Oneof_elem (index, spec, v) = f v in
write (Basic (index, spec, None)) writer v
write (Basic_req (index, spec)) writer v
end

let in_extension_ranges extension_ranges index =
Expand Down
20 changes: 18 additions & 2 deletions src/ocaml_protoc_plugin/spec.ml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ module Make(T : T) = struct
type packed = Packed | Not_packed
type extension_ranges = (int * int) list
type extensions = (int * Field.t) list
type 'a merge = 'a -> 'a -> 'a

type _ spec =
| Double : float spec
Expand Down Expand Up @@ -40,20 +41,34 @@ module Make(T : T) = struct
| String : string spec
| Bytes : bytes spec
| Enum : ('a, int -> 'a, 'a -> int) T.dir -> 'a spec
| Message : ('a, Reader.t -> 'a, Writer.t -> 'a -> Writer.t) T.dir -> 'a spec
| Message : ('a, ((Reader.t -> 'a) * 'a merge), Writer.t -> 'a -> Writer.t) T.dir -> 'a spec

(* Existential types *)
type espec = Espec: _ spec -> espec

type _ oneof =
| Oneof_elem : int * 'b spec * ('a, ('b -> 'a), 'b) T.dir -> 'a oneof

type _ compound =
| Basic : int * 'a spec * 'a option -> 'a compound
(* A field, where the default value is know (and set). This cannot be used for message types *)
| Basic : int * 'a spec * 'a -> 'a compound

(* Proto2/proto3 optional fields. *)
| Basic_opt : int * 'a spec -> 'a option compound

(* Proto2 required fields (and oneof fields) *)
| Basic_req : int * 'a spec -> 'a compound

(* Repeated fields *)
| Repeated : int * 'a spec * packed -> 'a list compound
| Oneof : ('a, 'a oneof list, 'a -> unit oneof) T.dir -> ([> `not_set ] as 'a) compound

type (_, _) compound_list =
| Nil : ('a, 'a) compound_list

(* Nil_ext denotes that the message contains extensions *)
| Nil_ext: extension_ranges -> (extensions -> 'a, 'a) compound_list

| Cons : ('a compound) * ('b, 'c) compound_list -> ('a -> 'b, 'c) compound_list

module C = struct
Expand Down Expand Up @@ -93,6 +108,7 @@ module Make(T : T) = struct

let repeated (i, s, p) = Repeated (i, s, p)
let basic (i, s, d) = Basic (i, s, d)
let basic_req (i, s) = Basic_req (i, s)
let basic_opt (i, s) = Basic_opt (i, s)
let oneof s = Oneof s
let oneof_elem (a, b, c) = Oneof_elem (a, b, c)
Expand Down
9 changes: 7 additions & 2 deletions src/plugin/code.ml
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,13 @@ let emit t indent fmt =
| n -> String.sub ~pos:0 ~len:(String.length s - n) s
in
let prepend s =
String.split_on_char ~sep:'\n' s
|> List.iter ~f:(fun s -> t.code <- (trim_end ~char:' ' (t.indent ^ s)) :: t.code)
match String.split_on_char ~sep:'\n' s with
| line :: lines ->
t.code <- (trim_end ~char:' ' (t.indent ^ line)) :: t.code;
incr t;
List.iter lines ~f:(fun line -> t.code <- (trim_end ~char:' ' (t.indent ^ line)) :: t.code);
decr t;
| [] -> ()
in
let emit s =
match indent with
Expand Down
8 changes: 5 additions & 3 deletions src/plugin/emit.ml
Original file line number Diff line number Diff line change
Expand Up @@ -211,14 +211,16 @@ let rec emit_message ~params ~syntax scope
| Some _name ->
let is_map_entry = is_map_entry options in
let is_cyclic = Scope.is_cyclic scope in
let Types.{ type'; constructor; apply; deserialize_spec; serialize_spec; default_constructor_sig; default_constructor_impl } =
let Types.{ type'; constructor; apply; deserialize_spec; serialize_spec;
default_constructor_sig; default_constructor_impl; merge_impl } =
Types.make ~params ~syntax ~is_cyclic ~is_map_entry ~extension_ranges ~scope ~fields oneof_decls
in
ignore (default_constructor_sig, default_constructor_impl);
ignore (merge_impl);

Code.emit signature `None "val name': unit -> string";
Code.emit signature `None "type t = %s %s" type' params.annot;
Code.emit signature `None "val make: %s" default_constructor_sig;
Code.emit signature `None "val merge: t -> t -> t";
Code.emit signature `None "val to_proto': Runtime'.Writer.t -> t -> Runtime'.Writer.t";
Code.emit signature `None "val to_proto: t -> Runtime'.Writer.t";
Code.emit signature `None "val from_proto: Runtime'.Reader.t -> (t, [> Runtime'.Result.error]) result";
Expand All @@ -227,6 +229,7 @@ let rec emit_message ~params ~syntax scope
Code.emit implementation `None "let name' () = \"%s\"" (Scope.get_current_scope scope);
Code.emit implementation `None "type t = %s%s" type' params.annot;
Code.emit implementation `None "let make %s" default_constructor_impl;
Code.emit implementation `None "let merge = (%s)" merge_impl;

Code.emit implementation `Begin "let to_proto' =";
Code.emit implementation `None "let spec = %s in" serialize_spec;
Expand All @@ -240,7 +243,6 @@ let rec emit_message ~params ~syntax scope
Code.emit implementation `None "let constructor = %s in" constructor;
Code.emit implementation `None "let spec = %s in" deserialize_spec;
Code.emit implementation `None "Runtime'.Deserialize.deserialize spec constructor";
(* TODO: No need to have a function here. We could drop deserialize thing here *)
Code.emit implementation `End "let from_proto writer = Runtime'.Result.catch (fun () -> from_proto_exn writer)";
| None -> ()
in
Expand Down
Loading

0 comments on commit e2247b2

Please sign in to comment.