diff --git a/src/stack-direct/tcpip_stack_direct.ml b/src/stack-direct/tcpip_stack_direct.ml index ce3b13d34..afc7546ae 100644 --- a/src/stack-direct/tcpip_stack_direct.ml +++ b/src/stack-direct/tcpip_stack_direct.ml @@ -256,3 +256,226 @@ module MakeV6 end +type direct_ipv4v6_input = src:Ipaddr.t -> dst:Ipaddr.t -> Cstruct.t -> unit Lwt.t + +module type UDPV4V6_DIRECT = Mirage_protocols.UDP + with type ipaddr = Ipaddr.t + and type ipinput = direct_ipv4v6_input + +module type TCPV4V6_DIRECT = Mirage_protocols.TCP + with type ipaddr = Ipaddr.t + and type ipinput = direct_ipv4v6_input + +module IPV4V6 (Ipv4 : Mirage_protocols.IPV4) (Ipv6 : Mirage_protocols.IPV6) = struct + + type ipaddr = Ipaddr.t + type callback = src:ipaddr -> dst:ipaddr -> Cstruct.t -> unit Lwt.t + + let pp_ipaddr = Ipaddr.pp + + type error = [ Mirage_protocols.Ip.error | `Ipv4 of Ipv4.error | `Ipv6 of Ipv6.error | `Msg of string ] + + let pp_error ppf = function + | #Mirage_protocols.Ip.error as e -> Mirage_protocols.Ip.pp_error ppf e + | `Ipv4 e -> Ipv4.pp_error ppf e + | `Ipv6 e -> Ipv6.pp_error ppf e + | `Msg m -> Fmt.string ppf m + + type t = { ipv4 : Ipv4.t ; ipv6 : Ipv6.t } + + let connect ipv4 ipv6 = Lwt.return { ipv4 ; ipv6 } + + let disconnect _ = Lwt.return_unit + + let input t ~tcp ~udp ~default = + let tcp4 ~src ~dst payload = tcp ~src:(Ipaddr.V4 src) ~dst:(Ipaddr.V4 dst) payload + and tcp6 ~src ~dst payload = tcp ~src:(Ipaddr.V6 src) ~dst:(Ipaddr.V6 dst) payload + and udp4 ~src ~dst payload = udp ~src:(Ipaddr.V4 src) ~dst:(Ipaddr.V4 dst) payload + and udp6 ~src ~dst payload = udp ~src:(Ipaddr.V6 src) ~dst:(Ipaddr.V6 dst) payload + and default4 ~proto ~src ~dst payload = default ~proto ~src:(Ipaddr.V4 src) ~dst:(Ipaddr.V4 dst) payload + and default6 ~proto ~src ~dst payload = default ~proto ~src:(Ipaddr.V6 src) ~dst:(Ipaddr.V6 dst) payload + in + fun buf -> + if Cstruct.len buf >= 1 then + let v = Cstruct.get_uint8 buf 0 lsr 4 in + if v = 4 then + Ipv4.input t.ipv4 ~tcp:tcp4 ~udp:udp4 ~default:default4 buf + else if v = 6 then + Ipv6.input t.ipv6 ~tcp:tcp6 ~udp:udp6 ~default:default6 buf + else + Lwt.return_unit + else + Lwt.return_unit + + let write t ?fragment ?ttl ?src dst proto ?size headerf bufs = + match dst with + | Ipaddr.V4 dst -> + begin + match + match src with + | None -> Ok None + | Some (Ipaddr.V4 src) -> Ok (Some src) + | _ -> Error (`Msg "source must be V4 if dst is V4") + with + | Error e -> Lwt.return (Error e) + | Ok src -> + Ipv4.write t.ipv4 ?fragment ?ttl ?src dst proto ?size headerf bufs >|= function + | Ok () -> Ok () + | Error e -> Error (`Ipv4 e) + end + | Ipaddr.V6 dst -> + begin + match + match src with + | None -> Ok None + | Some (Ipaddr.V6 src) -> Ok (Some src) + | _ -> Error (`Msg "source must be V6 if dst is V6") + with + | Error e -> Lwt.return (Error e) + | Ok src -> + Ipv6.write t.ipv6 ?fragment ?ttl ?src dst proto ?size headerf bufs >|= function + | Ok () -> Ok () + | Error e -> Error (`Ipv6 e) + end + + let pseudoheader t ?src dst proto len = + match dst with + | Ipaddr.V4 dst -> + let src = + match src with + | None -> None + | Some (Ipaddr.V4 src) -> Some src + | _ -> None (* TODO *) + in + Ipv4.pseudoheader t.ipv4 ?src dst proto len + | Ipaddr.V6 dst -> + let src = + match src with + | None -> None + | Some (Ipaddr.V6 src) -> Some src + | _ -> None (* TODO *) + in + Ipv6.pseudoheader t.ipv6 ?src dst proto len + + let src t ~dst = + match dst with + | Ipaddr.V4 dst -> Ipaddr.V4 (Ipv4.src t.ipv4 ~dst) + | Ipaddr.V6 dst -> Ipaddr.V6 (Ipv6.src t.ipv6 ~dst) + + let get_ip t = + List.map (fun ip -> Ipaddr.V4 ip) (Ipv4.get_ip t.ipv4) @ + List.map (fun ip -> Ipaddr.V6 ip) (Ipv6.get_ip t.ipv6) + + let mtu t = + (* TODO incorrect for IPv4 *) + Ipv6.mtu t.ipv6 +end + +module MakeV4V6 + (Time : Mirage_time.S) + (Random : Mirage_random.S) + (Netif : Mirage_net.S) + (Ethernet : Mirage_protocols.ETHERNET) + (Arpv4 : Mirage_protocols.ARP) + (Ip : Mirage_protocols.IP with type ipaddr = Ipaddr.t) + (Icmpv4 : Mirage_protocols.ICMP with type ipaddr = Ipaddr.V4.t) + (Udp : UDPV4V6_DIRECT) + (Tcp : TCPV4V6_DIRECT) = struct + + module UDP = Udp + module TCP = Tcp + module IP = Ip + + type t = { + netif : Netif.t; + ethif : Ethernet.t; + arpv4 : Arpv4.t; + icmpv4 : Icmpv4.t; + ip : IP.t; + udp : Udp.t; + tcp : Tcp.t; + udp_listeners: (int, Udp.callback) Hashtbl.t; + tcp_listeners: (int, Tcp.listener) Hashtbl.t; + mutable task : unit Lwt.t option; + } + + let pp fmt t = + Format.fprintf fmt "mac=%a,ip=%a" Macaddr.pp (Ethernet.mac t.ethif) + (Fmt.list Ipaddr.pp) (IP.get_ip t.ip) + + let tcp { tcp; _ } = tcp + let udp { udp; _ } = udp + let ip { ip; _ } = ip + + let err_invalid_port p = Printf.sprintf "invalid port number (%d)" p + + let listen_udp t ~port callback = + if port < 0 || port > 65535 + then raise (Invalid_argument (err_invalid_port port)) + else Hashtbl.replace t.udp_listeners port callback + + let listen_tcp ?keepalive t ~port process = + if port < 0 || port > 65535 + then raise (Invalid_argument (err_invalid_port port)) + else Hashtbl.replace t.tcp_listeners port { Tcp.process; keepalive } + + let udp_listeners t ~dst_port = + try Some (Hashtbl.find t.udp_listeners dst_port) + with Not_found -> None + + let tcp_listeners t dst_port = + try Some (Hashtbl.find t.tcp_listeners dst_port) + with Not_found -> None + + let listen t = + Lwt.catch (fun () -> + Log.debug (fun f -> f "Establishing or updating listener for stack %a" pp t); + let tcp = Tcp.input t.tcp ~listeners:(tcp_listeners t) + and udp = Udp.input t.udp ~listeners:(udp_listeners t) + and default ~proto ~src ~dst buf = + match proto, src, dst with + | 1, Ipaddr.V4 src, Ipaddr.V4 dst -> Icmpv4.input t.icmpv4 ~src ~dst buf + | _ -> Lwt.return_unit + in + let ethif_listener = Ethernet.input + ~arpv4:(Arpv4.input t.arpv4) + ~ipv4:(IP.input ~tcp ~udp ~default t.ip) + ~ipv6:(IP.input ~tcp ~udp ~default t.ip) + t.ethif + in + Netif.listen t.netif ~header_size:Ethernet_wire.sizeof_ethernet ethif_listener + >>= function + | Error e -> + Log.warn (fun p -> p "%a" Netif.pp_error e) ; + (* XXX: error should be passed to the caller *) + Lwt.return_unit + | Ok _res -> + let nstat = Netif.get_stats_counters t.netif in + let open Mirage_net in + Log.info (fun f -> + f "listening loop of interface %s terminated regularly:@ %Lu bytes \ + (%lu packets) received, %Lu bytes (%lu packets) sent@ " + (Macaddr.to_string (Netif.mac t.netif)) + nstat.rx_bytes nstat.rx_pkts + nstat.tx_bytes nstat.tx_pkts) ; + Lwt.return_unit) + (function + | Lwt.Canceled -> + Log.info (fun f -> f "listen of %a cancelled" pp t); + Lwt.return_unit + | e -> Lwt.fail e) + + let connect netif ethif arpv4 ip icmpv4 udp tcp = + let udp_listeners = Hashtbl.create 7 in + let tcp_listeners = Hashtbl.create 7 in + let t = { netif; ethif; arpv4; ip; icmpv4; tcp; udp; + udp_listeners; tcp_listeners; task = None } in + Log.info (fun f -> f "stack assembled: %a" pp t); + Lwt.async (fun () -> let task = listen t in t.task <- Some task; task); + Lwt.return t + + let disconnect t = + Log.info (fun f -> f "disconnect called: %a" pp t); + (match t.task with None -> () | Some task -> Lwt.cancel task); + Lwt.return_unit +end diff --git a/src/stack-direct/tcpip_stack_direct.mli b/src/stack-direct/tcpip_stack_direct.mli index 3d3698292..e7a5a7dc1 100644 --- a/src/stack-direct/tcpip_stack_direct.mli +++ b/src/stack-direct/tcpip_stack_direct.mli @@ -74,7 +74,46 @@ module MakeV6 val connect : Netif.t -> Ethernet.t -> Ipv6.t -> Udpv6.t -> Tcpv6.t -> t Lwt.t (** [connect] assembles the arguments into a network stack, then calls `listen` on the assembled stack before returning it to the caller. The - initial `listen` functions to ensure that the lower-level layers are + initial `listen` functions to ensure that the lower-level layers are + functioning, so that if the user wishes to establish outbound connections, + they will be able to do so. *) +end + +type direct_ipv4v6_input = src:Ipaddr.t -> dst:Ipaddr.t -> Cstruct.t -> unit Lwt.t + +module type UDPV4V6_DIRECT = Mirage_protocols.UDP + with type ipaddr = Ipaddr.t + and type ipinput = direct_ipv4v6_input + +module type TCPV4V6_DIRECT = Mirage_protocols.TCP + with type ipaddr = Ipaddr.t + and type ipinput = direct_ipv4v6_input + +module IPV4V6 (Ipv4 : Mirage_protocols.IPV4) (Ipv6 : Mirage_protocols.IPV6) : sig + include Mirage_protocols.IP with type ipaddr = Ipaddr.t + + val connect : Ipv4.t -> Ipv6.t -> t Lwt.t +end + +module MakeV4V6 + (Time : Mirage_time.S) + (Random : Mirage_random.S) + (Netif : Mirage_net.S) + (Ethernet : Mirage_protocols.ETHERNET) + (Arpv4 : Mirage_protocols.ARP) + (Ip : Mirage_protocols.IP with type ipaddr = Ipaddr.t) + (Icmpv4 : Mirage_protocols.ICMP with type ipaddr = Ipaddr.V4.t) + (Udp : UDPV4V6_DIRECT) + (Tcp : TCPV4V6_DIRECT) : sig + include Mirage_stack.V4V6 + with module IP = Ip + and module TCP = Tcp + and module UDP = Udp + + val connect : Netif.t -> Ethernet.t -> Arpv4.t -> Ip.t -> Icmpv4.t -> Udp.t -> Tcp.t -> t Lwt.t + (** [connect] assembles the arguments into a network stack, then calls + `listen` on the assembled stack before returning it to the caller. The + initial `listen` functions to ensure that the lower-level layers are functioning, so that if the user wishes to establish outbound connections, they will be able to do so. *) end