diff --git a/Cargo.lock b/Cargo.lock index 78e496e4a0..b56a707afd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -74,12 +74,6 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c5d78ce20460b82d3fa150275ed9d55e21064fc7951177baacf86a145c4a4b1f" -[[package]] -name = "array-init" -version = "2.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6945cc5422176fc5e602e590c2878d2c2acd9a4fe20a4baa7c28022521698ec6" - [[package]] name = "ascii-canvas" version = "3.0.0" @@ -1706,21 +1700,16 @@ version = "0.1.0" dependencies = [ "futures", "gateway-messages", - "http", - "hyper", "omicron-common", "omicron-test-utils", "once_cell", - "ringbuffer", "serde", "serde_with", "slog", "thiserror", "tokio", "tokio-stream", - "tokio-tungstenite", "usdt", - "uuid", ] [[package]] @@ -2905,7 +2894,6 @@ dependencies = [ "omicron-test-utils", "openapi-lint", "openapiv3", - "ringbuffer", "schemars", "serde", "serde_json", @@ -4329,15 +4317,6 @@ dependencies = [ "winapi", ] -[[package]] -name = "ringbuffer" -version = "0.8.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b30a00730a27595dcf899dce512aa031dd650f86aafcb132fd8dd9f409b369d0" -dependencies = [ - "array-init", -] - [[package]] name = "riscv" version = "0.7.0" diff --git a/gateway-cli/src/main.rs b/gateway-cli/src/main.rs index 3337e890d8..d877eb9470 100644 --- a/gateway-cli/src/main.rs +++ b/gateway-cli/src/main.rs @@ -18,6 +18,7 @@ use slog::o; use slog::Drain; use slog::Level; use slog::Logger; +use std::borrow::Cow; use std::net::IpAddr; use std::net::ToSocketAddrs; use std::time::Duration; @@ -25,6 +26,8 @@ use tokio::io::AsyncReadExt; use tokio::io::AsyncWriteExt; use tokio::net::TcpStream; use tokio::select; +use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode; +use tokio_tungstenite::tungstenite::protocol::CloseFrame; use tokio_tungstenite::tungstenite::Message; use tokio_tungstenite::MaybeTlsStream; use tokio_tungstenite::WebSocketStream; @@ -281,6 +284,10 @@ async fn main() -> Result<()> { match c { None => { // channel is closed + _ = ws.close(Some(CloseFrame { + code: CloseCode::Normal, + reason: Cow::Borrowed("client closed stdin"), + })).await; break; } Some(c) => { diff --git a/gateway-sp-comms/Cargo.toml b/gateway-sp-comms/Cargo.toml index 30d6e2f14d..abce837549 100644 --- a/gateway-sp-comms/Cargo.toml +++ b/gateway-sp-comms/Cargo.toml @@ -6,16 +6,11 @@ license = "MPL-2.0" [dependencies] futures = "0.3.21" -http = "0.2.7" -hyper = "0.14.20" -ringbuffer = "0.8" serde = { version = "1.0", features = ["derive"] } serde_with = "2.0.0" thiserror = "1.0.32" -tokio-tungstenite = "0.17" tokio-stream = "0.1.8" usdt = "0.3.1" -uuid = "1.1.0" gateway-messages = { path = "../gateway-messages", features = ["std"] } omicron-common = { path = "../common" } diff --git a/gateway-sp-comms/src/communicator.rs b/gateway-sp-comms/src/communicator.rs index b7e6ad74d3..20a92a9155 100644 --- a/gateway-sp-comms/src/communicator.rs +++ b/gateway-sp-comms/src/communicator.rs @@ -6,11 +6,10 @@ use crate::error::BadResponseType; use crate::error::Error; -use crate::error::SpCommunicationError; use crate::error::StartupError; use crate::management_switch::ManagementSwitch; -use crate::management_switch::SpSocket; use crate::management_switch::SwitchPort; +use crate::single_sp::AttachedSerialConsole; use crate::Elapsed; use crate::SpIdentifier; use crate::SwitchConfig; @@ -22,21 +21,12 @@ use gateway_messages::BulkIgnitionState; use gateway_messages::DiscoverResponse; use gateway_messages::IgnitionCommand; use gateway_messages::IgnitionState; -use gateway_messages::RequestKind; use gateway_messages::ResponseKind; -use gateway_messages::SerialConsole; -use gateway_messages::SpComponent; use gateway_messages::SpState; -use hyper::header; -use hyper::upgrade; -use hyper::Body; use slog::info; use slog::o; use slog::Logger; -use std::sync::Arc; -use std::time::Duration; use tokio::time::Instant; -use tokio_tungstenite::tungstenite::handshake; /// Helper trait that allows us to return an `impl FuturesUnordered<_>` where /// the caller can call `.is_empty()` without knowing the type of the future @@ -67,8 +57,7 @@ impl Communicator { ) -> Result { let log = log.new(o!("component" => "SpCommunicator")); let switch = - ManagementSwitch::new(config, discovery_deadline, log.clone()) - .await?; + ManagementSwitch::new(config, discovery_deadline, &log).await?; info!(&log, "started SP communicator"); Ok(Self { switch }) @@ -112,51 +101,31 @@ impl Communicator { /// known to this communicator. pub fn address_known(&self, sp: SpIdentifier) -> bool { let port = self.switch.switch_port(sp).unwrap(); - self.switch.sp_socket(port).is_some() + self.switch.sp(port).is_some() } /// Ask the local ignition controller for the ignition state of a given SP. pub async fn get_ignition_state( &self, sp: SpIdentifier, - timeout: Timeout, ) -> Result { let controller = self .switch .ignition_controller() .ok_or(Error::LocalIgnitionControllerAddressUnknown)?; let port = self.id_to_port(sp)?; - let request = - RequestKind::IgnitionState { target: port.as_ignition_target() }; - - self.request_response( - &controller, - request, - ResponseKindExt::expect_ignition_state, - Some(timeout), - ) - .await + Ok(controller.ignition_state(port.as_ignition_target()).await?) } /// Ask the local ignition controller for the ignition state of all SPs. pub async fn get_ignition_state_all( &self, - timeout: Timeout, ) -> Result, Error> { let controller = self .switch .ignition_controller() .ok_or(Error::LocalIgnitionControllerAddressUnknown)?; - let request = RequestKind::BulkIgnitionState; - - let bulk_state = self - .request_response( - &controller, - request, - ResponseKindExt::expect_bulk_ignition_state, - Some(timeout), - ) - .await?; + let bulk_state = controller.bulk_ignition_state().await?; // deserializing checks that `num_targets` is reasonably sized, so we // don't need to guard that here @@ -172,7 +141,7 @@ impl Communicator { let port = self .switch .switch_port_from_ignition_target(target) - .ok_or(SpCommunicationError::BadIgnitionTarget(target))?; + .ok_or(Error::BadIgnitionTarget(target))?; let id = self.port_to_id(port); Ok((id, state)) }) @@ -185,31 +154,17 @@ impl Communicator { &self, target_sp: SpIdentifier, command: IgnitionCommand, - timeout: Timeout, ) -> Result<(), Error> { let controller = self .switch .ignition_controller() .ok_or(Error::LocalIgnitionControllerAddressUnknown)?; let target = self.id_to_port(target_sp)?.as_ignition_target(); - let request = RequestKind::IgnitionCommand { target, command }; - - self.request_response( - &controller, - request, - ResponseKindExt::expect_ignition_command_ack, - Some(timeout), - ) - .await + Ok(controller.ignition_command(target, command).await?) } - /// Set up a websocket connection that forwards data to and from the given - /// SP component's serial console. - // TODO: Much of the implementation of this function is shamelessly copied - // from propolis. Should dropshot provide some of this? Is there another - // common place it could live? - // - // NOTE / TODO: This currently does not actually contact the target SP; it + /// Attach to the serial console of `sp`. + // TODO-cleanup: This currently does not actually contact the target SP; it // sets up the websocket connection in the current process which knows how // to relay any information sent or received on that connection to the SP // via UDP. SPs will continuously broadcast any serial console data, even if @@ -221,159 +176,31 @@ impl Communicator { // connection will start working if we later discover the address, but this // is probably not the behavior we want. pub async fn serial_console_attach( - self: &Arc, - request: &mut http::Request, + &self, sp: SpIdentifier, - component: SpComponent, - sp_ack_timeout: Duration, - ) -> Result, Error> { + ) -> Result { let port = self.id_to_port(sp)?; - - if !request - .headers() - .get(header::CONNECTION) - .and_then(|hv| hv.to_str().ok()) - .map(|hv| { - hv.split(|c| c == ',' || c == ' ') - .any(|vs| vs.eq_ignore_ascii_case("upgrade")) - }) - .unwrap_or(false) - { - return Err(Error::BadWebsocketConnection( - "expected connection upgrade", - )); - } - if !request - .headers() - .get(header::UPGRADE) - .and_then(|v| v.to_str().ok()) - .map(|v| { - v.split(|c| c == ',' || c == ' ') - .any(|v| v.eq_ignore_ascii_case("websocket")) - }) - .unwrap_or(false) - { - return Err(Error::BadWebsocketConnection( - "unexpected protocol for upgrade", - )); - } - if request - .headers() - .get(header::SEC_WEBSOCKET_VERSION) - .map(|v| v.as_bytes()) - != Some(b"13") - { - return Err(Error::BadWebsocketConnection( - "missing or invalid websocket version", - )); - } - let accept_key = request - .headers() - .get(header::SEC_WEBSOCKET_KEY) - .map(|hv| hv.as_bytes()) - .map(|key| handshake::derive_accept_key(key)) - .ok_or(Error::BadWebsocketConnection("missing websocket key"))?; - - self.switch.serial_console_attach( - Arc::clone(self), - port, - component, - sp_ack_timeout, - upgrade::on(request), - )?; - - // `.body()` only fails if our headers are bad, which they aren't - // (unless `hyper::handshake` gives us a bogus accept key?), so we're - // safe to unwrap this - Ok(http::Response::builder() - .status(http::StatusCode::SWITCHING_PROTOCOLS) - .header(header::CONNECTION, "Upgrade") - .header(header::UPGRADE, "websocket") - .header(header::SEC_WEBSOCKET_ACCEPT, accept_key) - .body(Body::empty()) - .unwrap()) + let sp = self.switch.sp(port).ok_or(Error::SpAddressUnknown(sp))?; + Ok(sp.serial_console_attach().await?) } /// Detach any existing connection to the given SP component's serial /// console. - /// - /// If there is an existing websocket connection to this SP component, it - /// will be closed. If there isn't, this method does nothing. pub async fn serial_console_detach( &self, sp: SpIdentifier, - component: &SpComponent, ) -> Result<(), Error> { let port = self.id_to_port(sp)?; - self.switch.serial_console_detach(port, component) - } - - /// Send `packet` to the given SP component's serial console. - pub(crate) async fn serial_console_send_packet( - &self, - port: SwitchPort, - packet: SerialConsole, - timeout: Timeout, - ) -> Result<(), Error> { - // We can only send to an SP's serial console if we've attached to it, - // which means we know its address. - // - // TODO how do we handle SP "disconnects"? If `self.switch` keeps the - // old addr around and we send data into the ether until a reconnection - // is established this is fine, but if it detects them and clears out - // addresses this could panic and needs better handling. - let sp = - self.switch.sp_socket(port).expect("lost address of attached SP"); - - self.request_response( - &sp, - RequestKind::SerialConsoleWrite(packet), - ResponseKindExt::expect_serial_console_write_ack, - Some(timeout), - ) - .await + let sp = self.switch.sp(port).ok_or(Error::SpAddressUnknown(sp))?; + sp.serial_console_detach().await; + Ok(()) } /// Get the state of a given SP. - pub async fn get_state( - &self, - sp: SpIdentifier, - timeout: Timeout, - ) -> Result { - self.get_state_maybe_timeout(sp, Some(timeout)).await - } - - /// Get the state of a given SP without a timeout; it is the caller's - /// responsibility to ensure a reasonable timeout is applied higher up in - /// the chain. - // TODO we could have one method that takes `Option` for a timeout, - // and/or apply that to _all_ the methods in this class. I don't want to - // make it easy to accidentally call a method without providing a timeout, - // though, so went with the current design. - pub async fn get_state_without_timeout( - &self, - sp: SpIdentifier, - ) -> Result { - self.get_state_maybe_timeout(sp, None).await - } - - async fn get_state_maybe_timeout( - &self, - sp: SpIdentifier, - timeout: Option, - ) -> Result { + pub async fn get_state(&self, sp: SpIdentifier) -> Result { let port = self.id_to_port(sp)?; - let sp = - self.switch.sp_socket(port).ok_or(Error::SpAddressUnknown(sp))?; - let request = RequestKind::SpState; - - self.request_response( - &sp, - request, - ResponseKindExt::expect_sp_state, - timeout, - ) - .await + let sp = self.switch.sp(port).ok_or(Error::SpAddressUnknown(sp))?; + Ok(sp.state().await?) } /// Query all online SPs. @@ -419,19 +246,6 @@ impl Communicator { }) .collect::>() } - - async fn request_response( - &self, - sp: &SpSocket<'_>, - kind: RequestKind, - map_response_kind: F, - timeout: Option, - ) -> Result - where - F: FnMut(ResponseKind) -> Result, - { - self.switch.request_response(sp, kind, map_response_kind, timeout).await - } } // When we send a request we expect a specific kind of response; the boilerplate diff --git a/gateway-sp-comms/src/error.rs b/gateway-sp-comms/src/error.rs index f3c2b3cc08..6ce8a8cd2e 100644 --- a/gateway-sp-comms/src/error.rs +++ b/gateway-sp-comms/src/error.rs @@ -7,14 +7,44 @@ use crate::SpIdentifier; use gateway_messages::ResponseError; use std::io; -use std::net::SocketAddr; +use std::net::SocketAddrV6; use std::time::Duration; use thiserror::Error; +#[derive(Debug, Error)] +pub enum SpCommunicationError { + #[error("failed to send UDP packet to {addr}: {err}")] + UdpSendTo { addr: SocketAddrV6, err: io::Error }, + #[error("failed to recv UDP packet: {0}")] + UdpRecv(io::Error), + #[error("failed to deserialize SP message from {peer}: {err}")] + Deserialize { peer: SocketAddrV6, err: gateway_messages::HubpackError }, + #[error("RPC call failed (gave up after {0} attempts)")] + ExhaustedNumAttempts(usize), + #[error(transparent)] + BadResponseType(#[from] BadResponseType), + #[error("Error response from SP: {0}")] + SpError(#[from] ResponseError), +} + +#[derive(Debug, Error)] +pub enum UpdateError { + #[error("update image is too large")] + ImageTooLarge, + #[error("error starting update: {0}")] + Start(SpCommunicationError), + #[error("error sending update chunk at offset {offset}: {err}")] + Chunk { offset: u32, err: SpCommunicationError }, +} + +#[derive(Debug, Error)] +#[error("serial console already attached")] +pub struct SerialConsoleAlreadyAttached; + #[derive(Debug, Error)] pub enum StartupError { #[error("error binding to UDP address {addr}: {err}")] - UdpBind { addr: SocketAddr, err: io::Error }, + UdpBind { addr: SocketAddrV6, err: io::Error }, #[error("invalid configuration file: {}", .reasons.join(", "))] InvalidConfig { reasons: Vec }, #[error("error communicating with SP: {0}")] @@ -39,24 +69,18 @@ pub enum Error { "timeout ({timeout:?}) elapsed communicating with {sp:?} on port {port}" )] Timeout { timeout: Duration, port: usize, sp: Option }, + #[error("bogus SP response: specified unknown ignition target {0}")] + BadIgnitionTarget(usize), #[error("error communicating with SP: {0}")] SpCommunicationFailed(#[from] SpCommunicationError), #[error("serial console is already attached")] SerialConsoleAttached, - #[error("websocket connection failure: {0}")] - BadWebsocketConnection(&'static str), } -#[derive(Debug, Error)] -pub enum SpCommunicationError { - #[error("failed to send UDP packet to {addr}: {err}")] - UdpSend { addr: SocketAddr, err: io::Error }, - #[error("error reported by SP: {0}")] - SpError(#[from] ResponseError), - #[error(transparent)] - BadResponseType(#[from] BadResponseType), - #[error("bogus SP response: specified unknown ignition target {0}")] - BadIgnitionTarget(usize), +impl From for Error { + fn from(_: SerialConsoleAlreadyAttached) -> Self { + Self::SerialConsoleAttached + } } #[derive(Debug, Error)] diff --git a/gateway-sp-comms/src/lib.rs b/gateway-sp-comms/src/lib.rs index 68042d5b47..0a8b31fb5f 100644 --- a/gateway-sp-comms/src/lib.rs +++ b/gateway-sp-comms/src/lib.rs @@ -14,8 +14,7 @@ mod communicator; mod management_switch; -mod recv_handler; -pub mod single_sp; +mod single_sp; mod timeout; pub use usdt::register_probes; @@ -30,5 +29,10 @@ pub use management_switch::SpIdentifier; pub use management_switch::SpType; pub use management_switch::SwitchConfig; pub use management_switch::SwitchPortConfig; +pub use single_sp::AttachedSerialConsole; +pub use single_sp::AttachedSerialConsoleRecv; +pub use single_sp::AttachedSerialConsoleSend; +pub use single_sp::SingleSp; +pub use single_sp::DISCOVERY_MULTICAST_ADDR; pub use timeout::Elapsed; pub use timeout::Timeout; diff --git a/gateway-sp-comms/src/management_switch.rs b/gateway-sp-comms/src/management_switch.rs index be2f571693..111cfdb8d7 100644 --- a/gateway-sp-comms/src/management_switch.rs +++ b/gateway-sp-comms/src/management_switch.rs @@ -18,47 +18,26 @@ pub use self::location_map::LocationDeterminationConfig; use self::location_map::LocationMap; pub use self::location_map::SwitchPortConfig; -use crate::error::BadResponseType; -use crate::error::Error; -use crate::error::SpCommunicationError; use crate::error::StartupError; -use crate::recv_handler::RecvHandler; -use crate::Communicator; -use crate::Elapsed; -use crate::Timeout; -use futures::stream::FuturesUnordered; -use futures::Future; -use futures::StreamExt; -use gateway_messages::version; -use gateway_messages::Request; -use gateway_messages::RequestKind; -use gateway_messages::ResponseError; -use gateway_messages::ResponseKind; -use gateway_messages::SerializedSize; -use gateway_messages::SpComponent; -use gateway_messages::SpMessage; -use hyper::upgrade::OnUpgrade; -use omicron_common::backoff; -use omicron_common::backoff::Backoff; +use crate::single_sp::SingleSp; use serde::Deserialize; use serde::Serialize; use serde_with::serde_as; use serde_with::DisplayFromStr; -use slog::debug; +use slog::o; use slog::Logger; use std::collections::HashMap; -use std::io; -use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; use tokio::net::UdpSocket; -use tokio::task::JoinHandle; use tokio::time::Instant; #[serde_as] #[derive(Clone, Debug, PartialEq, Deserialize, Serialize)] pub struct SwitchConfig { pub local_ignition_controller_port: usize, + pub rpc_max_attempts: usize, + pub rpc_per_attempt_timeout_millis: u64, pub location: LocationConfig, #[serde_as(as = "HashMap")] pub port: HashMap, @@ -108,27 +87,15 @@ impl SwitchPort { #[derive(Debug)] pub(crate) struct ManagementSwitch { local_ignition_controller_port: SwitchPort, - recv_handler: Arc, - sockets: Arc>, + sockets: Arc>, location_map: LocationMap, - log: Logger, - - // handle to the running task that calls recv on all `switch_ports` sockets; - // we keep this handle only to kill it when we're dropped - recv_task: JoinHandle<()>, -} - -impl Drop for ManagementSwitch { - fn drop(&mut self) { - self.recv_task.abort(); - } } impl ManagementSwitch { pub(crate) async fn new( config: SwitchConfig, discovery_deadline: Instant, - log: Logger, + log: &Logger, ) -> Result { // begin by binding to all our configured ports; insert them into a map // keyed by the switch port they're listening on @@ -143,7 +110,18 @@ impl ManagementSwitch { .map_err(|err| StartupError::UdpBind { addr, err })?; let port = SwitchPort(port); - sockets.insert(port, socket); + sockets.insert( + port, + SingleSp::new( + socket, + port_config.multicast_addr, + config.rpc_max_attempts, + Duration::from_millis( + config.rpc_per_attempt_timeout_millis, + ), + log.new(o!("switch_port" => port.0)), + ), + ); ports.insert(port, port_config); } @@ -159,44 +137,19 @@ impl ManagementSwitch { }); } - // set up a handler for incoming packets - let recv_handler = - RecvHandler::new(sockets.keys().copied(), log.clone()); - - // spawn background task that listens for incoming packets on all ports - // and passes them to `recv_handler` - let sockets = Arc::new(sockets); - let recv_task = { - let recv_handler = Arc::clone(&recv_handler); - tokio::spawn(recv_task( - Arc::clone(&sockets), - move |port, addr, data| { - recv_handler.handle_incoming_packet(port, addr, data) - }, - log.clone(), - )) - }; - // run discovery to figure out the physical location of ourselves (and // therefore all SPs we talk to) + let sockets = Arc::new(sockets); let location_map = LocationMap::run_discovery( config.location, ports, Arc::clone(&sockets), - Arc::clone(&recv_handler), discovery_deadline, - &log, + log, ) .await?; - Ok(Self { - local_ignition_controller_port, - recv_handler, - location_map, - sockets, - log, - recv_task, - }) + Ok(Self { local_ignition_controller_port, location_map, sockets }) } /// Get the name of our location. @@ -209,21 +162,13 @@ impl ManagementSwitch { /// Get the socket to use to communicate with an SP and the socket address /// of that SP. - pub(crate) fn sp_socket(&self, port: SwitchPort) -> Option> { - self.recv_handler.remote_addr(port).map(|addr| { - let socket = self.sockets.get(&port).unwrap(); - SpSocket { - location_map: Some(&self.location_map), - socket, - addr, - port, - } - }) + pub(crate) fn sp(&self, port: SwitchPort) -> Option<&SingleSp> { + self.sockets.get(&port) } /// Get the socket connected to the local ignition controller. - pub(crate) fn ignition_controller(&self) -> Option> { - self.sp_socket(self.local_ignition_controller_port) + pub(crate) fn ignition_controller(&self) -> Option<&SingleSp> { + self.sp(self.local_ignition_controller_port) } pub(crate) fn switch_port_from_ignition_target( @@ -245,492 +190,4 @@ impl ManagementSwitch { pub(crate) fn switch_port_to_id(&self, port: SwitchPort) -> SpIdentifier { self.location_map.port_to_id(port) } - - /// Spawn a tokio task responsible for forwarding serial console data - /// between the SP component on `port` and the websocket connection provided - /// by `upgrade_fut`. - pub(crate) fn serial_console_attach( - &self, - communicator: Arc, - port: SwitchPort, - component: SpComponent, - sp_ack_timeout: Duration, - upgrade_fut: OnUpgrade, - ) -> Result<(), Error> { - self.recv_handler.serial_console_attach( - communicator, - port, - component, - sp_ack_timeout, - upgrade_fut, - ) - } - - /// Shut down the serial console task associated with the given port and - /// component, if one exists and is attached. - pub(crate) fn serial_console_detach( - &self, - port: SwitchPort, - component: &SpComponent, - ) -> Result<(), Error> { - self.recv_handler.serial_console_detach(port, component) - } - - pub(crate) async fn request_response( - &self, - sp: &SpSocket<'_>, - kind: RequestKind, - map_response_kind: F, - timeout: Option, - ) -> Result - where - F: FnMut(ResponseKind) -> Result, - { - sp.request_response( - &self.recv_handler, - kind, - map_response_kind, - timeout, - &self.log, - ) - .await - } -} - -/// Wrapper for a UDP socket on one of our switch ports that knows the address -/// of the SP connected to this port. -#[derive(Debug)] -pub(crate) struct SpSocket<'a> { - location_map: Option<&'a LocationMap>, - socket: &'a UdpSocket, - addr: SocketAddr, - port: SwitchPort, -} - -impl SpSocket<'_> { - // TODO The `timeout` we take here is the overall timeout for receiving a - // response. We only resend the request if the SP sends us a "busy" - // response; if the SP doesn't answer at all we never resend the request. - // Should we take a separate timeout for individual sends? E.g., with an - // overall timeout of 5 sec and a per-request timeout of 1 sec, we could - // treat "no response at 1 sec" the same as a "busy" and resend the request. - async fn request_response( - &self, - recv_handler: &RecvHandler, - mut kind: RequestKind, - mut map_response_kind: F, - timeout: Option, - log: &Logger, - ) -> Result - where - F: FnMut(ResponseKind) -> Result, - { - // helper to wrap a future in a timeout if we have one - async fn maybe_with_timeout( - timeout: Option, - fut: F, - ) -> Result - where - F: Future, - { - match timeout { - Some(t) => t.timeout_at(fut).await, - None => Ok(fut.await), - } - } - - // We'll use exponential backoff if and only if the SP responds with - // "busy"; any other error will cause the loop below to terminate. - let mut backoff = backoff::internal_service_policy(); - - loop { - // It would be nicer to use `backoff::retry()` instead of manually - // stepping the backoff policy, but the dance we do with `kind` to - // avoid cloning it is hard to map into `retry()` in a way that - // satisfies the borrow checker. ("The dance we do with `kind` to - // avoid cloning it" being that we move it into `request` below, and - // on a busy response from the SP we move it back out into the - // `kind` local var.) - let duration = backoff - .next_backoff() - .expect("internal backoff policy gave up"); - maybe_with_timeout(timeout, tokio::time::sleep(duration)) - .await - .map_err(|err| Error::Timeout { - timeout: err.duration(), - port: self.port.0, - sp: self.location_map.map(|lm| lm.port_to_id(self.port)), - })?; - - // update our recv_handler to expect a response for this request ID - let (request_id, response_fut) = - recv_handler.register_request_id(self.port); - - // Serialize and send our request. We know `buf` is large enough for - // any `Request`, so unwrapping here is fine. - let request = Request { version: version::V1, request_id, kind }; - let mut buf = [0; Request::MAX_SIZE]; - let n = gateway_messages::serialize(&mut buf, &request).unwrap(); - let serialized_request = &buf[..n]; - - // Actual communication, guarded by `timeout` if it's not `None`. - let result = maybe_with_timeout(timeout, async { - debug!( - log, "sending request"; - "request" => ?request, - "dest_addr" => %self.addr, - "port" => ?self.port, - ); - self.socket - .send_to(serialized_request, self.addr) - .await - .map_err(|err| SpCommunicationError::UdpSend { - addr: self.addr, - err, - })?; - - Ok::(response_fut.await?) - }) - .await - .map_err(|err| Error::Timeout { - timeout: err.duration(), - port: self.port.0, - sp: self.location_map.map(|lm| lm.port_to_id(self.port)), - })?; - - match result { - Ok(response_kind) => { - return map_response_kind(response_kind) - .map_err(SpCommunicationError::from) - .map_err(Error::from) - } - Err(SpCommunicationError::SpError(ResponseError::Busy)) => { - debug!( - log, - "SP busy; sleeping before retrying send"; - "dest_addr" => %self.addr, - "port" => ?self.port, - ); - - // move `kind` back into local var; required to satisfy - // borrow check of this loop - kind = request.kind; - } - Err(err) => return Err(err.into()), - } - } - } -} - -async fn recv_task( - ports: Arc>, - mut recv_handler: F, - log: Logger, -) where - F: FnMut(SwitchPort, SocketAddr, &[u8]), -{ - // helper function to tag a socket's `.readable()` future with an index; we - // need this to make rustc happy about the types we push into - // `recv_all_sockets` below - async fn readable_with_port( - port: SwitchPort, - sock: &UdpSocket, - ) -> (SwitchPort, &UdpSocket, io::Result<()>) { - let result = sock.readable().await; - (port, sock, result) - } - - // set up collection of futures tracking readability of all switch port - // sockets - let mut recv_all_sockets = FuturesUnordered::new(); - for (port, sock) in ports.iter() { - recv_all_sockets.push(readable_with_port(*port, &sock)); - } - - let mut buf = [0; SpMessage::MAX_SIZE]; - - loop { - // `recv_all_sockets.next()` will never return `None` because we - // immediately push a new future into it every time we pull one out - // (to reregister readable interest in the corresponding socket) - let (port, sock, result) = recv_all_sockets.next().await.unwrap(); - - // checking readability of the socket can't fail without violating some - // internal state in tokio in a presumably-strage way; at that point we - // don't know how to recover, so just panic and let something restart us - if let Err(err) = result { - panic!("error in socket readability: {} (port={:?})", err, port); - } - - match sock.try_recv_from(&mut buf) { - Ok((n, addr)) => { - let buf = &buf[..n]; - probes::recv_packet!(|| ( - &addr, - &port, - buf.as_ptr() as usize as u64, - buf.len() as u64 - )); - debug!( - log, "received {} bytes", n; - "port" => ?port, "addr" => %addr, - ); - recv_handler(port, addr, &buf[..n]); - } - // spurious wakeup; no need to log, just continue - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => (), - // other kinds of errors should be rare/impossible given our use of - // UDP and the way tokio is structured; we don't know how to recover - // from them, so just panic and let something restart us - Err(err) => { - panic!("error in recv_from: {} (port={:?})", err, port); - } - } - - // push a new future requesting readability interest, as noted above - recv_all_sockets.push(readable_with_port(port, sock)); - } -} - -#[cfg(test)] -mod tests { - use super::*; - use futures::future; - use omicron_test_utils::dev::poll; - use omicron_test_utils::dev::poll::CondCheckError; - use std::collections::HashMap; - use std::convert::Infallible; - use std::mem; - use std::sync::Mutex; - use std::time::Duration; - - // collection of sockets that act like SPs for the purposes of these tests - struct Harness { - switches: Vec, - sleds: Vec, - pscs: Vec, - } - - impl Harness { - async fn new() -> Self { - const NUM_SWITCHES: usize = 1; - const NUM_SLEDS: usize = 2; - const NUM_POWER_CONTROLLERS: usize = 1; - - let mut switches = Vec::with_capacity(NUM_SWITCHES); - let mut sleds = Vec::with_capacity(NUM_SLEDS); - let mut pscs = Vec::with_capacity(NUM_POWER_CONTROLLERS); - for _ in 0..NUM_SWITCHES { - switches.push(UdpSocket::bind("127.0.0.1:0").await.unwrap()); - } - for _ in 0..NUM_SLEDS { - sleds.push(UdpSocket::bind("127.0.0.1:0").await.unwrap()); - } - for _ in 0..NUM_POWER_CONTROLLERS { - pscs.push(UdpSocket::bind("127.0.0.1:0").await.unwrap()); - } - - Self { switches, sleds, pscs } - } - - fn all_sockets(&self) -> impl Iterator { - self.switches - .iter() - .chain(self.sleds.iter()) - .chain(self.pscs.iter()) - } - - async fn make_management_switch( - &self, - mut recv_callback: F, - ) -> ManagementSwitch - where - F: FnMut(SwitchPort, &[u8]) + Send + 'static, - { - let log = Logger::root(slog::Discard, slog::o!()); - - // Skip the discovery process by constructing a `ManagementSwitch` - // by hand - let mut sockets = HashMap::new(); - let mut port_to_id = HashMap::new(); - let mut sp_addrs = HashMap::new(); - for (typ, sp_sockets) in [ - (SpType::Switch, &self.switches), - (SpType::Sled, &self.sleds), - (SpType::Power, &self.pscs), - ] { - for (slot, sp_sock) in sp_sockets.iter().enumerate() { - let port = SwitchPort(sockets.len()); - let id = SpIdentifier { typ, slot }; - let socket = UdpSocket::bind("127.0.0.1:0").await.unwrap(); - - sp_addrs.insert(port, sp_sock.local_addr().unwrap()); - port_to_id.insert(port, id); - sockets.insert(port, socket); - } - } - let sockets = Arc::new(sockets); - - let recv_handler = - RecvHandler::new(sockets.keys().copied(), log.clone()); - - // Since we skipped the discovery process, we have to tell - // `recv_handler` what all the fake SP addresses are. - for (port, addr) in sp_addrs { - recv_handler.set_remote_addr(port, addr); - } - - let recv_task = { - let recv_handler = Arc::clone(&recv_handler); - tokio::spawn(recv_task( - Arc::clone(&sockets), - move |port, addr, data| { - recv_handler.handle_incoming_packet(port, addr, data); - recv_callback(port, data); - }, - log.clone(), - )) - }; - - let location_map = - LocationMap::new_raw(String::from("test"), port_to_id); - let local_ignition_controller_port = location_map - .id_to_port(SpIdentifier { typ: SpType::Switch, slot: 0 }) - .unwrap(); - - ManagementSwitch { - local_ignition_controller_port, - recv_handler, - location_map, - sockets, - log, - recv_task, - } - } - } - - #[tokio::test] - async fn test_recv_task() { - let harness = Harness::new().await; - - // create a switch, pointed at our harness's fake SPs, with a - // callback that accumlates received packets into a hashmap - let received: Arc>>>> = - Arc::default(); - - let switch = harness - .make_management_switch({ - let received = Arc::clone(&received); - move |port, data: &[u8]| { - let mut received = received.lock().unwrap(); - received.entry(port).or_default().push(data.to_vec()); - } - }) - .await; - - // Actual test - send a bunch of data to each of the ports... - let mut expected: HashMap>> = HashMap::new(); - for i in 0..10 { - for (port_num, sock) in harness.all_sockets().enumerate() { - let port = SwitchPort(port_num); - let data = format!("message {} to {:?}", i, port).into_bytes(); - - let addr = - switch.sockets.get(&port).unwrap().local_addr().unwrap(); - sock.send_to(&data, addr).await.unwrap(); - expected.entry(port).or_default().push(data); - } - } - // ... and confirm we received them all. messages should be in order per - // socket, but we don't check ordering of messages across sockets since - // that may vary - { - let received = Arc::clone(&received); - poll::wait_for_condition( - move || { - let result = if expected == *received.lock().unwrap() { - Ok(()) - } else { - Err(CondCheckError::::NotYet) - }; - future::ready(result) - }, - &Duration::from_millis(10), - &Duration::from_secs(1), - ) - .await - .unwrap(); - } - - // before dropping `switch`, confirm that the count on `received` is - // exactly 2: us and the receive task. after dropping `switch` it will - // be only us. - assert_eq!(Arc::strong_count(&received), 2); - - // dropping `switch` should cancel its corresponding recv task, which - // we can confirm by checking that the ref count on `received` drops to - // 1 (just us). we have to poll for this since it's not necessarily - // immediate; recv_task is presumably running on a tokio thread - mem::drop(switch); - poll::wait_for_condition( - move || { - let result = match Arc::strong_count(&received) { - 1 => Ok(()), - 2 => Err(CondCheckError::::NotYet), - n => panic!("bogus count {}", n), - }; - future::ready(result) - }, - &Duration::from_millis(10), - &Duration::from_secs(1), - ) - .await - .unwrap(); - } - - #[tokio::test] - async fn test_sp_socket() { - let harness = Harness::new().await; - - let switch = harness.make_management_switch(|_, _| {}).await; - - // confirm messages sent to the switch's sp sockets show up on our - // harness sockets - let mut buf = [0; SpMessage::MAX_SIZE]; - for (typ, sp_sockets) in [ - (SpType::Switch, &harness.switches), - (SpType::Sled, &harness.sleds), - (SpType::Power, &harness.pscs), - ] { - for (slot, sp_sock) in sp_sockets.iter().enumerate() { - let port = switch - .location_map - .id_to_port(SpIdentifier { typ, slot }) - .unwrap(); - let sock = switch.sp_socket(port).unwrap(); - let local_addr = sock.socket.local_addr().unwrap(); - - let message = format!("{:?} {}", typ, slot).into_bytes(); - sock.socket.send_to(&message, sock.addr).await.unwrap(); - - let (n, addr) = sp_sock.recv_from(&mut buf).await.unwrap(); - - // confirm we received the expected message from the - // corresponding switch port - assert_eq!(&buf[..n], message); - assert_eq!(addr, local_addr); - } - } - } -} - -#[usdt::provider(provider = "gateway_sp_comms")] -mod probes { - fn recv_packet( - _source: &SocketAddr, - _port: &SwitchPort, - _data: u64, // TODO actually a `*const u8`, but that isn't allowed by usdt - _len: u64, - ) { - } } diff --git a/gateway-sp-comms/src/management_switch/location_map.rs b/gateway-sp-comms/src/management_switch/location_map.rs index e52f0c825b..f97de075d2 100644 --- a/gateway-sp-comms/src/management_switch/location_map.rs +++ b/gateway-sp-comms/src/management_switch/location_map.rs @@ -5,19 +5,13 @@ // Copyright 2022 Oxide Computer Company use super::SpIdentifier; -use super::SpSocket; use super::SwitchPort; -use crate::communicator::ResponseKindExt; use crate::error::StartupError; -use crate::recv_handler::RecvHandler; -use crate::Timeout; +use crate::single_sp::SingleSp; use futures::stream::FuturesUnordered; use futures::Stream; use futures::StreamExt; -use gateway_messages::RequestKind; use gateway_messages::SpPort; -use omicron_common::backoff; -use omicron_common::backoff::Backoff; use serde::Deserialize; use serde::Serialize; use slog::debug; @@ -26,10 +20,8 @@ use slog::Logger; use std::collections::HashMap; use std::collections::HashSet; use std::convert::TryFrom; -use std::net::SocketAddr; +use std::net::SocketAddrV6; use std::sync::Arc; -use std::time::Duration; -use tokio::net::UdpSocket; use tokio::sync::mpsc; use tokio::time::Instant; use tokio_stream::wrappers::ReceiverStream; @@ -40,14 +32,14 @@ pub struct SwitchPortConfig { /// Data link addresses; this is the address on which we should bind a /// socket, which will be tagged with the appropriate VLAN for this switch /// port (see RFD 250). - pub data_link_addr: SocketAddr, + pub data_link_addr: SocketAddrV6, /// Multicast address used to find the SP connected to this port. // TODO: The multicast address used should be a single address, not a // per-port address. For now we configure it per-port to make dev/test on a // single system easier; we can run multiple simulated SPs that all listen // to different multicast addresses on one host. - pub multicast_addr: SocketAddr, + pub multicast_addr: SocketAddrV6, /// Map defining the logical identifier of the SP connected to this port for /// each of the possible locations where MGS is running (see @@ -93,25 +85,10 @@ pub(super) struct LocationMap { } impl LocationMap { - // For unit tests we don't want to have to run discovery, so allow - // construction of a canned `LocationMap`. - #[cfg(test)] - pub(super) fn new_raw( - location_name: String, - port_to_id: HashMap, - ) -> Self { - let mut id_to_port = HashMap::with_capacity(port_to_id.len()); - for (&port, &id) in port_to_id.iter() { - id_to_port.insert(id, port); - } - Self { location_name, port_to_id, id_to_port } - } - pub(super) async fn run_discovery( config: LocationConfig, ports: HashMap, - sockets: Arc>, - recv_handler: Arc, + sockets: Arc>, deadline: Instant, log: &Logger, ) -> Result { @@ -132,7 +109,6 @@ impl LocationMap { discover_sps( &sockets, ports, - &recv_handler, location_determination, refined_locations_tx, &log, @@ -358,63 +334,28 @@ impl TryFrom<(&'_ HashMap, LocationConfig)> /// and the list of locations we could be in based on the SP's response on that /// port. Our spawner is responsible for collecting/using those messages. async fn discover_sps( - sockets: &HashMap, + sockets: &HashMap, port_config: HashMap, - recv_handler: &RecvHandler, mut location_determination: Vec, refined_locations: mpsc::Sender<(SwitchPort, HashSet)>, log: &Logger, ) { - // Build a collection of futures that sends discovery packets on every port; - // each future runs until it hears back from an SP (possibly running forever - // if there is no SP listening on the other end of that port's connection). + // Build a collection of futures representing the results of discovering the + // SP address for each switch port in `sockets`. let mut futs = FuturesUnordered::new(); - for (port, config) in port_config { + for (switch_port, _config) in port_config { futs.push(async move { - // construct a socket pointed to a multicast addr instead of a - // specific, known addr - let socket = SpSocket { - location_map: None, - port, - addr: config.multicast_addr, - // all ports in `port_config` also get sockets bound to them; - // unwrapping this lookup is fine - socket: sockets.get(&port).unwrap(), - }; - - let mut backoff = backoff::internal_service_policy(); + // all ports in `port_config` also get sockets bound to them; + // unwrapping this lookup is fine + let sp = sockets.get(&switch_port).unwrap(); + + let mut addr_watch = sp.sp_addr_watch().clone(); loop { - let duration = backoff - .next_backoff() - .expect("internal backoff policy gave up"); - tokio::time::sleep(duration).await; - - let result = socket - .request_response( - &recv_handler, - RequestKind::Discover, - ResponseKindExt::expect_discover, - // TODO should this timeout be configurable or itself - // have some kind of backoff? we're inside a - // `backoff::retry()` loop, but if an SP is alive but - // slow (i.e., taking longer than this timeout to reply) - // we'll never hear it - the response will show up late - // and we'll ignore it. For now just leave it at some - // reasonably large number; this may solve itself when - // we move to some kind of authenticated comms channel. - Some(Timeout::from_now(Duration::from_secs(5))), - &log, - ) - .await; - - match result { - Ok(response) => return (port, response), - Err(err) => { - debug!( - log, "discovery failed; will retry"; - "port" => ?port, - "err" => %err, - ); + let current = *addr_watch.borrow(); + match current { + Some((_addr, sp_port)) => return (switch_port, sp_port), + None => { + addr_watch.changed().await.unwrap(); } } } @@ -422,25 +363,25 @@ async fn discover_sps( } // Wait for responses. - while let Some((port, response)) = futs.next().await { + while let Some((switch_port, sp_port)) = futs.next().await { // See if this port can participate in location determination. let pos = match location_determination .iter() - .position(|d| d.switch_port == port) + .position(|d| d.switch_port == switch_port) { Some(pos) => pos, None => { info!( log, "received discovery response (not used for location)"; - "port" => ?port, - "response" => ?response, + "switch_port" => ?switch_port, + "sp_port" => ?sp_port, ); continue; } }; let determination = location_determination.remove(pos); - let refined = match response.sp_port { + let refined = match sp_port { SpPort::One => determination.sp_port_1, SpPort::Two => determination.sp_port_2, }; @@ -448,7 +389,7 @@ async fn discover_sps( // the only failure possible here is that the receiver is gone; that's // harmless for us (e.g., maybe it's already fully determined the // location and doesn't care about more messages) - let _ = refined_locations.send((port, refined)).await; + let _ = refined_locations.send((switch_port, refined)).await; } // TODO If we're exiting, we've now heard from an SP on every port. Is there @@ -532,8 +473,8 @@ mod tests { let bad_ports = HashMap::from([( SwitchPort(0), SwitchPortConfig { - data_link_addr: "127.0.0.1:0".parse().unwrap(), - multicast_addr: "127.0.0.1:0".parse().unwrap(), + data_link_addr: "[::1]:0".parse().unwrap(), + multicast_addr: "[::1]:0".parse().unwrap(), location: HashMap::from([ (String::from("a"), SpIdentifier::new(SpType::Sled, 0)), // missing "b", has extraneous "c" @@ -600,8 +541,8 @@ mod tests { let good_ports = HashMap::from([( SwitchPort(0), SwitchPortConfig { - data_link_addr: "127.0.0.1:0".parse().unwrap(), - multicast_addr: "127.0.0.1:0".parse().unwrap(), + data_link_addr: "[::1]:0".parse().unwrap(), + multicast_addr: "[::1]:0".parse().unwrap(), location: HashMap::from([ (String::from("a"), SpIdentifier::new(SpType::Sled, 0)), (String::from("b"), SpIdentifier::new(SpType::Sled, 1)), diff --git a/gateway-sp-comms/src/recv_handler/mod.rs b/gateway-sp-comms/src/recv_handler/mod.rs deleted file mode 100644 index 8e88034087..0000000000 --- a/gateway-sp-comms/src/recv_handler/mod.rs +++ /dev/null @@ -1,568 +0,0 @@ -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at https://mozilla.org/MPL/2.0/. - -// Copyright 2022 Oxide Computer Company - -use crate::error::Error; -use crate::management_switch::SwitchPort; -use crate::Communicator; -use crate::Timeout; -use futures::future::Fuse; -use futures::Future; -use futures::FutureExt; -use futures::SinkExt; -use futures::StreamExt; -use futures::TryFutureExt; -use gateway_messages::sp_impl::SerialConsolePacketizer; -use gateway_messages::version; -use gateway_messages::ResponseError; -use gateway_messages::ResponseKind; -use gateway_messages::SerialConsole; -use gateway_messages::SpComponent; -use gateway_messages::SpMessage; -use gateway_messages::SpMessageKind; -use hyper::upgrade::OnUpgrade; -use hyper::upgrade::Upgraded; -use slog::debug; -use slog::error; -use slog::info; -use slog::trace; -use slog::warn; -use slog::Logger; -use std::borrow::Cow; -use std::collections::hash_map::Entry; -use std::collections::HashMap; -use std::collections::VecDeque; -use std::convert::TryInto; -use std::net::SocketAddr; -use std::sync::atomic::AtomicU32; -use std::sync::atomic::Ordering; -use std::sync::Arc; -use std::sync::Mutex; -use std::time::Duration; -use tokio::sync::mpsc; -use tokio::sync::oneshot; -use tokio_tungstenite::tungstenite::protocol::Role; -use tokio_tungstenite::tungstenite::protocol::WebSocketConfig; -use tokio_tungstenite::WebSocketStream; - -mod request_response_map; - -use self::request_response_map::RequestResponseMap; -use self::request_response_map::ResponseIngestResult; - -/// Handler for incoming packets received on management switch ports. -/// -/// When [`RecvHandler::new()`] is created, it starts an indefinite tokio task -/// that calls [`RecvHandler::handle_incoming_packet()`] for each UDP packet -/// received. Those packets come in two flavors: responses to requests that we -/// sent to an SP, and unprompted messages generated by an SP. The flow for -/// request / response is: -/// -/// 1. [`RecvHandler::register_request_id()`] is called with the 32-bit request -/// ID associated with the request. This returns a future that will be -/// fulfilled with the response once it arrives. This method should be called -/// after determining the request ID but before actually sending the request -/// to avoid a race window where the response could be received before the -/// receive handler task knows to expect it. Internally, this future holds -/// the receiving half of a [`tokio::oneshot`] channel. -/// 2. The request packet is sent to the SP. -/// 3. The requester `.await`s the future returned by `register_request_id()`. -/// 4. If the future is dropped before a response is received (e.g., due to a -/// timeout), dropping the future will unregister the request ID provided in -/// step 1 (see [`RequestResponseMap::wait_for_response()`] for details). -/// 5. Assuming the future has not been dropped when the response arrives, -/// [`RecvHandler::handle_incoming_packet()`] will look up the request ID and -/// send the response on the sending half of the [`tokio::oneshot`] channel -/// corresponding to the future returned in step 1, fulfilling it. -#[derive(Debug)] -pub(crate) struct RecvHandler { - request_id: AtomicU32, - sp_state: HashMap, - log: Logger, -} - -impl RecvHandler { - /// Create a new `RecvHandler` that is aware of all ports described by - /// `switch`. - pub(crate) fn new( - ports: impl ExactSizeIterator, - log: Logger, - ) -> Arc { - // prime `sp_state` with all known ports of the switch - let mut sp_state = HashMap::with_capacity(ports.len()); - for port in ports { - sp_state.insert(port, SingleSpState::default()); - } - - // TODO: Should we init our request_id randomly instead of always - // starting at 0? - Arc::new(Self { request_id: AtomicU32::new(0), sp_state, log }) - } - - fn sp_state(&self, port: SwitchPort) -> &SingleSpState { - // SwitchPort instances can only be created by `ManagementSwitch`, so we - // should never be able to instantiate a port that we don't have in - // `self.sp_state` (which we initialize with all ports declared by the - // switch we were given). - self.sp_state.get(&port).expect("invalid switch port") - } - - /// Spawn a tokio task responsible for forwarding serial console data - /// between the SP component on `port` and the websocket connection provided - /// by `upgrade_fut`. - pub(crate) fn serial_console_attach( - &self, - communicator: Arc, - port: SwitchPort, - component: SpComponent, - sp_ack_timeout: Duration, - upgrade_fut: OnUpgrade, - ) -> Result<(), Error> { - // lazy closure to spawn the task; called below _unless_ we already have - // at attached task with this SP - let spawn_task = move || { - let (detach, detach_rx) = oneshot::channel(); - let (packets_from_sp, packets_from_sp_rx) = - mpsc::unbounded_channel(); - let log = self.log.new(slog::o!( - "port" => format!("{:?}", port), - "component" => format!("{:?}", component), - )); - - tokio::spawn(async move { - let upgraded = match upgrade_fut.await { - Ok(u) => u, - Err(e) => { - error!(log, "serial task failed"; "err" => %e); - return; - } - }; - let config = WebSocketConfig { - max_send_queue: Some(4096), - ..Default::default() - }; - let ws_stream = WebSocketStream::from_raw_socket( - upgraded, - Role::Server, - Some(config), - ) - .await; - - let task = SerialConsoleTask { - communicator, - port, - component, - detach: detach_rx, - packets: packets_from_sp_rx, - ws_stream, - sp_ack_timeout, - }; - match task.run(&log).await { - Ok(()) => debug!(log, "serial task complete"), - Err(e) => { - error!(log, "serial task failed"; "err" => %e) - } - } - }); - - SerialConsoleTaskHandle { packets_from_sp, detach } - }; - - let mut tasks = - self.sp_state(port).serial_console_tasks.lock().unwrap(); - match tasks.entry(component) { - Entry::Occupied(mut slot) => { - if slot.get().is_attached() { - Err(Error::SerialConsoleAttached) - } else { - // old task is already detached; replace it - let _ = slot.insert(spawn_task()); - Ok(()) - } - } - Entry::Vacant(slot) => { - slot.insert(spawn_task()); - Ok(()) - } - } - } - - /// Shut down the serial console task associated with the given port and - /// component, if one exists and is attached. - pub(crate) fn serial_console_detach( - &self, - port: SwitchPort, - component: &SpComponent, - ) -> Result<(), Error> { - let mut tasks = - self.sp_state(port).serial_console_tasks.lock().unwrap(); - if let Some(task) = tasks.remove(component) { - // If send fails, we're already detached. - let _ = task.detach.send(()); - } - Ok(()) - } - - /// Returns a new request ID and a future that will complete when we receive - /// a response on the given `port` with that request ID. - /// - /// Panics if `port` is not one of the ports defined by the `switch` given - /// to this `RecvHandler` when it was constructed. - pub(crate) fn register_request_id( - &self, - port: SwitchPort, - ) -> (u32, impl Future> + '_) - { - let request_id = self.request_id.fetch_add(1, Ordering::Relaxed); - (request_id, self.sp_state(port).requests.wait_for_response(request_id)) - } - - /// Returns the address of the SP connected to `port`, if we know it. - pub(crate) fn remote_addr(&self, port: SwitchPort) -> Option { - *self.sp_state(port).addr.lock().unwrap() - } - - // Only available for tests: set the remote address of a given port. - #[cfg(test)] - pub(crate) fn set_remote_addr(&self, port: SwitchPort, addr: SocketAddr) { - *self.sp_state(port).addr.lock().unwrap() = Some(addr); - } - - pub(crate) fn handle_incoming_packet( - &self, - port: SwitchPort, - addr: SocketAddr, - buf: &[u8], - ) { - trace!(&self.log, "received {} bytes from {:?}", buf.len(), port); - - // the first four bytes of packets we expect is always a version number; - // check for that first - // - // TODO? We're (ab)using our knowledge of our packet wire format here - - // knowledge contained in another crate - both in expecting that the - // first four bytes are the version (perfectly reasonable) and that - // they're stored in little endian (a bit less reasonable, but still - // probably okay). We could consider moving this check into - // `gateway_messages` itself? - let version_raw = match buf.get(0..4) { - Some(bytes) => u32::from_le_bytes(bytes.try_into().unwrap()), - None => { - error!(&self.log, "discarding too-short packet"); - return; - } - }; - match version_raw { - version::V1 => (), - _ => { - error!( - &self.log, - "discarding message with unsupported version {}", - version_raw - ); - return; - } - } - - // parse into an `SpMessage` - let sp_msg = match gateway_messages::deserialize::(buf) { - Ok((msg, _extra)) => { - // TODO should we check that `extra` is empty? if the - // response is maximal size any extra data is silently - // discarded anyway, so probably not? - msg - } - Err(err) => { - error!( - &self.log, "discarding malformed message"; - "err" => %err, - "raw" => ?buf, - ); - return; - } - }; - debug!(&self.log, "received {:?} from {:?}", sp_msg, port); - - // update our knowledge of the sender's address - let state = self.sp_state(port); - match state.addr.lock().unwrap().replace(addr) { - None => { - // expected but rare: our first packet on this port - debug!( - &self.log, "discovered remote address for port"; - "port" => ?port, - "addr" => %addr, - ); - } - Some(old) if old == addr => { - // expected; we got another packet from the expected address - } - Some(old) => { - // unexpected; the remote address changed - // TODO-security - What should we do here? Could the sled have - // been physically replaced and we're now hearing from a new SP? - // This question/TODO may go away on its own if we add an - // authenticated channel? - warn!( - &self.log, "updated remote address for port"; - "port" => ?port, - "old_addr" => %old, - "new_addr" => %addr, - ); - } - } - - // decide whether this is a response to an outstanding request or an - // unprompted message - match sp_msg.kind { - SpMessageKind::Response { request_id, result } => { - self.handle_response(port, request_id, result); - } - SpMessageKind::SerialConsole(serial_console) => { - self.handle_serial_console(port, serial_console); - } - } - } - - fn handle_response( - &self, - port: SwitchPort, - request_id: u32, - result: Result, - ) { - probes::recv_response!(|| (&port, request_id, &result)); - match self.sp_state(port).requests.ingest_response(&request_id, result) - { - ResponseIngestResult::Ok => (), - ResponseIngestResult::UnknownRequestId => { - error!( - &self.log, - "discarding unexpected response {} from {:?} (possibly past timeout?)", - request_id, - port, - ); - } - } - } - - fn handle_serial_console(&self, port: SwitchPort, packet: SerialConsole) { - probes::recv_serial_console!(|| ( - &port, - &packet.component, - packet.offset, - packet.data.as_ptr() as usize as u64, - u64::from(packet.len) - )); - - let tasks = self.sp_state(port).serial_console_tasks.lock().unwrap(); - - // if we have a task managing an active client connection, forward this - // packet to it; otherwise, drop it on the floor. - match tasks.get(&packet.component).map(|task| { - let data = packet.data[..usize::from(packet.len)].to_vec(); - task.recv_data_from_sp(data) - }) { - Some(Ok(())) => { - debug!( - self.log, - "forwarded serial console packet to attached client" - ); - } - - // The only error we can get from `recv_data_from_sp()` is "the task - // went away", which from our point of view is the same as it not - // existing in the first place. - Some(Err(_)) | None => { - debug!( - self.log, - "discarding serial console packet (no attached client)" - ); - } - } - } -} - -#[derive(Debug, Default)] -struct SingleSpState { - addr: Mutex>, - requests: RequestResponseMap>, - serial_console_tasks: Mutex>, -} - -#[derive(Debug)] -struct SerialConsoleTaskHandle { - packets_from_sp: mpsc::UnboundedSender>, - detach: oneshot::Sender<()>, -} - -impl SerialConsoleTaskHandle { - fn is_attached(&self) -> bool { - !self.detach.is_closed() - } - fn recv_data_from_sp( - &self, - data: Vec, - ) -> Result<(), mpsc::error::SendError>> { - self.packets_from_sp.send(data) - } -} - -struct SerialConsoleTask { - communicator: Arc, - port: SwitchPort, - component: SpComponent, - detach: oneshot::Receiver<()>, - packets: mpsc::UnboundedReceiver>, - ws_stream: WebSocketStream, - sp_ack_timeout: Duration, -} - -impl SerialConsoleTask { - async fn run(mut self, log: &Logger) -> Result<(), SerialTaskError> { - use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode; - use tokio_tungstenite::tungstenite::protocol::CloseFrame; - use tokio_tungstenite::tungstenite::Message; - - let (mut ws_sink, mut ws_stream) = self.ws_stream.split(); - - // TODO Currently we do not apply any backpressure on the SP and are - // willing to buffer up an arbitrary amount of data in memory. Is it - // reasonable to apply backpressure to the SP over UDP? Should we have - // caps on memory and start discarding data if we exceed them? We _do_ - // apply backpressure to the websocket, delaying reading from it if we - // still have data waiting to be sent to the SP. - let mut data_from_sp: VecDeque> = VecDeque::new(); - let mut data_to_sp: Vec = Vec::new(); - let mut packetizer_to_sp = SerialConsolePacketizer::new(self.component); - - loop { - let ws_send = if let Some(data) = data_from_sp.pop_front() { - ws_sink.send(Message::Binary(data)).fuse() - } else { - Fuse::terminated() - }; - - let ws_recv; - let sp_send; - if data_to_sp.is_empty() { - sp_send = Fuse::terminated(); - ws_recv = ws_stream.next().fuse(); - } else { - ws_recv = Fuse::terminated(); - - let (packet, _remaining) = - packetizer_to_sp.first_packet(data_to_sp.as_slice()); - let packet_data_len = usize::from(packet.len); - - sp_send = self - .communicator - .serial_console_send_packet( - self.port, - packet, - Timeout::from_now(self.sp_ack_timeout), - ) - .map_ok(move |()| packet_data_len) - .fuse(); - } - - tokio::select! { - // Poll in the order written - biased; - - // It's important we always poll the detach channel first - // so that a constant stream of incoming/outgoing messages - // don't cause us to ignore a detach - _ = &mut self.detach => { - info!(log, "detaching from serial console"); - let close = CloseFrame { - code: CloseCode::Policy, - reason: Cow::Borrowed("serial console was detached"), - }; - ws_sink.send(Message::Close(Some(close))).await?; - break; - } - - // Send a UDP packet to the SP - send_success = sp_send => { - let n = send_success?; - data_to_sp.drain(..n); - } - - // Receive a UDP packet from the SP. - data = self.packets.recv() => { - // The sending half of `packets` is held by the - // `SerialConsoleTask` that created us. It is only dropped - // if we are detached; i.e., no longer running; therefore, - // we can safely unwrap this recv. - let data = data.expect("sending half dropped"); - data_from_sp.push_back(data); - } - - // Send a previously-received UDP packet of data to the websocket - // client - write_success = ws_send => { - write_success?; - } - - // Receive from the websocket to send to the SP. - msg = ws_recv => { - match msg { - Some(Ok(Message::Binary(mut data))) => { - // we only populate ws_recv when we have no data - // currently queued up; sanity check that here - assert!(data_to_sp.is_empty()); - data_to_sp.append(&mut data); - } - Some(Ok(Message::Close(_))) | None => { - info!( - log, - "remote end closed websocket; terminating task", - ); - break; - } - Some(other) => { - let wrong_message = other?; - error!( - log, - "bogus websocket message; terminating task"; - "message" => ?wrong_message, - ); - break; - } - } - } - } - } - - Ok(()) - } -} - -#[derive(Debug, thiserror::Error)] -enum SerialTaskError { - #[error(transparent)] - Error(#[from] Error), - #[error(transparent)] - TungsteniteError(#[from] tokio_tungstenite::tungstenite::Error), -} - -#[usdt::provider(provider = "gateway_sp_comms")] -mod probes { - fn recv_response( - _port: &SwitchPort, - _request_id: u32, - _result: &Result, - ) { - } - - fn recv_serial_console( - _port: &SwitchPort, - _component: &SpComponent, - _offset: u64, - _data: u64, // TODO actually a `*const u8`, but that isn't allowed by usdt - _len: u64, - ) { - } -} diff --git a/gateway-sp-comms/src/recv_handler/request_response_map.rs b/gateway-sp-comms/src/recv_handler/request_response_map.rs deleted file mode 100644 index f1400d55ba..0000000000 --- a/gateway-sp-comms/src/recv_handler/request_response_map.rs +++ /dev/null @@ -1,170 +0,0 @@ -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at https://mozilla.org/MPL/2.0/. - -// Copyright 2022 Oxide Computer Company - -use futures::Future; -use std::collections::HashMap; -use std::hash::Hash; -use std::sync::Mutex; -use tokio::sync::oneshot; -use tokio::sync::oneshot::Receiver; -use tokio::sync::oneshot::Sender; - -/// Possible outcomes of ingesting a response. -#[derive(Debug, Clone, Copy, PartialEq)] -pub(super) enum ResponseIngestResult { - Ok, - UnknownRequestId, -} - -/// An in-memory map providing paired request/response functionality to async -/// tasks. Both tasks must know the key of the request. -#[derive(Debug)] -pub(super) struct RequestResponseMap { - requests: Mutex>>, -} - -// this can't be derived for T: !Default, but we don't need T: Default -impl Default for RequestResponseMap { - fn default() -> Self { - Self { requests: Mutex::default() } - } -} - -impl RequestResponseMap -where - K: Hash + Eq + Clone, -{ - /// Register a key to be paired with a future call to `ingest_response()`. - /// Returns a [`Future`] that will complete once that call is made. - /// - /// Panics if `key` is currently in use by another - /// `wait_for_response()` call on `self`. - pub(super) fn wait_for_response( - &self, - key: K, - ) -> impl Future + '_ { - // construct a oneshot channel, and wrap the receiving half in a type - // that will clean up `self.requests` if the future we return is dropped - // (typically due to timeout/cancellation) - let (tx, rx) = oneshot::channel(); - let mut wrapped_rx = - RemoveOnDrop { rx, requests: &self.requests, key: key.clone() }; - - // insert the sending half into `self.requests` for use by - // `ingest_response()` later. - let old = self.requests.lock().unwrap().insert(key, tx); - - // we always clean up `self.requests` after receiving a response (or - // timing out), so we should never see request ID reuse even if they - // roll over. assert that here to ensure we don't get mysterious - // misbehavior if that turns out to be incorrect - assert!(old.is_none(), "request ID reuse"); - - async move { - // wait for someone to call `ingest_response()` with our request id - match (&mut wrapped_rx.rx).await { - Ok(inner_result) => inner_result, - Err(_recv_error) => { - // we hold the sending half in `self.requests` until either - // `wrapped_rx`'s `Drop` impl removes it (in which case - // we're not running anymore), or it's consumed to send the - // result to us. receiving therefore can't fail - unreachable!() - } - } - } - } - - /// Ingest a response, which will cause the corresponding - /// `wait_for_response()` future to be fulfilled (if it exists). - pub(super) fn ingest_response( - &self, - key: &K, - response: T, - ) -> ResponseIngestResult { - // get the sending half of the channel created in `wait_for_response()`, - // if it exists (i.e., `wait_for_response()` was actually called with - // `key`, and the returned future from it hasn't been dropped) - let tx = match self.requests.lock().unwrap().remove(key) { - Some(tx) => tx, - None => return ResponseIngestResult::UnknownRequestId, - }; - - // we got `tx`, so the receiving end existed a moment ago, but there's a - // race here where it could be dropped before we're able to send - // `result` through, so we can't unwrap this send; we treat this failure - // the same as not finding `tx`, because if we had tried to get `tx` a - // moment later it would not have been there. - match tx.send(response) { - Ok(()) => ResponseIngestResult::Ok, - Err(_) => ResponseIngestResult::UnknownRequestId, - } - } -} - -/// `RemoveOnDrop` is a light wrapper around a [`Receiver`] that removes the -/// associated key from the `HashMap` it's given once the response has been -/// received (or the `RemoveOnDrop` itself is dropped). -#[derive(Debug)] -struct RemoveOnDrop<'a, K, T> -where - K: Hash + Eq, -{ - rx: Receiver, - requests: &'a Mutex>>, - key: K, -} - -impl Drop for RemoveOnDrop<'_, K, T> -where - K: Hash + Eq, -{ - fn drop(&mut self) { - // we don't care to check the return value here; this will be `Some(_)` - // if we're being dropped before a response has been received and - // forwarded on to our caller, and `None` if not (because a caller will - // have already extracted the sending channel) - let _ = self.requests.lock().unwrap().remove(&self.key); - } -} - -#[cfg(test)] -mod tests { - use super::*; - use std::mem; - - #[tokio::test] - async fn basic_usage() { - let os = RequestResponseMap::default(); - - // ingesting a response before waiting for it doesn't work - assert_eq!( - os.ingest_response(&1, "hi"), - ResponseIngestResult::UnknownRequestId - ); - - // ingesting a response after waiting for it works, and the receiving - // half gets the ingested response - let resp = os.wait_for_response(1); - assert_eq!(os.ingest_response(&1, "hello"), ResponseIngestResult::Ok); - assert_eq!(resp.await, "hello"); - } - - #[tokio::test] - async fn dropping_future_cleans_up_key() { - let os = RequestResponseMap::default(); - - // register interest in a response, but then drop the returned future - let resp = os.wait_for_response(1); - mem::drop(resp); - - // attempting to ingest the corresponding response should now fail - assert_eq!( - os.ingest_response(&1, "hi"), - ResponseIngestResult::UnknownRequestId - ); - } -} diff --git a/gateway-sp-comms/src/single_sp.rs b/gateway-sp-comms/src/single_sp.rs index 1dcae35d90..5865d05dfe 100644 --- a/gateway-sp-comms/src/single_sp.rs +++ b/gateway-sp-comms/src/single_sp.rs @@ -6,7 +6,15 @@ //! Interface for communicating with a single SP. +use crate::communicator::ResponseKindExt; +use crate::error::BadResponseType; +use crate::error::SerialConsoleAlreadyAttached; +use crate::error::SpCommunicationError; +use crate::error::UpdateError; use gateway_messages::version; +use gateway_messages::BulkIgnitionState; +use gateway_messages::IgnitionCommand; +use gateway_messages::IgnitionState; use gateway_messages::Request; use gateway_messages::RequestKind; use gateway_messages::ResponseError; @@ -28,11 +36,10 @@ use slog::trace; use slog::warn; use slog::Logger; use std::convert::TryInto; -use std::io; use std::net::Ipv6Addr; +use std::net::SocketAddr; use std::net::SocketAddrV6; use std::time::Duration; -use thiserror::Error; use tokio::net::UdpSocket; use tokio::sync::mpsc; use tokio::sync::oneshot; @@ -41,9 +48,6 @@ use tokio::task::JoinHandle; use tokio::time; use tokio::time::timeout; -use crate::communicator::ResponseKindExt; -use crate::error::BadResponseType; - pub const DISCOVERY_MULTICAST_ADDR: Ipv6Addr = Ipv6Addr::new(0xff15, 0, 0, 0, 0, 0, 0x1de, 0); @@ -53,35 +57,7 @@ pub const DISCOVERY_MULTICAST_ADDR: Ipv6Addr = // TODO-correctness/TODO-security What do we do if the SP address changes? const DISCOVERY_INTERVAL_IDLE: Duration = Duration::from_secs(60); -#[derive(Debug, Error)] -pub enum Error { - #[error("failed to send UDP packet to {addr}: {err}")] - UdpSendTo { addr: SocketAddrV6, err: io::Error }, - #[error("failed to recv UDP packet: {0}")] - UdpRecv(io::Error), - #[error("failed to deserialize SP message from {peer}: {err}")] - Deserialize { peer: SocketAddrV6, err: gateway_messages::HubpackError }, - #[error("RPC call failed (gave up after {0} attempts)")] - ExhaustedNumAttempts(usize), - #[error("serial console already attached")] - SerialConsoleAlreadyAttached, - #[error(transparent)] - BadResponseType(#[from] BadResponseType), - #[error("Error response from SP: {0}")] - SpError(#[from] ResponseError), -} - -#[derive(Debug, Error)] -pub enum UpdateError { - #[error("update image is too large")] - ImageTooLarge, - #[error("error starting update: {0}")] - Start(Error), - #[error("error updating chunk at offset {offset}: {err}")] - Chunk { offset: u32, err: Error }, -} - -type Result = std::result::Result; +type Result = std::result::Result; #[derive(Debug)] pub struct SingleSp { @@ -139,6 +115,43 @@ impl SingleSp { &self.sp_addr_rx } + /// Request the state of an ignition target. + /// + /// This will fail if this SP is not connected to an ignition controller. + pub async fn ignition_state(&self, target: u8) -> Result { + self.rpc(RequestKind::IgnitionState { target }).await.and_then( + |(_peer, response)| { + response.expect_ignition_state().map_err(Into::into) + }, + ) + } + + /// Request the state of all ignition targets. + /// + /// This will fail if this SP is not connected to an ignition controller. + pub async fn bulk_ignition_state(&self) -> Result { + self.rpc(RequestKind::BulkIgnitionState).await.and_then( + |(_peer, response)| { + response.expect_bulk_ignition_state().map_err(Into::into) + }, + ) + } + + /// Send an ignition command to the given target. + /// + /// This will fail if this SP is not connected to an ignition controller. + pub async fn ignition_command( + &self, + target: u8, + command: IgnitionCommand, + ) -> Result<()> { + self.rpc(RequestKind::IgnitionCommand { target, command }) + .await + .and_then(|(_peer, response)| { + response.expect_ignition_command_ack().map_err(Into::into) + }) + } + /// Request the state of the SP. pub async fn state(&self) -> Result { self.rpc(RequestKind::SpState).await.and_then(|(_peer, response)| { @@ -239,12 +252,12 @@ impl SingleSp { // SP wasn't expecting a reset trigger (because it has reset!). match self.rpc(RequestKind::SysResetTrigger).await { Ok((_peer, response)) => { - Err(Error::BadResponseType(BadResponseType { + Err(SpCommunicationError::BadResponseType(BadResponseType { expected: "system-reset", got: response.name(), })) } - Err(Error::SpError( + Err(SpCommunicationError::SpError( ResponseError::SysResetTriggerWithoutPrepare, )) => Ok(()), Err(other) => Err(other), @@ -253,23 +266,32 @@ impl SingleSp { /// "Attach" to the serial console, setting up a tokio channel for all /// incoming serial console packets from the SP. - pub async fn serial_console_attach(&self) -> Result { + pub async fn serial_console_attach( + &self, + ) -> Result { let (tx, rx) = oneshot::channel(); // `Inner::run()` doesn't exit until we are dropped, so unwrapping here // only panics if it itself panicked. self.cmds_tx.send(InnerCommand::SerialConsoleAttach(tx)).await.unwrap(); - let rx = rx.await.unwrap()?; + let attachment = rx.await.unwrap()?; - Ok(AttachedSerialConsole { rx, inner_tx: self.cmds_tx.clone() }) + Ok(AttachedSerialConsole { + key: attachment.key, + rx: attachment.incoming, + inner_tx: self.cmds_tx.clone(), + }) } - /// Detach an existing attached serial console connection. + /// Detach any existing attached serial console connection. pub async fn serial_console_detach(&self) { // `Inner::run()` doesn't exit until we are dropped, so unwrapping here // only panics if it itself panicked. - self.cmds_tx.send(InnerCommand::SerialConsoleDetach).await.unwrap(); + self.cmds_tx + .send(InnerCommand::SerialConsoleDetach(None)) + .await + .unwrap(); } pub(crate) async fn rpc( @@ -298,6 +320,7 @@ async fn rpc( #[derive(Debug)] pub struct AttachedSerialConsole { + key: u64, rx: mpsc::Receiver, inner_tx: mpsc::Sender, } @@ -307,7 +330,10 @@ impl AttachedSerialConsole { self, ) -> (AttachedSerialConsoleSend, AttachedSerialConsoleRecv) { ( - AttachedSerialConsoleSend { inner_tx: self.inner_tx }, + AttachedSerialConsoleSend { + key: self.key, + inner_tx: self.inner_tx, + }, AttachedSerialConsoleRecv { rx: self.rx }, ) } @@ -315,6 +341,7 @@ impl AttachedSerialConsole { #[derive(Debug)] pub struct AttachedSerialConsoleSend { + key: u64, inner_tx: mpsc::Sender, } @@ -327,6 +354,14 @@ impl AttachedSerialConsoleSend { response.expect_serial_console_write_ack().map_err(Into::into) }) } + + /// Detach this serial console connection. + pub async fn detach(&self) { + self.inner_tx + .send(InnerCommand::SerialConsoleDetach(Some(self.key))) + .await + .unwrap(); + } } #[derive(Debug)] @@ -350,14 +385,29 @@ struct RpcRequest { response: oneshot::Sender>, } +#[derive(Debug)] +struct SerialConsoleAttachment { + key: u64, + incoming: mpsc::Receiver, +} + #[derive(Debug)] // `Rpc` is the large variant, which is by far the most common, so silence // clippy's warning that recommends boxing it. #[allow(clippy::large_enum_variant)] enum InnerCommand { Rpc(RpcRequest), - SerialConsoleAttach(oneshot::Sender>>), - SerialConsoleDetach, + SerialConsoleAttach( + oneshot::Sender< + Result, + >, + ), + // The associated value is the connection key; if `Some(_)`, only detach if + // the currently-attached key number matches. If `None`, detach any current + // connection. These correspond to "detach the current session" (performed + // automatically when a connection is closed) and "force-detach any session" + // (performed by a user). + SerialConsoleDetach(Option), } struct Inner { @@ -370,6 +420,7 @@ struct Inner { serial_console_tx: Option>, cmds_rx: mpsc::Receiver, request_id: u32, + serial_console_connection_key: u64, } impl Inner { @@ -392,6 +443,7 @@ impl Inner { serial_console_tx: None, cmds_rx, request_id: 0, + serial_console_connection_key: 0, } } @@ -519,16 +571,24 @@ impl Inner { } InnerCommand::SerialConsoleAttach(response_tx) => { let resp = if self.serial_console_tx.is_some() { - Err(Error::SerialConsoleAlreadyAttached) + Err(SerialConsoleAlreadyAttached) } else { let (tx, rx) = mpsc::channel(SERIAL_CONSOLE_CHANNEL_DEPTH); self.serial_console_tx = Some(tx); - Ok(rx) + self.serial_console_connection_key += 1; + Ok(SerialConsoleAttachment { + key: self.serial_console_connection_key, + incoming: rx, + }) }; response_tx.send(resp).unwrap(); } - InnerCommand::SerialConsoleDetach => { - self.serial_console_tx = None; + InnerCommand::SerialConsoleDetach(key) => { + if key.is_none() + || key == Some(self.serial_console_connection_key) + { + self.serial_console_tx = None; + } } } } @@ -619,7 +679,7 @@ impl Inner { } } - Err(Error::ExhaustedNumAttempts(self.max_attempts)) + Err(SpCommunicationError::ExhaustedNumAttempts(self.max_attempts)) } async fn rpc_call_one_attempt( @@ -714,11 +774,11 @@ async fn send( socket: &UdpSocket, addr: SocketAddrV6, data: &[u8], -) -> Result<(), Error> { +) -> Result<()> { let n = socket .send_to(data, addr) .await - .map_err(|err| Error::UdpSendTo { addr, err })?; + .map_err(|err| SpCommunicationError::UdpSendTo { addr, err })?; // `send_to` should never write a partial packet; this is UDP. assert_eq!(data.len(), n, "partial UDP packet sent to {}?!", addr); @@ -730,15 +790,19 @@ async fn recv( socket: &UdpSocket, incoming_buf: &mut [u8; SpMessage::MAX_SIZE], log: &Logger, -) -> Result<(SocketAddrV6, SpMessage), Error> { +) -> Result<(SocketAddrV6, SpMessage)> { let (n, peer) = socket .recv_from(&mut incoming_buf[..]) .await - .map_err(Error::UdpRecv)?; + .map_err(SpCommunicationError::UdpRecv)?; + + probes::recv_packet!(|| { + (peer, incoming_buf.as_ptr() as usize as u64, n as u64) + }); let peer = match peer { - std::net::SocketAddr::V6(addr) => addr, - std::net::SocketAddr::V4(_) => { + SocketAddr::V6(addr) => addr, + SocketAddr::V4(_) => { // We're exclusively using IPv6; we can't get a response from an // IPv4 peer. unreachable!() @@ -747,7 +811,7 @@ async fn recv( let (message, _n) = gateway_messages::deserialize::(&incoming_buf[..n]) - .map_err(|err| Error::Deserialize { peer, err })?; + .map_err(|err| SpCommunicationError::Deserialize { peer, err })?; trace!( log, "received message from SP"; @@ -771,3 +835,13 @@ fn sp_busy_policy() -> backoff::ExponentialBackoff { ..Default::default() } } + +#[usdt::provider(provider = "gateway_sp_comms")] +mod probes { + fn recv_packet( + _source: &SocketAddr, + _data: u64, // TODO actually a `*const u8`, but that isn't allowed by usdt + _len: u64, + ) { + } +} diff --git a/gateway/Cargo.toml b/gateway/Cargo.toml index 8f7506bc99..f610a25453 100644 --- a/gateway/Cargo.toml +++ b/gateway/Cargo.toml @@ -11,11 +11,11 @@ futures = "0.3.21" hex = "0.4" http = "0.2.7" hyper = "0.14.20" -ringbuffer = "0.8" schemars = "0.8" serde = { version = "1.0", features = ["derive"] } slog-dtrace = "0.2" thiserror = "1.0.32" +tokio-tungstenite = "0.17" toml = "0.5.9" uuid = "1.1.0" diff --git a/gateway/examples/config.toml b/gateway/examples/config.toml index 321455c321..ff53ef0230 100644 --- a/gateway/examples/config.toml +++ b/gateway/examples/config.toml @@ -10,6 +10,13 @@ id = "8afcb12d-f625-4df9-bdf2-f495c3bbd323" # our contact to the ignition controller) local_ignition_controller_port = 0 +# when sending UDP RPC packets to an SP, how many total attempts do we make +# before giving up? +rpc_max_attempts = 5 + +# sleep time between UDP RPC resends (up to `rpc_max_attempts`) +rpc_per_attempt_timeout_millis = 2000 + [switch.location] # possible locations where MGS could be running; these names appear in logs and # are used in the remainder of the `[switch.*]` configuration to define port @@ -68,8 +75,6 @@ switch1 = ["sled", 1] [timeouts] discovery_millis = 1_000 -ignition_controller_millis = 1_000 -sp_request_millis = 1_000 bulk_request_default_millis = 5_000 bulk_request_max_millis = 60_000 bulk_request_page_millis = 1_000 diff --git a/gateway/faux-mgs/src/main.rs b/gateway/faux-mgs/src/main.rs index 1ef8e7a5cd..e7b40b356a 100644 --- a/gateway/faux-mgs/src/main.rs +++ b/gateway/faux-mgs/src/main.rs @@ -9,8 +9,8 @@ use anyhow::Context; use anyhow::Result; use clap::Parser; use clap::Subcommand; -use gateway_sp_comms::single_sp::SingleSp; -use gateway_sp_comms::single_sp::DISCOVERY_MULTICAST_ADDR; +use gateway_sp_comms::SingleSp; +use gateway_sp_comms::DISCOVERY_MULTICAST_ADDR; use slog::info; use slog::o; use slog::Drain; diff --git a/gateway/faux-mgs/src/usart.rs b/gateway/faux-mgs/src/usart.rs index deb200e212..8eb68e2fe2 100644 --- a/gateway/faux-mgs/src/usart.rs +++ b/gateway/faux-mgs/src/usart.rs @@ -8,8 +8,8 @@ use anyhow::Context; use anyhow::Result; use gateway_messages::sp_impl::SerialConsolePacketizer; use gateway_messages::SpComponent; -use gateway_sp_comms::single_sp::AttachedSerialConsoleSend; -use gateway_sp_comms::single_sp::SingleSp; +use gateway_sp_comms::AttachedSerialConsoleSend; +use gateway_sp_comms::SingleSp; use slog::trace; use slog::Logger; use std::io; diff --git a/gateway/src/bin/gateway.rs b/gateway/src/bin/gateway.rs index 1fec74f63d..2a4b679fad 100644 --- a/gateway/src/bin/gateway.rs +++ b/gateway/src/bin/gateway.rs @@ -20,8 +20,12 @@ struct Args { )] openapi: bool, - #[clap(name = "CONFIG_FILE_PATH", action)] - config_file_path: PathBuf, + #[clap( + name = "CONFIG_FILE_PATH", + action, + required_unless_present = "openapi" + )] + config_file_path: Option, } #[tokio::main] @@ -34,12 +38,14 @@ async fn main() { async fn do_run() -> Result<(), CmdError> { let args = Args::parse(); - let config = Config::from_file(args.config_file_path) - .map_err(|e| CmdError::Failure(e.to_string()))?; - if args.openapi { run_openapi().map_err(CmdError::Failure) } else { + // `.unwrap()` here is fine because our clap config requires + // `config_file_path` to be passed if `openapi` is not. + let config = Config::from_file(args.config_file_path.unwrap()) + .map_err(|e| CmdError::Failure(e.to_string()))?; + run_server(config).await.map_err(CmdError::Failure) } } diff --git a/gateway/src/bulk_state_get.rs b/gateway/src/bulk_state_get.rs index 8f85aff94a..8462b9d3f0 100644 --- a/gateway/src/bulk_state_get.rs +++ b/gateway/src/bulk_state_get.rs @@ -154,7 +154,7 @@ impl BulkSpStateRequests { self.requests.lock().unwrap().insert(id, Arc::clone(&collector)); // query ignition controller to find out which SPs are powered on - let all_sps = self.communicator.get_ignition_state_all(timeout).await?; + let all_sps = self.communicator.get_ignition_state_all().await?; // build collection of futures to contact all SPs let communicator = Arc::clone(&self.communicator); @@ -163,7 +163,7 @@ impl BulkSpStateRequests { timeout, move |sp| { let communicator = Arc::clone(&communicator); - async move { communicator.get_state(sp, timeout).await } + async move { communicator.get_state(sp).await } }, ); diff --git a/gateway/src/config.rs b/gateway/src/config.rs index 34e225028e..496316e9cd 100644 --- a/gateway/src/config.rs +++ b/gateway/src/config.rs @@ -17,10 +17,6 @@ pub struct Timeouts { /// Timeout for running the discovery process to determine logical mappings /// of switches/sleds. pub discovery_millis: u64, - /// Timeout for messages to our local ignition controller SP. - pub ignition_controller_millis: u64, - /// Timeout for requests sent to arbitrary SPs. - pub sp_request_millis: u64, /// Default timeout for requests that collect responses from multiple /// targets, if the client doesn't provide one. pub bulk_request_default_millis: u64, diff --git a/gateway/src/context.rs b/gateway/src/context.rs index 66a31dda45..7b29aebf68 100644 --- a/gateway/src/context.rs +++ b/gateway/src/context.rs @@ -14,11 +14,10 @@ pub struct ServerContext { pub sp_comms: Arc, pub bulk_sp_state_requests: BulkSpStateRequests, pub timeouts: Timeouts, + pub log: Logger, } pub struct Timeouts { - pub ignition_controller: Duration, - pub sp_request: Duration, pub bulk_request_default: Duration, pub bulk_request_max: Duration, pub bulk_request_page: Duration, @@ -28,10 +27,6 @@ pub struct Timeouts { impl From<&'_ crate::config::Timeouts> for Timeouts { fn from(timeouts: &'_ crate::config::Timeouts) -> Self { Self { - ignition_controller: Duration::from_millis( - timeouts.ignition_controller_millis, - ), - sp_request: Duration::from_millis(timeouts.sp_request_millis), bulk_request_default: Duration::from_millis( timeouts.bulk_request_default_millis, ), @@ -63,6 +58,7 @@ impl ServerContext { sp_comms: Arc::clone(&comms), bulk_sp_state_requests: BulkSpStateRequests::new(comms, log), timeouts: Timeouts::from(&timeouts), + log: log.clone(), })) } } diff --git a/gateway/src/error.rs b/gateway/src/error.rs index 201c000f1f..d1db127f23 100644 --- a/gateway/src/error.rs +++ b/gateway/src/error.rs @@ -13,6 +13,8 @@ use gateway_sp_comms::error::Error as SpCommsError; pub(crate) enum Error { #[error("invalid page token ({0})")] InvalidPageToken(InvalidPageToken), + #[error("websocket connection failure: {0}")] + BadWebsocketConnection(&'static str), #[error(transparent)] CommunicationsError(#[from] SpCommsError), } @@ -33,6 +35,10 @@ impl From for HttpError { err.to_string(), ), Error::CommunicationsError(err) => http_err_from_comms_err(err), + Error::BadWebsocketConnection(_) => HttpError::for_bad_request( + Some("BadWebsocketConnection".to_string()), + err.to_string(), + ), } } } @@ -51,12 +57,9 @@ where Some("SerialConsoleAttached".to_string()), err.to_string(), ), - SpCommsError::BadWebsocketConnection(_) => HttpError::for_bad_request( - Some("BadWebsocketConnection".to_string()), - err.to_string(), - ), SpCommsError::SpAddressUnknown(_) | SpCommsError::Timeout { .. } + | SpCommsError::BadIgnitionTarget(_) | SpCommsError::LocalIgnitionControllerAddressUnknown | SpCommsError::SpCommunicationFailed(_) => { HttpError::for_internal_error(err.to_string()) diff --git a/gateway/src/http_entrypoints.rs b/gateway/src/http_entrypoints.rs index f7bba5b6e7..cad76a04cc 100644 --- a/gateway/src/http_entrypoints.rs +++ b/gateway/src/http_entrypoints.rs @@ -318,6 +318,7 @@ async fn sp_list( // need more refined errors? SpCommsError::Timeout { .. } | SpCommsError::SpCommunicationFailed(_) + | SpCommsError::BadIgnitionTarget(_) | SpCommsError::LocalIgnitionControllerAddressUnknown | SpCommsError::SpAddressUnknown(_) => { SpState::Unresponsive @@ -325,7 +326,6 @@ async fn sp_list( // These errors should not be possible for the request we // made. SpCommsError::SpDoesNotExist(_) - | SpCommsError::BadWebsocketConnection(_) | SpCommsError::SerialConsoleAttached => { unreachable!("impossible error {}", err) } @@ -350,9 +350,6 @@ async fn sp_list( } /// Get info on an SP -/// -/// As communication with SPs may be unreliable, consumers may specify an -/// optional timeout to override the default. #[endpoint { method = GET, path = "/sp/{type}/{slot}", @@ -360,34 +357,21 @@ async fn sp_list( async fn sp_get( rqctx: Arc>>, path: Path, - query: Query, ) -> Result, HttpError> { let apictx = rqctx.context(); let comms = &apictx.sp_comms; let sp = path.into_inner().sp; - // TODO should we construct this here or after our `ignition_get`? By - // putting it here, the time it takes us to query ignition counts against - // the client's timeout; that seems right but puts us in a bind if their - // timeout expires while we're still waiting for ignition. - let timeout = SpTimeout::from_now( - query - .into_inner() - .timeout_millis - .map(|n| Duration::from_millis(u64::from(n))) - .unwrap_or(apictx.timeouts.sp_request), - ); - // ping the ignition controller first; if it says the SP is off or otherwise // unavailable, we're done. let state = comms - .get_ignition_state(sp.into(), timeout) + .get_ignition_state(sp.into()) .await .map_err(http_err_from_comms_err)?; let details = if state.is_powered_on() { // ignition indicates the SP is on; ask it for its state - match comms.get_state(sp.into(), timeout).await { + match comms.get_state(sp.into()).await { Ok(state) => SpState::from(state), Err(SpCommsError::Timeout { .. }) => SpState::Unresponsive, Err(other) => return Err(http_err_from_comms_err(other)), @@ -463,16 +447,15 @@ async fn sp_component_serial_console_attach( let component = component_from_str(&component)?; let mut request = rqctx.request.lock().await; - apictx - .sp_comms - .serial_console_attach( - &mut request, - sp.into(), - component, - apictx.timeouts.sp_request, - ) - .await - .map_err(http_err_from_comms_err) + let sp = sp.into(); + Ok(crate::serial_console::attach( + &apictx.sp_comms, + sp, + component, + &mut request, + apictx.log.new(slog::o!("sp" => format!("{sp:?}"))), + ) + .await?) } /// Detach the websocket connection attached to the given SP component's serial @@ -486,12 +469,13 @@ async fn sp_component_serial_console_detach( path: Path, ) -> Result { let comms = &rqctx.context().sp_comms; - let PathSpComponent { sp, component } = path.into_inner(); - let component = component_from_str(&component)?; + // TODO-cleanup: "component" support for the serial console is half baked; + // we don't use it at all to detach. + let PathSpComponent { sp, component: _ } = path.into_inner(); comms - .serial_console_detach(sp.into(), &component) + .serial_console_detach(sp.into()) .await .map_err(http_err_from_comms_err)?; @@ -571,9 +555,7 @@ async fn ignition_list( let sp_comms = &apictx.sp_comms; let all_state = sp_comms - .get_ignition_state_all(SpTimeout::from_now( - apictx.timeouts.ignition_controller, - )) + .get_ignition_state_all() .await .map_err(http_err_from_comms_err)?; @@ -602,10 +584,7 @@ async fn ignition_get( let state = apictx .sp_comms - .get_ignition_state( - sp.into(), - SpTimeout::from_now(apictx.timeouts.ignition_controller), - ) + .get_ignition_state(sp.into()) .await .map_err(http_err_from_comms_err)?; @@ -627,11 +606,7 @@ async fn ignition_power_on( apictx .sp_comms - .send_ignition_command( - sp.into(), - IgnitionCommand::PowerOn, - SpTimeout::from_now(apictx.timeouts.ignition_controller), - ) + .send_ignition_command(sp.into(), IgnitionCommand::PowerOn) .await .map_err(http_err_from_comms_err)?; @@ -652,11 +627,7 @@ async fn ignition_power_off( apictx .sp_comms - .send_ignition_command( - sp.into(), - IgnitionCommand::PowerOff, - SpTimeout::from_now(apictx.timeouts.ignition_controller), - ) + .send_ignition_command(sp.into(), IgnitionCommand::PowerOff) .await .map_err(http_err_from_comms_err)?; diff --git a/gateway/src/lib.rs b/gateway/src/lib.rs index 008088a7c6..3c5835bea2 100644 --- a/gateway/src/lib.rs +++ b/gateway/src/lib.rs @@ -6,6 +6,7 @@ mod bulk_state_get; mod config; mod context; mod error; +mod serial_console; pub mod http_entrypoints; // TODO pub only for testing - is this right? diff --git a/gateway/src/serial_console.rs b/gateway/src/serial_console.rs new file mode 100644 index 0000000000..26a97c0cfb --- /dev/null +++ b/gateway/src/serial_console.rs @@ -0,0 +1,282 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +// Copyright 2022 Oxide Computer Company + +use crate::error::Error; +use futures::future::Fuse; +use futures::FutureExt; +use futures::SinkExt; +use futures::StreamExt; +use futures::TryFutureExt; +use gateway_messages::sp_impl::SerialConsolePacketizer; +use gateway_messages::SpComponent; +use gateway_sp_comms::AttachedSerialConsole; +use gateway_sp_comms::AttachedSerialConsoleSend; +use gateway_sp_comms::Communicator; +use gateway_sp_comms::SpIdentifier; +use http::header; +use hyper::upgrade; +use hyper::upgrade::Upgraded; +use hyper::Body; +use slog::debug; +use slog::error; +use slog::info; +use slog::Logger; +use std::borrow::Cow; +use std::collections::VecDeque; +use std::ops::Deref; +use tokio_tungstenite::tungstenite::handshake; +use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode; +use tokio_tungstenite::tungstenite::protocol::CloseFrame; +use tokio_tungstenite::tungstenite::protocol::Role; +use tokio_tungstenite::tungstenite::protocol::WebSocketConfig; +use tokio_tungstenite::tungstenite::Message; +use tokio_tungstenite::WebSocketStream; + +pub(crate) async fn attach( + sp_comms: &Communicator, + sp: SpIdentifier, + component: SpComponent, + request: &mut http::Request, + log: Logger, +) -> Result, Error> { + if !request + .headers() + .get(header::CONNECTION) + .and_then(|hv| hv.to_str().ok()) + .map(|hv| { + hv.split(|c| c == ',' || c == ' ') + .any(|vs| vs.eq_ignore_ascii_case("upgrade")) + }) + .unwrap_or(false) + { + return Err(Error::BadWebsocketConnection( + "expected connection upgrade", + )); + } + if !request + .headers() + .get(header::UPGRADE) + .and_then(|v| v.to_str().ok()) + .map(|v| { + v.split(|c| c == ',' || c == ' ') + .any(|v| v.eq_ignore_ascii_case("websocket")) + }) + .unwrap_or(false) + { + return Err(Error::BadWebsocketConnection( + "unexpected protocol for upgrade", + )); + } + if request + .headers() + .get(header::SEC_WEBSOCKET_VERSION) + .map(|v| v.as_bytes()) + != Some(b"13") + { + return Err(Error::BadWebsocketConnection( + "missing or invalid websocket version", + )); + } + let accept_key = request + .headers() + .get(header::SEC_WEBSOCKET_KEY) + .map(|hv| hv.as_bytes()) + .map(|key| handshake::derive_accept_key(key)) + .ok_or(Error::BadWebsocketConnection("missing websocket key"))?; + + let console = sp_comms.serial_console_attach(sp).await?; + let upgrade_fut = upgrade::on(request); + tokio::spawn(async move { + let upgraded = match upgrade_fut.await { + Ok(u) => u, + Err(e) => { + error!(log, "serial task failed"; "err" => %e); + return; + } + }; + let config = WebSocketConfig { + max_send_queue: Some(4096), + ..Default::default() + }; + let ws_stream = WebSocketStream::from_raw_socket( + upgraded, + Role::Server, + Some(config), + ) + .await; + + let task = SerialConsoleTask { console, component, ws_stream }; + match task.run(&log).await { + Ok(()) => debug!(log, "serial task complete"), + Err(e) => { + error!(log, "serial task failed"; "err" => %e) + } + } + }); + + // `.body()` only fails if our headers are bad, which they aren't + // (unless `hyper::handshake` gives us a bogus accept key?), so we're + // safe to unwrap this + Ok(http::Response::builder() + .status(http::StatusCode::SWITCHING_PROTOCOLS) + .header(header::CONNECTION, "Upgrade") + .header(header::UPGRADE, "websocket") + .header(header::SEC_WEBSOCKET_ACCEPT, accept_key) + .body(Body::empty()) + .unwrap()) +} + +#[derive(Debug, thiserror::Error)] +enum SerialTaskError { + #[error(transparent)] + Error(#[from] Error), + #[error(transparent)] + TungsteniteError(#[from] tokio_tungstenite::tungstenite::Error), +} + +struct SerialConsoleTask { + console: AttachedSerialConsole, + component: SpComponent, + ws_stream: WebSocketStream, +} + +impl SerialConsoleTask { + async fn run(self, log: &Logger) -> Result<(), SerialTaskError> { + let (mut ws_sink, mut ws_stream) = self.ws_stream.split(); + let (console_tx, mut console_rx) = self.console.split(); + let console_tx = DetachOnDrop::new(console_tx); + + // TODO Currently we do not apply any backpressure on the SP and are + // willing to buffer up an arbitrary amount of data in memory. Is it + // reasonable to apply backpressure to the SP over UDP? Should we have + // caps on memory and start discarding data if we exceed them? We _do_ + // apply backpressure to the websocket, delaying reading from it if we + // still have data waiting to be sent to the SP. + let mut data_from_sp: VecDeque> = VecDeque::new(); + let mut data_to_sp: Vec = Vec::new(); + let mut packetizer_to_sp = SerialConsolePacketizer::new(self.component); + + loop { + let ws_send = if let Some(data) = data_from_sp.pop_front() { + ws_sink.send(Message::Binary(data)).fuse() + } else { + Fuse::terminated() + }; + + let ws_recv; + let sp_send; + if data_to_sp.is_empty() { + sp_send = Fuse::terminated(); + ws_recv = ws_stream.next().fuse(); + } else { + ws_recv = Fuse::terminated(); + + let (packet, _remaining) = + packetizer_to_sp.first_packet(data_to_sp.as_slice()); + let packet_data_len = usize::from(packet.len); + + sp_send = console_tx + .write(packet) + .map_ok(move |()| packet_data_len) + .fuse(); + } + + tokio::select! { + // Send a UDP packet to the SP + send_success = sp_send => { + let n = send_success + .map_err(gateway_sp_comms::error::Error::from) + .map_err(Error::from)?; + data_to_sp.drain(..n); + } + + // Receive a UDP packet from the SP. + packet = console_rx.recv() => { + match packet.as_ref() { + Some(packet) => { + let data = &packet.data[..usize::from(packet.len)]; + data_from_sp.push_back(data.to_vec()); + } + None => { + // Sender is closed; i.e., we've been detached. + // Close the websocket. + info!(log, "detaching from serial console"); + let close = CloseFrame { + code: CloseCode::Policy, + reason: Cow::Borrowed("serial console was detached"), + }; + ws_sink.send(Message::Close(Some(close))).await?; + return Ok(()); + } + } + } + + // Send a previously-received UDP packet of data to the websocket + // client + write_success = ws_send => { + write_success?; + } + + // Receive from the websocket to send to the SP. + msg = ws_recv => { + match msg { + Some(Ok(Message::Binary(mut data))) => { + // we only populate ws_recv when we have no data + // currently queued up; sanity check that here + assert!(data_to_sp.is_empty()); + data_to_sp.append(&mut data); + } + Some(Ok(Message::Close(_))) | None => { + info!( + log, + "remote end closed websocket; terminating task", + ); + return Ok(()); + } + Some(other) => { + let wrong_message = other?; + error!( + log, + "bogus websocket message; terminating task"; + "message" => ?wrong_message, + ); + return Ok(()); + } + } + } + } + } + } +} + +struct DetachOnDrop(Option); + +impl DetachOnDrop { + fn new(console: AttachedSerialConsoleSend) -> Self { + Self(Some(console)) + } +} + +impl Drop for DetachOnDrop { + fn drop(&mut self) { + // We can't `.await` within `drop()`, so we'll spawn a task to detach + // the console. `detach()` only does anything if the current connection + // is still attached, so it's fine if this runs after a new connection + // has been attached (at which point it won't do anything). + let console = self.0.take().unwrap(); + tokio::spawn(async move { console.detach().await }); + } +} + +impl Deref for DetachOnDrop { + type Target = AttachedSerialConsoleSend; + + fn deref(&self) -> &Self::Target { + // We know from `new()` that we're created with `Some(console)`, and we + // don't remove it until our `Drop` impl + self.0.as_ref().unwrap() + } +} diff --git a/gateway/tests/config.test.toml b/gateway/tests/config.test.toml index 11de801d20..56e6f7adbd 100644 --- a/gateway/tests/config.test.toml +++ b/gateway/tests/config.test.toml @@ -11,6 +11,13 @@ id = "8afcb12d-f625-4df9-bdf2-f495c3bbd323" # our contact to the ignition controller) local_ignition_controller_port = 0 +# when sending UDP RPC packets to an SP, how many total attempts do we make +# before giving up? +rpc_max_attempts = 5 + +# sleep time between UDP RPC resends (up to `rpc_max_attempts`) +rpc_per_attempt_timeout_millis = 5000 + [switch.location] # possible locations where MGS could be running; these names appear in logs and # are used in the remainder of the `[switch.*]` configuration to define port @@ -76,8 +83,6 @@ switch1 = ["sled", 1] [timeouts] discovery_millis = 10_000 -ignition_controller_millis = 10_000 -sp_request_millis = 10_000 bulk_request_default_millis = 10_000 bulk_request_max_millis = 40_000 bulk_request_page_millis = 2_000 diff --git a/openapi/gateway.json b/openapi/gateway.json index 5855993315..1de9f40436 100644 --- a/openapi/gateway.json +++ b/openapi/gateway.json @@ -150,7 +150,6 @@ "/sp/{type}/{slot}": { "get": { "summary": "Get info on an SP", - "description": "As communication with SPs may be unreliable, consumers may specify an optional timeout to override the default.", "operationId": "sp_get", "parameters": [ { @@ -172,17 +171,6 @@ "$ref": "#/components/schemas/SpType" }, "style": "simple" - }, - { - "in": "query", - "name": "timeout_millis", - "schema": { - "nullable": true, - "type": "integer", - "format": "uint32", - "minimum": 0 - }, - "style": "form" } ], "responses": {