diff --git a/configure.ac b/configure.ac index 4da41d9..1174f33 100644 --- a/configure.ac +++ b/configure.ac @@ -25,6 +25,7 @@ AC_CHECK_LIB(ssl,TLSv1_1_method,[CFLAGS="$CFLAGS -DHAVE_TLS11"],,[$LIBS]) AC_CHECK_LIB(ssl,TLSv1_2_method,[CFLAGS="$CFLAGS -DHAVE_TLS12"],,[$LIBS]) AC_CHECK_LIB(crypto,EC_KEY_free,[CFLAGS="$CFLAGS -DHAVE_EC"],,[$LIBS]) AC_CHECK_DECL([SSL_set_tlsext_host_name], [CFLAGS="$CFLAGS -DHAVE_SNI"], [], [[#include ]]) +AC_CHECK_DECL([SSL_set_alpn_protos], [CFLAGS="$CFLAGS -DHAVE_ALPN"], [], [[#include ]]) # Finally create the Makefile and samples AC_CONFIG_FILES([Makefile],[chmod a-w Makefile]) diff --git a/examples/alpn.ml b/examples/alpn.ml new file mode 100644 index 0000000..b3f4d26 --- /dev/null +++ b/examples/alpn.ml @@ -0,0 +1,93 @@ + +let test_client proto_list = + Ssl.init (); + let ctx = Ssl.create_context Ssl.TLSv1_2 Ssl.Client_context in + Ssl.set_context_alpn_protos ctx proto_list; + let sockaddr = Unix.ADDR_INET(Unix.inet_addr_of_string "127.0.0.1", 4433) in + let ssl = Ssl.open_connection_with_context ctx sockaddr in + let () = + match Ssl.get_negotiated_alpn_protocol ssl with + | None -> print_endline "No protocol selected" + | Some proto -> print_endline ("Selected protocol: " ^ proto) + in + Ssl.shutdown ssl + + +let test_server proto_list = + let certfile = "cert.pem" in + let privkey = "privkey.key" in + let log s = + Printf.printf "[II] %s\n%!" s + in + Ssl.init (); + let sockaddr = Unix.ADDR_INET(Unix.inet_addr_of_string "127.0.0.1", 4433) in + let domain = + begin match sockaddr with + | Unix.ADDR_UNIX _ -> Unix.PF_UNIX + | Unix.ADDR_INET (_, _) -> Unix.PF_INET + end + in + let sock = Unix.socket domain Unix.SOCK_STREAM 0 in + let ctx = Ssl.create_context Ssl.TLSv1_2 Ssl.Server_context in + Ssl.use_certificate ctx certfile privkey; + let rec first_match l1 = function + | [] -> None + | x::_ when List.mem x l1 -> Some x + | _::xs -> first_match l1 xs + in + Ssl.set_context_alpn_select_callback ctx (fun client_protos -> + first_match client_protos proto_list + ); + Unix.setsockopt sock Unix.SO_REUSEADDR true; + Unix.bind sock sockaddr; + Unix.listen sock 100; + log "listening for connections"; + let (s, caller) = Unix.accept sock in + let ssl_s = Ssl.embed_socket s ctx in + let () = + try Ssl.accept ssl_s with + | e -> Printexc.to_string e |> print_endline + in + let inet_addr_of_sockaddr = function + | Unix.ADDR_INET (n, _) -> n + | Unix.ADDR_UNIX _ -> Unix.inet_addr_any + in + let inet_addr = inet_addr_of_sockaddr caller in + let ip = Unix.string_of_inet_addr inet_addr in + log (Printf.sprintf "openning connection for [%s]" ip); + let () = + match Ssl.get_negotiated_alpn_protocol ssl_s with + | None -> log "no protocol selected" + | Some proto -> log (Printf.sprintf "selected protocol: %s" proto) + in + Ssl.shutdown ssl_s + +let () = + let usage = "usage: ./alpn (server|client) protocol[,protocol]" in + let split_on_char sep s = + let r = ref [] in + let j = ref (String.length s) in + for i = String.length s - 1 downto 0 do + if s.[i] = sep then begin + r := String.sub s (i + 1) (!j - i - 1) :: !r; + j := i + end + done; + String.sub s 0 !j :: !r + in + let typ = ref "" in + let protocols = ref [] in + Arg.parse [ + "-t", Arg.String (fun t -> typ := t), "Type (server or client)"; + "-p", Arg.String (fun p -> protocols := split_on_char ',' p), "Comma-separated protocols"; + ] (fun _ -> ()) usage; + match !typ with + | "server" -> test_server !protocols + | "client" -> test_client !protocols + | _ -> failwith "Invalid type, use server or client." + +(* Usage: +ocamlfind ocamlc alpn.ml -g -o alpn -package ssl -linkpkg -ccopt -L/path/to/openssl/lib -cclib -lssl -cclib -lcrypto +./alpn -t server -p h2,http/1.1 +./alpn -t client -p h2,http/1.1 +*) diff --git a/src/ssl.ml b/src/ssl.ml index c52c412..23646ac 100644 --- a/src/ssl.ml +++ b/src/ssl.ml @@ -193,6 +193,10 @@ external set_verify_depth : context -> int -> unit = "ocaml_ssl_ctx_set_verify_d external set_client_CA_list_from_file : context -> string -> unit = "ocaml_ssl_ctx_set_client_CA_list_from_file" +external set_context_alpn_protos : context -> string list -> unit = "ocaml_ssl_ctx_set_alpn_protos" + +external set_context_alpn_select_callback : context -> (string list -> string option) -> unit = "ocaml_ssl_ctx_set_alpn_select_callback" + type cipher external get_cipher : socket -> cipher = "ocaml_ssl_get_current_cipher" @@ -219,6 +223,10 @@ external file_descr_of_socket : socket -> Unix.file_descr = "ocaml_ssl_get_file_ external set_client_SNI_hostname : socket -> string -> unit = "ocaml_ssl_set_client_SNI_hostname" +external set_alpn_protos : socket -> string list -> unit = "ocaml_ssl_set_alpn_protos" + +external get_negotiated_alpn_protocol : socket -> string option = "ocaml_ssl_get_negotiated_alpn_protocol" + external connect : socket -> unit = "ocaml_ssl_connect" external verify : socket -> unit = "ocaml_ssl_verify" diff --git a/src/ssl.mli b/src/ssl.mli index 45d47ba..921ff0b 100644 --- a/src/ssl.mli +++ b/src/ssl.mli @@ -275,6 +275,12 @@ val set_verify : context -> verify_mode list -> verify_callback option -> unit (** Set the maximum depth for the certificate chain verification that shall be allowed. *) val set_verify_depth : context -> int -> unit +(** Set the list of supported ALPN protocols for negotiation to the context. *) +val set_context_alpn_protos : context -> string list -> unit + +(** Set the callback to allow server to select the preferred protocol from client's available protocols. *) +val set_context_alpn_select_callback : context -> (string list -> string option) -> unit + (** {2 Ciphers} *) @@ -369,6 +375,12 @@ val shutdown_connection : socket -> unit * Name Indication (SNI) TLS extension. *) val set_client_SNI_hostname : socket -> string -> unit +(** Set the list of supported ALPN protocols for negotiation to the connection. *) +val set_alpn_protos : socket -> string list -> unit + +(** Get the negotiated protocol from the connection. *) +val get_negotiated_alpn_protocol : socket -> string option + (** Connect an SSL socket. *) val connect : socket -> unit diff --git a/src/ssl_stubs.c b/src/ssl_stubs.c index fa4cbbf..7ea375b 100644 --- a/src/ssl_stubs.c +++ b/src/ssl_stubs.c @@ -102,6 +102,19 @@ static struct custom_operations socket_ops = custom_deserialize_default }; +/* Option types */ + +#define Val_none Val_int(0) + +static value Val_some(value v) +{ + CAMLparam1(v); + CAMLlocal1(some); + some = caml_alloc(1, 0); + Store_field(some, 0, v); + CAMLreturn(some); +} + /****************** * Initialization * @@ -535,6 +548,144 @@ CAMLprim value ocaml_ssl_ctx_set_client_CA_list_from_file(value context, value v CAMLreturn(Val_unit); } +#ifdef HAVE_ALPN +static int get_alpn_buffer_length(value vprotos) +{ + value protos_tl = vprotos; + int total_len = 0; + while (protos_tl != Val_emptylist) + { + total_len += caml_string_length(Field(protos_tl, 0)) + 1; + protos_tl = Field(protos_tl, 1); + } + return total_len; +} + +static void build_alpn_protocol_buffer(value vprotos, unsigned char *protos) +{ + int proto_idx = 0; + while (vprotos != Val_emptylist) + { + value head = Field(vprotos, 0); + int len = caml_string_length(head); + protos[proto_idx++] = len; + + int i; + for (i = 0; i < len; i++) + protos[proto_idx++] = Byte_u(head, i); + vprotos = Field(vprotos, 1); + } +} + +CAMLprim value ocaml_ssl_ctx_set_alpn_protos(value context, value vprotos) +{ + CAMLparam2(context, vprotos); + SSL_CTX *ctx = Ctx_val(context); + + int total_len = get_alpn_buffer_length(vprotos); + unsigned char protos[total_len]; + build_alpn_protocol_buffer(vprotos, protos); + + caml_enter_blocking_section(); + SSL_CTX_set_alpn_protos(ctx, protos, sizeof(protos)); + caml_leave_blocking_section(); + + CAMLreturn(Val_unit); +} + +static value build_alpn_protocol_list(const unsigned char *protocol_buffer, unsigned int len) +{ + CAMLparam0(); + CAMLlocal3(protocol_list, current, tail); + + int idx = 0; + protocol_list = Val_emptylist; + + while (idx < len) + { + int proto_len = (int) protocol_buffer[idx++]; + char proto[proto_len + 1]; + int i; + for (i = 0; i < proto_len; i++) + proto[i] = (char) protocol_buffer[idx++]; + proto[proto_len] = '\0'; + + tail = caml_alloc(2, 0); + Store_field(tail, 0, caml_copy_string(proto)); + Store_field(tail, 1, Val_emptylist); + + if (protocol_list == Val_emptylist) + protocol_list = tail; + else + Store_field(current, 1, tail); + + current = tail; + } + + CAMLreturn(protocol_list); +} + +static int alpn_select_cb(SSL *ssl, + const unsigned char **out, + unsigned char *outlen, + const unsigned char *in, + unsigned int inlen, + void *arg) +{ + CAMLparam0(); + CAMLlocal3(protocol_list, selected_protocol, selected_protocol_opt); + + int len; + + caml_leave_blocking_section(); + protocol_list = build_alpn_protocol_list(in, inlen); + selected_protocol_opt = caml_callback(*((value*)arg), protocol_list); + + if (selected_protocol_opt == Val_none) + return SSL_TLSEXT_ERR_NOACK; + + selected_protocol = Field(selected_protocol_opt, 0); + len = caml_string_length(selected_protocol); + *out = Bytes_val(selected_protocol); + *outlen = len; + caml_enter_blocking_section(); + + return SSL_TLSEXT_ERR_OK; +} + +CAMLprim value ocaml_ssl_ctx_set_alpn_select_callback(value context, value cb) +{ + CAMLparam2(context, cb); + SSL_CTX *ctx = Ctx_val(context); + + value *select_cb; + + select_cb = malloc(sizeof(value)); + *select_cb = cb; + caml_register_global_root(select_cb); + + caml_enter_blocking_section(); + SSL_CTX_set_alpn_select_cb(ctx, alpn_select_cb, select_cb); + caml_leave_blocking_section(); + + CAMLreturn(Val_unit); +} +#else +CAMLprim value ocaml_ssl_ctx_set_alpn_protos(value context, value vprotos) +{ + CAMLparam2(context, vprotos); + caml_raise_constant(*caml_named_value("ssl_exn_method_error")); + CAMLreturn(Val_unit); +} + +CAMLprim value ocaml_ssl_ctx_set_alpn_select_callback(value context, value cb) +{ + CAMLparam2(context, cb); + caml_raise_constant(*caml_named_value("ssl_exn_method_error")); + CAMLreturn(Val_unit); +} +#endif + static int pem_passwd_cb(char *buf, int size, int rwflag, void *userdata) { value s; @@ -940,6 +1091,52 @@ CAMLprim value ocaml_ssl_set_client_SNI_hostname(value socket, value vhostname) } #endif +#ifdef HAVE_ALPN +CAMLprim value ocaml_ssl_set_alpn_protos(value socket, value vprotos) +{ + CAMLparam2(socket, vprotos); + SSL *ssl = SSL_val(socket); + + int total_len = get_alpn_buffer_length(vprotos); + unsigned char protos[total_len]; + build_alpn_protocol_buffer(vprotos, protos); + + caml_enter_blocking_section(); + SSL_set_alpn_protos(ssl, protos, sizeof(protos)); + caml_leave_blocking_section(); + + CAMLreturn(Val_unit); +} + +CAMLprim value ocaml_ssl_get_negotiated_alpn_protocol(value socket) +{ + CAMLparam1(socket); + SSL *ssl = SSL_val(socket); + + const unsigned char *data; + unsigned int len; + SSL_get0_alpn_selected(ssl, &data, &len); + + if (len == 0) CAMLreturn(Val_none); + + CAMLreturn(Val_some(caml_copy_string((const char*) data))); +} +#else +CAMLprim value ocaml_ssl_set_alpn_protos(value socket, value vprotos) +{ + CAMLparam2(socket, vprotos); + caml_raise_constant(*caml_named_value("ssl_exn_method_error")); + CAMLreturn(Val_unit); +} + +CAMLprim value ocaml_ssl_get_negotiated_alpn_protocol(value socket) +{ + CAMLparam1(socket); + caml_raise_constant(*caml_named_value("ssl_exn_method_error")); + CAMLreturn(Val_unit); +} +#endif + CAMLprim value ocaml_ssl_connect(value socket) { CAMLparam1(socket); @@ -1342,7 +1539,7 @@ static DH *load_dh_param(const char *dhfile) { DH *ret=NULL; BIO *bio; - + if ((bio=BIO_new_file(dhfile,"r")) == NULL) goto err; ret=PEM_read_bio_DHparams(bio,NULL,NULL,NULL);