Skip to content

Commit

Permalink
Add timeout to all TLS implicit read/write, add explicit termination …
Browse files Browse the repository at this point in the history
…of wait operations on shutdown. (#3681)
  • Loading branch information
toots committed Jan 30, 2024
1 parent 0bb3100 commit 352496b
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 66 deletions.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Fixed:
- Fixed memory leaks when using dynamically created sources (`input.harbor`, `input.ffmepg`, SRT sources and `request.dynamic`)
- Fixed invalid array fill in `add` (#3678)
- Fixed empty metadata when switching to new source (#3373)
- Fixed deadlock when connecting to a non-SSL icecast using the TLS transport (#3681)

---

Expand Down
82 changes: 47 additions & 35 deletions src/core/builtins/builtins_tls.ml
Original file line number Diff line number Diff line change
Expand Up @@ -30,32 +30,35 @@ module Liq_tls = struct

let () = Mirage_crypto_rng_unix.initialize (module Mirage_crypto_rng.Fortuna)
let buf_len = 4096
let write_all fd data = Utils.write_all fd (Cstruct.to_bytes data)

let read h len =
let write_all ~timeout fd data =
Tutils.write_all ~timeout fd (Cstruct.to_bytes data)

let read ~timeout h len =
Tutils.wait_for (`Read h.fd) timeout;
let n = Unix.read h.fd h.buf 0 (min len buf_len) in
Cstruct.of_bytes ~len:n h.buf

let read_pending h = function
| None -> ()
| Some data -> Buffer.add_string h.read_pending (Cstruct.to_string data)

let write_response h = function
let write_response ~timeout h = function
| None -> ()
| Some data -> write_all h.fd data
| Some data -> write_all ~timeout h.fd data

let handshake h =
let handshake ~timeout h =
let rec f () =
if Tls.Engine.handshake_in_progress h.state then (
match Tls.Engine.handle_tls h.state (read h buf_len) with
match Tls.Engine.handle_tls h.state (read ~timeout h buf_len) with
| Ok (`Eof, _, _) ->
Runtime_error.raise ~pos:[]
~message:"Connection closed while negotiating TLS handshake!"
"tls"
| Ok ((`Ok _ as step), `Response response, `Data data)
| Ok ((`Alert _ as step), `Response response, `Data data) ->
read_pending h data;
write_response h response;
write_response ~timeout h response;
(match step with
| `Ok state -> h.state <- state
| `Alert alert ->
Expand All @@ -66,7 +69,7 @@ module Liq_tls = struct
"tls");
f ()
| Error (error, `Response response) ->
write_all h.fd response;
write_all ~timeout h.fd response;
Runtime_error.raise ~pos:[]
~message:
(Printf.sprintf "TLS handshake error: %s"
Expand All @@ -76,33 +79,40 @@ module Liq_tls = struct
in
f ()

let init_base ~state fd =
let init_base ~timeout ~state fd =
let buf = Bytes.create buf_len in
let read_pending = Buffer.create 4096 in
let h = { read_pending; fd; buf; state } in
handshake h;
handshake ~timeout h;
h

let init_server ~server fd =
let init_server ~timeout ~server fd =
let state = Tls.Engine.server server in
init_base ~state fd
init_base ~timeout ~state fd

let init_client ~client fd =
let init_client ~timeout ~client fd =
let state, hello = Tls.Engine.client client in
write_all fd hello;
init_base ~state fd
write_all ~timeout fd hello;
init_base ~timeout ~state fd

let write h b off len =
let write ?timeout h b off len =
let timeout = Option.value ~default:Harbor_base.conf_timeout#get timeout in
match
Tls.Engine.send_application_data h.state [Cstruct.of_bytes ~off ~len b]
with
| None -> len
| Some (state, data) ->
write_all h.fd data;
write_all ~timeout h.fd data;
h.state <- state;
len

let read h b off len =
let read ?read_timeout ?write_timeout h b off len =
let read_timeout =
Option.value ~default:Harbor_base.conf_timeout#get read_timeout
in
let write_timeout =
Option.value ~default:Harbor_base.conf_timeout#get write_timeout
in
let pending = Buffer.length h.read_pending in
if 0 < pending then (
let n = min pending len in
Expand All @@ -111,17 +121,23 @@ module Liq_tls = struct
n)
else (
let rec f () =
match Tls.Engine.handle_tls h.state (read h len) with
match
Tls.Engine.handle_tls h.state (read ~timeout:read_timeout h len)
with
| Ok (`Eof, _, _) -> 0
| Ok (`Alert alert, `Response response, _) ->
ignore (Option.map (write_all h.fd) response);
(match response with
| None -> ()
| Some r -> write_all ~timeout:write_timeout h.fd r);
Runtime_error.raise ~pos:[]
~message:
(Printf.sprintf "TLS read error: %s"
(Tls.Packet.alert_type_to_string alert))
"tls"
| Ok (`Ok state, `Response response, `Data data) -> (
ignore (Option.map (write_all h.fd) response);
(match response with
| None -> ()
| Some r -> write_all ~timeout:write_timeout h.fd r);
h.state <- state;
match data with
| None -> f ()
Expand All @@ -134,7 +150,7 @@ module Liq_tls = struct
(Cstruct.to_string data ~off:n ~len:(data_len - n));
n)
| Error (error, `Response response) ->
write_all h.fd response;
write_all ~timeout:write_timeout h.fd response;
Runtime_error.raise ~pos:[]
~message:
(Printf.sprintf "TLS read error: %s"
Expand All @@ -144,9 +160,9 @@ module Liq_tls = struct
in
f ())

let close h =
let close ?(timeout = 1.) h =
let state, data = Tls.Engine.send_close_notify h.state in
write_all h.fd data;
write_all ~timeout h.fd data;
h.state <- state;
Unix.close h.fd
end
Expand Down Expand Up @@ -191,15 +207,11 @@ let server ~read_timeout ~write_timeout ~certificate ~key transport =
object
method transport = transport

method accept ?timeout sock =
let fd, caller = Http.accept ?timeout sock in
method accept ?(timeout = 1.) sock =
let fd, caller = Http.accept ~timeout sock in
try
(match timeout with
| Some timeout ->
Http.set_socket_default ~read_timeout:timeout
~write_timeout:timeout fd
| None -> ());
let session = Liq_tls.init_server ~server fd in
Http.set_socket_default ~read_timeout:timeout ~write_timeout:timeout fd;
let session = Liq_tls.init_server ~timeout ~server fd in
Http.set_socket_default ~read_timeout ~write_timeout fd;
(tls_socket ~session transport, caller)
with exn ->
Expand All @@ -214,7 +226,7 @@ let transport ~read_timeout ~write_timeout ~certificate ~key () =
method protocol = "https"
method default_port = 443

method connect ?bind_address ?timeout ?prefer host port =
method connect ?bind_address ?(timeout = 1.) ?prefer host port =
let domain = Domain_name.host_exn (Domain_name.of_string_exn host) in
let authenticator = Result.get_ok (Ca_certs.authenticator ()) in
let certificate_authenticator =
Expand All @@ -238,8 +250,8 @@ let transport ~read_timeout ~write_timeout ~certificate ~key () =
if Result.is_ok r then r else authenticator ?ip ~host certs
in
let client = Tls.Config.client ~authenticator ~peer_name:domain () in
let fd = Http.connect ?bind_address ?timeout ?prefer host port in
let session = Liq_tls.init_client ~client fd in
let fd = Http.connect ?bind_address ~timeout ?prefer host port in
let session = Liq_tls.init_client ~timeout ~client fd in
tls_socket ~session self

method server = server ~read_timeout ~write_timeout ~certificate ~key self
Expand Down
60 changes: 39 additions & 21 deletions src/core/tools/tutils.ml
Original file line number Diff line number Diff line change
Expand Up @@ -330,28 +330,33 @@ type event =
(* Wait for [`Read socket], [`Write socket] or [`Both socket] for at most
* [timeout] seconds on the given [socket]. Raises [Timeout elapsed_time]
* if timeout is reached. *)
let wait_for ?(log = fun _ -> ()) event timeout =
let start_time = Unix.gettimeofday () in
let max_time = start_time +. timeout in
let r, w =
match event with
| `Read socket -> ([socket], [])
| `Write socket -> ([], [socket])
| `Both socket -> ([socket], [socket])
in
let rec wait t =
let r, w, _ =
try Utils.select r w [] t
with Unix.Unix_error (Unix.EINTR, _, _) -> ([], [], [])
let wait_for =
let end_r, end_w = Unix.pipe ~cloexec:true () in
Lifecycle.before_core_shutdown ~name:"wait_for shutdown" (fun () ->
try ignore (Unix.write end_w (Bytes.create 1) 0 1) with _ -> ());
fun ?(log = fun _ -> ()) event timeout ->
let start_time = Unix.gettimeofday () in
let max_time = start_time +. timeout in
let r, w =
match event with
| `Read socket -> ([socket], [])
| `Write socket -> ([], [socket])
| `Both socket -> ([socket], [socket])
in
if r = [] && w = [] then (
let current_time = Unix.gettimeofday () in
if current_time >= max_time then (
log "Timeout reached!";
raise (Timeout (current_time -. start_time)))
else wait (min 1. (max_time -. current_time)))
in
wait (min 1. timeout)
let rec wait t =
let r, w, _ =
try Utils.select (end_r :: r) w [] t
with Unix.Unix_error (Unix.EINTR, _, _) -> ([], [], [])
in
if List.mem end_r r then raise Exit;
if r = [] && w = [] then (
let current_time = Unix.gettimeofday () in
if current_time >= max_time then (
log "Timeout reached!";
raise (Timeout (current_time -. start_time)))
else wait (min 1. (max_time -. current_time)))
in
wait (min 1. timeout)
let main () =
if Atomic.compare_and_set state `Starting `Running then wait_done ();
Expand Down Expand Up @@ -391,3 +396,16 @@ let lazy_cell f =
let v = f () in
c := Some v;
v)
let write_all ?timeout fd b =
let rec f ofs len =
(match timeout with
| None -> ()
| Some timeout -> wait_for (`Write fd) timeout);
match Unix.write fd b ofs len with
| 0 -> raise End_of_file
| n when n = len -> ()
| n -> f (ofs + n) (len - n)
in
let len = Bytes.length b in
if len > 0 then f 0 len
2 changes: 2 additions & 0 deletions src/core/tools/tutils.mli
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,5 @@ val seems_locked : Mutex.t -> bool

(** Thread-safe equivalent to Lazy.from_fun. *)
val lazy_cell : (unit -> 'a) -> unit -> 'a

val write_all : ?timeout:float -> Unix.file_descr -> bytes -> unit
10 changes: 0 additions & 10 deletions src/core/tools/utils.ml
Original file line number Diff line number Diff line change
Expand Up @@ -502,16 +502,6 @@ let concat_with_last ~last sep l =
| x :: l ->
Printf.sprintf "%s %s %s" (String.concat sep (List.rev l)) last x

let write_all fd b =
let rec f ofs len =
match Unix.write fd b ofs len with
| 0 -> raise End_of_file
| n when n = len -> ()
| n -> f (ofs + n) (len - n)
in
let len = Bytes.length b in
if len > 0 then f 0 len

(* Stdlib.abs_float is not inlined!. *)
let abs_float (f : float) = if f < 0. then -.f else f [@@inline always]

Expand Down

0 comments on commit 352496b

Please sign in to comment.