Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support ALPN extension #38

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions configure.ac
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ AC_CHECK_LIB(ssl,TLSv1_1_method,[CFLAGS="$CFLAGS -DHAVE_TLS11"],,[$LIBS])
AC_CHECK_LIB(ssl,TLSv1_2_method,[CFLAGS="$CFLAGS -DHAVE_TLS12"],,[$LIBS])
AC_CHECK_LIB(crypto,EC_KEY_free,[CFLAGS="$CFLAGS -DHAVE_EC"],,[$LIBS])
AC_CHECK_DECL([SSL_set_tlsext_host_name], [CFLAGS="$CFLAGS -DHAVE_SNI"], [], [[#include <openssl/ssl.h>]])
AC_CHECK_DECL([SSL_set_alpn_protos], [CFLAGS="$CFLAGS -DHAVE_ALPN"], [], [[#include <openssl/ssl.h>]])

# Finally create the Makefile and samples
AC_CONFIG_FILES([Makefile],[chmod a-w Makefile])
Expand Down
93 changes: 93 additions & 0 deletions examples/alpn.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@

let test_client proto_list =
Ssl.init ();
let ctx = Ssl.create_context Ssl.TLSv1_2 Ssl.Client_context in
Ssl.set_context_alpn_protos ctx proto_list;
let sockaddr = Unix.ADDR_INET(Unix.inet_addr_of_string "127.0.0.1", 4433) in
let ssl = Ssl.open_connection_with_context ctx sockaddr in
let () =
match Ssl.get_negotiated_alpn_protocol ssl with
| None -> print_endline "No protocol selected"
| Some proto -> print_endline ("Selected protocol: " ^ proto)
in
Ssl.shutdown ssl


let test_server proto_list =
let certfile = "cert.pem" in
let privkey = "privkey.key" in
let log s =
Printf.printf "[II] %s\n%!" s
in
Ssl.init ();
let sockaddr = Unix.ADDR_INET(Unix.inet_addr_of_string "127.0.0.1", 4433) in
let domain =
begin match sockaddr with
| Unix.ADDR_UNIX _ -> Unix.PF_UNIX
| Unix.ADDR_INET (_, _) -> Unix.PF_INET
end
in
let sock = Unix.socket domain Unix.SOCK_STREAM 0 in
let ctx = Ssl.create_context Ssl.TLSv1_2 Ssl.Server_context in
Ssl.use_certificate ctx certfile privkey;
let rec first_match l1 = function
| [] -> None
| x::_ when List.mem x l1 -> Some x
| _::xs -> first_match l1 xs
in
Ssl.set_context_alpn_select_callback ctx (fun client_protos ->
first_match client_protos proto_list
);
Unix.setsockopt sock Unix.SO_REUSEADDR true;
Unix.bind sock sockaddr;
Unix.listen sock 100;
log "listening for connections";
let (s, caller) = Unix.accept sock in
let ssl_s = Ssl.embed_socket s ctx in
let () =
try Ssl.accept ssl_s with
| e -> Printexc.to_string e |> print_endline
in
let inet_addr_of_sockaddr = function
| Unix.ADDR_INET (n, _) -> n
| Unix.ADDR_UNIX _ -> Unix.inet_addr_any
in
let inet_addr = inet_addr_of_sockaddr caller in
let ip = Unix.string_of_inet_addr inet_addr in
log (Printf.sprintf "openning connection for [%s]" ip);
let () =
match Ssl.get_negotiated_alpn_protocol ssl_s with
| None -> log "no protocol selected"
| Some proto -> log (Printf.sprintf "selected protocol: %s" proto)
in
Ssl.shutdown ssl_s

let () =
let usage = "usage: ./alpn (server|client) protocol[,protocol]" in
let split_on_char sep s =
let r = ref [] in
let j = ref (String.length s) in
for i = String.length s - 1 downto 0 do
if s.[i] = sep then begin
r := String.sub s (i + 1) (!j - i - 1) :: !r;
j := i
end
done;
String.sub s 0 !j :: !r
in
let typ = ref "" in
let protocols = ref [] in
Arg.parse [
"-t", Arg.String (fun t -> typ := t), "Type (server or client)";
"-p", Arg.String (fun p -> protocols := split_on_char ',' p), "Comma-separated protocols";
] (fun _ -> ()) usage;
match !typ with
| "server" -> test_server !protocols
| "client" -> test_client !protocols
| _ -> failwith "Invalid type, use server or client."

(* Usage:
ocamlfind ocamlc alpn.ml -g -o alpn -package ssl -linkpkg -ccopt -L/path/to/openssl/lib -cclib -lssl -cclib -lcrypto
./alpn -t server -p h2,http/1.1
./alpn -t client -p h2,http/1.1
*)
8 changes: 8 additions & 0 deletions src/ssl.ml
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,10 @@ external set_verify_depth : context -> int -> unit = "ocaml_ssl_ctx_set_verify_d

external set_client_CA_list_from_file : context -> string -> unit = "ocaml_ssl_ctx_set_client_CA_list_from_file"

external set_context_alpn_protos : context -> string list -> unit = "ocaml_ssl_ctx_set_alpn_protos"

external set_context_alpn_select_callback : context -> (string list -> string option) -> unit = "ocaml_ssl_ctx_set_alpn_select_callback"

type cipher

external get_cipher : socket -> cipher = "ocaml_ssl_get_current_cipher"
Expand All @@ -219,6 +223,10 @@ external file_descr_of_socket : socket -> Unix.file_descr = "ocaml_ssl_get_file_

external set_client_SNI_hostname : socket -> string -> unit = "ocaml_ssl_set_client_SNI_hostname"

external set_alpn_protos : socket -> string list -> unit = "ocaml_ssl_set_alpn_protos"

external get_negotiated_alpn_protocol : socket -> string option = "ocaml_ssl_get_negotiated_alpn_protocol"

external connect : socket -> unit = "ocaml_ssl_connect"

external verify : socket -> unit = "ocaml_ssl_verify"
Expand Down
12 changes: 12 additions & 0 deletions src/ssl.mli
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,12 @@ val set_verify : context -> verify_mode list -> verify_callback option -> unit
(** Set the maximum depth for the certificate chain verification that shall be allowed. *)
val set_verify_depth : context -> int -> unit

(** Set the list of supported ALPN protocols for negotiation to the context. *)
val set_context_alpn_protos : context -> string list -> unit

(** Set the callback to allow server to select the preferred protocol from client's available protocols. *)
val set_context_alpn_select_callback : context -> (string list -> string option) -> unit


(** {2 Ciphers} *)

Expand Down Expand Up @@ -369,6 +375,12 @@ val shutdown_connection : socket -> unit
* Name Indication (SNI) TLS extension. *)
val set_client_SNI_hostname : socket -> string -> unit

(** Set the list of supported ALPN protocols for negotiation to the connection. *)
val set_alpn_protos : socket -> string list -> unit

(** Get the negotiated protocol from the connection. *)
val get_negotiated_alpn_protocol : socket -> string option

(** Connect an SSL socket. *)
val connect : socket -> unit

Expand Down
199 changes: 198 additions & 1 deletion src/ssl_stubs.c
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,19 @@ static struct custom_operations socket_ops =
custom_deserialize_default
};

/* Option types */

#define Val_none Val_int(0)

static value Val_some(value v)
{
CAMLparam1(v);
CAMLlocal1(some);
some = caml_alloc(1, 0);
Store_field(some, 0, v);
CAMLreturn(some);
}


/******************
* Initialization *
Expand Down Expand Up @@ -535,6 +548,144 @@ CAMLprim value ocaml_ssl_ctx_set_client_CA_list_from_file(value context, value v
CAMLreturn(Val_unit);
}

#ifdef HAVE_ALPN
static int get_alpn_buffer_length(value vprotos)
{
value protos_tl = vprotos;
int total_len = 0;
while (protos_tl != Val_emptylist)
{
total_len += caml_string_length(Field(protos_tl, 0)) + 1;
protos_tl = Field(protos_tl, 1);
}
return total_len;
}

static void build_alpn_protocol_buffer(value vprotos, unsigned char *protos)
{
int proto_idx = 0;
while (vprotos != Val_emptylist)
{
value head = Field(vprotos, 0);
int len = caml_string_length(head);
protos[proto_idx++] = len;

int i;
for (i = 0; i < len; i++)
protos[proto_idx++] = Byte_u(head, i);
vprotos = Field(vprotos, 1);
}
}

CAMLprim value ocaml_ssl_ctx_set_alpn_protos(value context, value vprotos)
{
CAMLparam2(context, vprotos);
SSL_CTX *ctx = Ctx_val(context);

int total_len = get_alpn_buffer_length(vprotos);
unsigned char protos[total_len];
build_alpn_protocol_buffer(vprotos, protos);

caml_enter_blocking_section();
SSL_CTX_set_alpn_protos(ctx, protos, sizeof(protos));
caml_leave_blocking_section();

CAMLreturn(Val_unit);
}

static value build_alpn_protocol_list(const unsigned char *protocol_buffer, unsigned int len)
{
CAMLparam0();
CAMLlocal3(protocol_list, current, tail);

int idx = 0;
protocol_list = Val_emptylist;

while (idx < len)
{
int proto_len = (int) protocol_buffer[idx++];
char proto[proto_len + 1];
int i;
for (i = 0; i < proto_len; i++)
proto[i] = (char) protocol_buffer[idx++];
proto[proto_len] = '\0';

tail = caml_alloc(2, 0);
Store_field(tail, 0, caml_copy_string(proto));
Store_field(tail, 1, Val_emptylist);

if (protocol_list == Val_emptylist)
protocol_list = tail;
else
Store_field(current, 1, tail);

current = tail;
}

CAMLreturn(protocol_list);
}

static int alpn_select_cb(SSL *ssl,
const unsigned char **out,
unsigned char *outlen,
const unsigned char *in,
unsigned int inlen,
void *arg)
{
CAMLparam0();
CAMLlocal3(protocol_list, selected_protocol, selected_protocol_opt);

int len;

caml_leave_blocking_section();
protocol_list = build_alpn_protocol_list(in, inlen);
selected_protocol_opt = caml_callback(*((value*)arg), protocol_list);

if (selected_protocol_opt == Val_none)
return SSL_TLSEXT_ERR_NOACK;

selected_protocol = Field(selected_protocol_opt, 0);
len = caml_string_length(selected_protocol);
*out = Bytes_val(selected_protocol);
*outlen = len;
caml_enter_blocking_section();

return SSL_TLSEXT_ERR_OK;
}

CAMLprim value ocaml_ssl_ctx_set_alpn_select_callback(value context, value cb)
{
CAMLparam2(context, cb);
SSL_CTX *ctx = Ctx_val(context);

value *select_cb;

select_cb = malloc(sizeof(value));
*select_cb = cb;
caml_register_global_root(select_cb);

caml_enter_blocking_section();
SSL_CTX_set_alpn_select_cb(ctx, alpn_select_cb, select_cb);
caml_leave_blocking_section();

CAMLreturn(Val_unit);
}
#else
CAMLprim value ocaml_ssl_ctx_set_alpn_protos(value context, value vprotos)
{
CAMLparam2(context, vprotos);
caml_raise_constant(*caml_named_value("ssl_exn_method_error"));
CAMLreturn(Val_unit);
}

CAMLprim value ocaml_ssl_ctx_set_alpn_select_callback(value context, value cb)
{
CAMLparam2(context, cb);
caml_raise_constant(*caml_named_value("ssl_exn_method_error"));
CAMLreturn(Val_unit);
}
#endif

static int pem_passwd_cb(char *buf, int size, int rwflag, void *userdata)
{
value s;
Expand Down Expand Up @@ -940,6 +1091,52 @@ CAMLprim value ocaml_ssl_set_client_SNI_hostname(value socket, value vhostname)
}
#endif

#ifdef HAVE_ALPN
CAMLprim value ocaml_ssl_set_alpn_protos(value socket, value vprotos)
{
CAMLparam2(socket, vprotos);
SSL *ssl = SSL_val(socket);

int total_len = get_alpn_buffer_length(vprotos);
unsigned char protos[total_len];
build_alpn_protocol_buffer(vprotos, protos);

caml_enter_blocking_section();
SSL_set_alpn_protos(ssl, protos, sizeof(protos));
caml_leave_blocking_section();

CAMLreturn(Val_unit);
}

CAMLprim value ocaml_ssl_get_negotiated_alpn_protocol(value socket)
{
CAMLparam1(socket);
SSL *ssl = SSL_val(socket);

const unsigned char *data;
unsigned int len;
SSL_get0_alpn_selected(ssl, &data, &len);

if (len == 0) CAMLreturn(Val_none);

CAMLreturn(Val_some(caml_copy_string((const char*) data)));
}
#else
CAMLprim value ocaml_ssl_set_alpn_protos(value socket, value vprotos)
{
CAMLparam2(socket, vprotos);
caml_raise_constant(*caml_named_value("ssl_exn_method_error"));
CAMLreturn(Val_unit);
}

CAMLprim value ocaml_ssl_get_negotiated_alpn_protocol(value socket)
{
CAMLparam1(socket);
caml_raise_constant(*caml_named_value("ssl_exn_method_error"));
CAMLreturn(Val_unit);
}
#endif

CAMLprim value ocaml_ssl_connect(value socket)
{
CAMLparam1(socket);
Expand Down Expand Up @@ -1342,7 +1539,7 @@ static DH *load_dh_param(const char *dhfile)
{
DH *ret=NULL;
BIO *bio;

if ((bio=BIO_new_file(dhfile,"r")) == NULL)
goto err;
ret=PEM_read_bio_DHparams(bio,NULL,NULL,NULL);
Expand Down