diff --git a/Cargo.lock b/Cargo.lock index 45f8fa87c8..bf23617244 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1167,6 +1167,7 @@ dependencies = [ "futures", "gateway-messages", "omicron-test-utils", + "once_cell", "ringbuffer", "serde", "slog", diff --git a/gateway-messages/src/lib.rs b/gateway-messages/src/lib.rs index 688994a64e..e4153f9751 100644 --- a/gateway-messages/src/lib.rs +++ b/gateway-messages/src/lib.rs @@ -129,6 +129,12 @@ pub struct IgnitionState { pub flags: IgnitionFlags, } +impl IgnitionState { + pub fn is_powered_on(self) -> bool { + self.flags.intersects(IgnitionFlags::POWER) + } +} + bitflags! { #[derive(Default, SerializedSize, Serialize, Deserialize)] pub struct IgnitionFlags: u8 { @@ -305,6 +311,7 @@ impl fmt::Debug for SpComponent { /// Error type returned from `TryFrom<&str> for SpComponent` if the provided ID /// is too long. +#[derive(Debug)] pub struct SpComponentIdTooLong; impl TryFrom<&str> for SpComponent { diff --git a/gateway-sp-comms/Cargo.toml b/gateway-sp-comms/Cargo.toml index 114db70570..494fb9dae1 100644 --- a/gateway-sp-comms/Cargo.toml +++ b/gateway-sp-comms/Cargo.toml @@ -22,4 +22,5 @@ version = "1.16" features = [ "full" ] [dev-dependencies] +once_cell = "1.9" omicron-test-utils = { path = "../test-utils" } diff --git a/gateway-sp-comms/src/communicator.rs b/gateway-sp-comms/src/communicator.rs new file mode 100644 index 0000000000..4ad030d925 --- /dev/null +++ b/gateway-sp-comms/src/communicator.rs @@ -0,0 +1,462 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +// Copyright 2022 Oxide Computer Company + +use crate::error::Error; +use crate::error::SpCommunicationError; +use crate::error::StartupError; +use crate::management_switch::ManagementSwitch; +use crate::management_switch::ManagementSwitchDiscovery; +use crate::management_switch::SpSocket; +use crate::management_switch::SwitchPort; +use crate::recv_handler::RecvHandler; +use crate::KnownSps; +use crate::SerialConsoleContents; +use crate::SpIdentifier; +use futures::stream::FuturesUnordered; +use futures::Future; +use futures::Stream; +use gateway_messages::sp_impl::SerialConsolePacketizer; +use gateway_messages::version; +use gateway_messages::BulkIgnitionState; +use gateway_messages::IgnitionCommand; +use gateway_messages::IgnitionState; +use gateway_messages::Request; +use gateway_messages::RequestKind; +use gateway_messages::ResponseKind; +use gateway_messages::SerializedSize; +use gateway_messages::SpComponent; +use gateway_messages::SpState; +use slog::debug; +use slog::info; +use slog::o; +use slog::Logger; +use std::sync::atomic::AtomicU32; +use std::sync::atomic::Ordering; +use std::sync::Arc; +use tokio::time::Instant; + +/// Helper trait that allows us to return an `impl FuturesUnordered<_>` where +/// the caller can call `.is_empty()` without knowing the type of the future +/// inside the collection. +pub trait FuturesUnorderedImpl: Stream + Unpin { + fn is_empty(&self) -> bool; +} + +impl FuturesUnorderedImpl for FuturesUnordered +where + Fut: Future, +{ + fn is_empty(&self) -> bool { + self.is_empty() + } +} + +#[derive(Debug)] +pub struct Communicator { + log: Logger, + switch: ManagementSwitch, + request_id: AtomicU32, + recv_handler: Arc, +} + +impl Communicator { + pub async fn new( + known_sps: KnownSps, + log: &Logger, + ) -> Result { + let log = log.new(o!("component" => "SpCommunicator")); + let discovery = ManagementSwitchDiscovery::placeholder_start( + known_sps, + log.clone(), + ) + .await?; + + let (switch, recv_handler) = RecvHandler::new(discovery, log.clone()); + + info!(&log, "started SP communicator"); + Ok(Self { log, switch, request_id: AtomicU32::new(0), recv_handler }) + } + + // convert an identifier to a port number; this is fallible because + // identifiers can be constructed arbiatrarily, in contrast to `port_to_id` + // below. + fn id_to_port(&self, sp: SpIdentifier) -> Result { + self.switch.switch_port(sp).ok_or(Error::SpDoesNotExist(sp)) + } + + // convert a port to an identifier; this is infallible because we construct + // `SwitchPort`s and know they map to valid IDs + fn port_to_id(&self, port: SwitchPort) -> SpIdentifier { + self.switch.switch_port_to_id(port) + } + + /// Ask the local ignition controller for the ignition state of a given SP. + pub async fn get_ignition_state( + &self, + sp: SpIdentifier, + timeout: Instant, + ) -> Result { + let controller = self.switch.ignition_controller(); + let port = self.id_to_port(sp)?; + let request = + RequestKind::IgnitionState { target: port.as_ignition_target() }; + + self.request_response( + &controller, + request, + Some(timeout), + ResponseKindExt::try_into_ignition_state, + ) + .await + } + + /// Ask the local ignition controller for the ignition state of all SPs. + pub async fn get_ignition_state_all( + &self, + timeout: Instant, + ) -> Result, Error> { + let controller = self.switch.ignition_controller(); + let request = RequestKind::BulkIgnitionState; + + let bulk_state = self + .request_response( + &controller, + request, + Some(timeout), + ResponseKindExt::try_into_bulk_ignition_state, + ) + .await?; + + // deserializing checks that `num_targets` is reasonably sized, so we + // don't need to guard that here + let targets = + &bulk_state.targets[..usize::from(bulk_state.num_targets)]; + + // map ignition target indices back to `SpIdentifier`s for our caller + targets + .iter() + .copied() + .enumerate() + .map(|(target, state)| { + let port = self + .switch + .switch_port_from_ignition_target(target) + .ok_or(SpCommunicationError::BadIgnitionTarget(target))?; + let id = self.port_to_id(port); + Ok((id, state)) + }) + .collect() + } + + /// Instruct the local ignition controller to perform the given `command` on + /// `target_sp`. + pub async fn send_ignition_command( + &self, + target_sp: SpIdentifier, + command: IgnitionCommand, + timeout: Instant, + ) -> Result<(), Error> { + let controller = self.switch.ignition_controller(); + let target = self.id_to_port(target_sp)?.as_ignition_target(); + let request = RequestKind::IgnitionCommand { target, command }; + + self.request_response( + &controller, + request, + Some(timeout), + ResponseKindExt::try_into_ignition_command_ack, + ) + .await + } + + /// Get our current serial console contents for the given SP component. + pub fn serial_console_contents( + &self, + sp: SpIdentifier, + component: &SpComponent, + ) -> Result, Error> { + let port = self.id_to_port(sp)?; + Ok(self.recv_handler.serial_console_contents(port, component)) + } + + /// Send `data` to the given SP component's serial console. + pub async fn serial_console_send( + &self, + sp: SpIdentifier, + component: &SpComponent, + data: &[u8], + timeout: Instant, + ) -> Result<(), Error> { + let port = self.id_to_port(sp)?; + let sp = + self.switch.sp_socket(port).ok_or(Error::SpAddressUnknown(sp))?; + + // TODO how do we handle multiple serial console sends to the same SP at + // the same time? in a previous iteration of this code I had a mutex + // here, but that only protects against multiple posts going through the + // same MGS instance - we have two, and local mutexes won't help if both + // are trying to send to the same SP simultaneously. maybe we need some + // kind of time-limited lock from the SP side where it discards any + // incoming data from other sources? + let mut packetizer = SerialConsolePacketizer::new(*component); + for packet in packetizer.packetize(data) { + self.request_response( + &sp, + RequestKind::SerialConsoleWrite(packet), + Some(timeout), + ResponseKindExt::try_into_serial_console_write_ack, + ) + .await?; + } + + Ok(()) + } + + /// Get the state of a given SP. + pub async fn get_state( + &self, + sp: SpIdentifier, + timeout: Instant, + ) -> Result { + self.get_state_maybe_timeout(sp, Some(timeout)).await + } + + /// Get the state of a given SP without a timeout; it is the caller's + /// responsibility to ensure a reasonable timeout is applied higher up in + /// the chain. + // TODO we could have one method that takes `Option` for a timeout, + // and/or apply that to _all_ the methods in this class. I don't want to + // make it easy to accidentally call a method without providing a timeout, + // though, so went with the current design. + pub async fn get_state_without_timeout( + &self, + sp: SpIdentifier, + ) -> Result { + self.get_state_maybe_timeout(sp, None).await + } + + async fn get_state_maybe_timeout( + &self, + sp: SpIdentifier, + timeout: Option, + ) -> Result { + let port = self.id_to_port(sp)?; + let sp = + self.switch.sp_socket(port).ok_or(Error::SpAddressUnknown(sp))?; + let request = RequestKind::SpState; + + self.request_response( + &sp, + request, + timeout, + ResponseKindExt::try_into_sp_state, + ) + .await + } + + /// Query all online SPs. + /// + /// `ignition_state` should be the state returned by a (recent) call to + /// [`get_ignition_state_all()`]. + /// + /// All SPs included in `ignition_state` will be yielded by the returned + /// stream. The order in which they are yielded is undefined; the offline + /// SPs are likely to be first, but even that is not guaranteed. The item + /// yielded by offline SPs will be `None`; the item yielded by online SPs + /// will be `Some(Ok(_))` if the future returned by `f` for that item + /// completed before `timeout` or `Some(Err(_))` if not. + /// + /// Note that the timeout is be applied to each _element_ of the returned + /// stream rather than the stream as a whole, allowing easy access to which + /// SPs timed out based on the yielded value associated with those SPs. + pub fn query_all_online_sps( + &self, + ignition_state: &[(SpIdentifier, IgnitionState)], + timeout: Instant, + f: F, + ) -> impl FuturesUnorderedImpl< + Item = ( + SpIdentifier, + IgnitionState, + Option>, + ), + > + where + F: FnMut(SpIdentifier) -> Fut + Clone, + Fut: Future, + { + ignition_state + .iter() + .copied() + .map(move |(id, state)| { + let mut f = f.clone(); + async move { + let val = if state.is_powered_on() { + Some(tokio::time::timeout_at(timeout, f(id)).await) + } else { + None + }; + (id, state, val) + } + }) + .collect::>() + } + + async fn request_response( + &self, + sp: &SpSocket<'_>, + request: RequestKind, + timeout: Option, + map_response_kind: F, + ) -> Result + where + F: FnOnce(ResponseKind) -> Result, + { + // request IDs will eventually roll over; since we enforce timeouts + // this should be a non-issue in practice. does this need testing? + let request_id = self.request_id.fetch_add(1, Ordering::Relaxed); + + // update our recv_handler to expect a response for this request ID + let response_fut = + self.recv_handler.register_request_id(sp.port(), request_id); + + // Serialize and send our request. We know `buf` is large enough for any + // `Request`, so unwrapping here is fine. + let request = + Request { version: version::V1, request_id, kind: request }; + let mut buf = [0; Request::MAX_SIZE]; + let n = gateway_messages::serialize(&mut buf, &request).unwrap(); + + let serialized_request = &buf[..n]; + + let fut = async { + debug!(&self.log, "sending {:?} to SP {:?}", request, sp); + sp.send(serialized_request).await.map_err(|err| { + SpCommunicationError::UdpSend { addr: sp.addr(), err } + })?; + + // confirm we can convert the response into the expected type + let response = response_fut.await?; + map_response_kind(response) + }; + + let result = match timeout { + Some(t) => tokio::time::timeout_at(t, fut).await?, + None => fut.await, + }; + + Ok(result?) + } +} + +// When we send a request we expect a specific kind of response; the boilerplate +// for confirming that is a little noisy, so it lives in this extension trait. +trait ResponseKindExt { + fn name(&self) -> &'static str; + + fn try_into_ignition_state( + self, + ) -> Result; + + fn try_into_bulk_ignition_state( + self, + ) -> Result; + + fn try_into_ignition_command_ack(self) -> Result<(), SpCommunicationError>; + + fn try_into_sp_state(self) -> Result; + + fn try_into_serial_console_write_ack( + self, + ) -> Result<(), SpCommunicationError>; +} + +impl ResponseKindExt for ResponseKind { + fn name(&self) -> &'static str { + match self { + ResponseKind::Pong => response_kind_names::PONG, + ResponseKind::IgnitionState(_) => { + response_kind_names::IGNITION_STATE + } + ResponseKind::BulkIgnitionState(_) => { + response_kind_names::BULK_IGNITION_STATE + } + ResponseKind::IgnitionCommandAck => { + response_kind_names::IGNITION_COMMAND_ACK + } + ResponseKind::SpState(_) => response_kind_names::SP_STATE, + ResponseKind::SerialConsoleWriteAck => { + response_kind_names::SERIAL_CONSOLE_WRITE_ACK + } + } + } + + fn try_into_ignition_state( + self, + ) -> Result { + match self { + ResponseKind::IgnitionState(state) => Ok(state), + other => Err(SpCommunicationError::BadResponseType { + expected: response_kind_names::IGNITION_STATE, + got: other.name(), + }), + } + } + + fn try_into_bulk_ignition_state( + self, + ) -> Result { + match self { + ResponseKind::BulkIgnitionState(state) => Ok(state), + other => Err(SpCommunicationError::BadResponseType { + expected: response_kind_names::BULK_IGNITION_STATE, + got: other.name(), + }), + } + } + + fn try_into_ignition_command_ack(self) -> Result<(), SpCommunicationError> { + match self { + ResponseKind::IgnitionCommandAck => Ok(()), + other => Err(SpCommunicationError::BadResponseType { + expected: response_kind_names::IGNITION_COMMAND_ACK, + got: other.name(), + }), + } + } + + fn try_into_sp_state(self) -> Result { + match self { + ResponseKind::SpState(state) => Ok(state), + other => Err(SpCommunicationError::BadResponseType { + expected: response_kind_names::SP_STATE, + got: other.name(), + }), + } + } + + fn try_into_serial_console_write_ack( + self, + ) -> Result<(), SpCommunicationError> { + match self { + ResponseKind::SerialConsoleWriteAck => Ok(()), + other => Err(SpCommunicationError::BadResponseType { + expected: response_kind_names::SP_STATE, + got: other.name(), + }), + } + } +} + +mod response_kind_names { + pub(super) const PONG: &str = "pong"; + pub(super) const IGNITION_STATE: &str = "ignition_state"; + pub(super) const BULK_IGNITION_STATE: &str = "bulk_ignition_state"; + pub(super) const IGNITION_COMMAND_ACK: &str = "ignition_command_ack"; + pub(super) const SP_STATE: &str = "sp_state"; + pub(super) const SERIAL_CONSOLE_WRITE_ACK: &str = + "serial_console_write_ack"; +} diff --git a/gateway-sp-comms/src/error.rs b/gateway-sp-comms/src/error.rs index 14d5ec8b95..2adedf690c 100644 --- a/gateway-sp-comms/src/error.rs +++ b/gateway-sp-comms/src/error.rs @@ -4,6 +4,8 @@ // Copyright 2022 Oxide Computer Company +use crate::SpIdentifier; +use gateway_messages::ResponseError; use std::io; use std::net::SocketAddr; use thiserror::Error; @@ -13,3 +15,37 @@ pub enum StartupError { #[error("error binding to UDP address {addr}: {err}")] UdpBind { addr: SocketAddr, err: io::Error }, } + +#[derive(Debug, Error)] +pub enum Error { + #[error("nonexistent SP (type {:?}, slot {})", .0.typ, .0.slot)] + SpDoesNotExist(SpIdentifier), + #[error( + "unknown socket address for SP (type {:?}, slot {})", + .0.typ, + .0.slot, + )] + SpAddressUnknown(SpIdentifier), + #[error("timeout elapsed")] + Timeout, + #[error("error communicating with SP: {0}")] + SpCommunicationFailed(#[from] SpCommunicationError), +} + +#[derive(Debug, Error)] +pub enum SpCommunicationError { + #[error("failed to send UDP packet to {addr}: {err}")] + UdpSend { addr: SocketAddr, err: io::Error }, + #[error("error reported by SP: {0}")] + SpError(#[from] ResponseError), + #[error("bogus SP response type: expected {expected:?} but got {got:?}")] + BadResponseType { expected: &'static str, got: &'static str }, + #[error("bogus SP response: specified unknown ignition target {0}")] + BadIgnitionTarget(usize), +} + +impl From for Error { + fn from(_: tokio::time::error::Elapsed) -> Self { + Self::Timeout + } +} diff --git a/gateway-sp-comms/src/lib.rs b/gateway-sp-comms/src/lib.rs index 22b004002c..c878a8f9f6 100644 --- a/gateway-sp-comms/src/lib.rs +++ b/gateway-sp-comms/src/lib.rs @@ -4,18 +4,23 @@ // Copyright 2022 Oxide Computer Company -pub mod error; +//! This crate provides UDP-based communication across the Oxide management +//! switch to a collection of SPs. +//! +//! The primary entry point is [`Communicator`]. + +mod communicator; mod management_switch; +mod recv_handler; +pub mod error; + +pub use communicator::Communicator; +pub use communicator::FuturesUnorderedImpl; pub use management_switch::SpIdentifier; pub use management_switch::SpType; - -// TODO the following should probably not be pub once this crate is more -// complete; for now make them pub so `gateway` can use them directly -pub use management_switch::ManagementSwitch; -pub use management_switch::ManagementSwitchDiscovery; -pub use management_switch::SpSocket; -pub use management_switch::SwitchPort; +pub use recv_handler::SerialConsoleChunk; +pub use recv_handler::SerialConsoleContents; // TODO these will remain public for a while, but eventually will be removed // altogther; currently these provide a way to hard-code the rack topology, diff --git a/gateway-sp-comms/src/management_switch.rs b/gateway-sp-comms/src/management_switch.rs index afc5772c1f..c1ca71e1ea 100644 --- a/gateway-sp-comms/src/management_switch.rs +++ b/gateway-sp-comms/src/management_switch.rs @@ -48,13 +48,13 @@ pub struct KnownSps { pub power_controllers: Vec, } -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct SpIdentifier { pub typ: SpType, pub slot: usize, } -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum SpType { Switch, Sled, @@ -62,10 +62,10 @@ pub enum SpType { } #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct SwitchPort(usize); +pub(crate) struct SwitchPort(usize); impl SwitchPort { - pub fn as_ignition_target(self) -> u8 { + pub(crate) fn as_ignition_target(self) -> u8 { // TODO should we use a u16 to describe ignition targets instead? rack // v1 is limited to 36, unclear what ignition will look like in future // products @@ -78,7 +78,7 @@ impl SwitchPort { } #[derive(Debug)] -pub struct ManagementSwitchDiscovery { +pub(crate) struct ManagementSwitchDiscovery { inner: Arc, } @@ -94,7 +94,7 @@ impl ManagementSwitchDiscovery { // ``` // // and any leftover bind addresses are ignored. - pub async fn placeholder_start( + pub(crate) async fn placeholder_start( known_sps: KnownSps, log: Logger, ) -> Result { @@ -126,7 +126,7 @@ impl ManagementSwitchDiscovery { // // TODO 2 should we attach the SP type to each port? For now just return a // flat list. - pub fn all_ports( + pub(crate) fn all_ports( &self, ) -> impl ExactSizeIterator + 'static { (0..self.inner.ports.len()).map(SwitchPort) @@ -134,7 +134,7 @@ impl ManagementSwitchDiscovery { /// Consume `self` and start a long-running task to receive packets on all /// ports, calling `recv_callback` for each. - pub fn start_recv_task(self, recv_callback: F) -> ManagementSwitch + pub(crate) fn start_recv_task(self, recv_callback: F) -> ManagementSwitch where F: Fn(SwitchPort, &'_ [u8]) + Send + 'static, { @@ -147,7 +147,7 @@ impl ManagementSwitchDiscovery { } #[derive(Debug)] -pub struct ManagementSwitch { +pub(crate) struct ManagementSwitch { inner: Arc, // handle to the running task that calls recv on all `switch_ports` sockets; @@ -163,7 +163,7 @@ impl Drop for ManagementSwitch { impl ManagementSwitch { /// Get the socket connected to the local ignition controller. - pub fn ignition_controller(&self) -> SpSocket { + pub(crate) fn ignition_controller(&self) -> SpSocket { // TODO for now this is guaranteed to exist based on the assertions in // `placeholder_start`; once that's replaced by a non-placeholder // implementation, revisit this. @@ -171,7 +171,7 @@ impl ManagementSwitch { self.inner.sp_socket(port).unwrap() } - pub fn switch_port_from_ignition_target( + pub(crate) fn switch_port_from_ignition_target( &self, target: usize, ) -> Option { @@ -184,15 +184,15 @@ impl ManagementSwitch { } } - pub fn switch_port(&self, id: SpIdentifier) -> Option { + pub(crate) fn switch_port(&self, id: SpIdentifier) -> Option { self.inner.switch_port(id.typ, id.slot) } - pub fn switch_port_to_id(&self, port: SwitchPort) -> SpIdentifier { + pub(crate) fn switch_port_to_id(&self, port: SwitchPort) -> SpIdentifier { self.inner.port_to_id(port) } - pub fn sp_socket(&self, port: SwitchPort) -> Option> { + pub(crate) fn sp_socket(&self, port: SwitchPort) -> Option> { self.inner.sp_socket(port) } } @@ -200,24 +200,24 @@ impl ManagementSwitch { /// Wrapper for a UDP socket on one of our switch ports that knows the address /// of the SP connected to this port. #[derive(Debug)] -pub struct SpSocket<'a> { +pub(crate) struct SpSocket<'a> { socket: &'a UdpSocket, addr: SocketAddr, port: SwitchPort, } impl SpSocket<'_> { - pub fn addr(&self) -> SocketAddr { + pub(crate) fn addr(&self) -> SocketAddr { self.addr } - pub fn port(&self) -> SwitchPort { + pub(crate) fn port(&self) -> SwitchPort { self.port } /// Wrapper around `send_to` that uses the SP address stored in `self` as /// the destination address. - pub async fn send(&self, buf: &[u8]) -> io::Result { + pub(crate) async fn send(&self, buf: &[u8]) -> io::Result { self.socket.send_to(buf, self.addr).await } } diff --git a/gateway-sp-comms/src/recv_handler/mod.rs b/gateway-sp-comms/src/recv_handler/mod.rs new file mode 100644 index 0000000000..fd365e4a6a --- /dev/null +++ b/gateway-sp-comms/src/recv_handler/mod.rs @@ -0,0 +1,200 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +// Copyright 2022 Oxide Computer Company + +use crate::management_switch::ManagementSwitch; +use crate::management_switch::ManagementSwitchDiscovery; +use crate::management_switch::SwitchPort; +use gateway_messages::version; +use gateway_messages::ResponseError; +use gateway_messages::ResponseKind; +use gateway_messages::SerialConsole; +use gateway_messages::SpComponent; +use gateway_messages::SpMessage; +use gateway_messages::SpMessageKind; +use slog::debug; +use slog::error; +use slog::trace; +use slog::Logger; +use std::collections::HashMap; +use std::sync::Arc; +use std::sync::Mutex; + +mod request_response_map; +mod serial_console_history; + +pub use self::serial_console_history::SerialConsoleChunk; +pub use self::serial_console_history::SerialConsoleContents; + +use self::request_response_map::RequestResponseMap; +use self::request_response_map::ResponseIngestResult; +use self::serial_console_history::SerialConsoleHistory; + +/// Handler for incoming packets received on management switch ports. +/// +/// When [`RecvHandler::new()`] is created, it starts an indefinite tokio task +/// that calls [`RecvHandler::handle_incoming_packet()`] for each UDP packet +/// received. Those packets come in two flavors: responses to requests that we +/// sent to an SP, and unprompted messages generated by an SP. The flow for +/// request / response is: +/// +/// 1. [`RecvHandler::register_request_id()`] is called with the 32-bit request +/// ID associated with the request. This returns a future that will be +/// fulfilled with the response once it arrives. This method should be called +/// after determining the request ID but before actually sending the request +/// to avoid a race window where the response could be received before the +/// receive handler task knows to expect it. Internally, this future holds +/// the receiving half of a [`tokio::oneshot`] channel. +/// 2. The request packet is sent to the SP. +/// 3. The requester `.await`s the future returned by `register_request_id()`. +/// 4. If the future is dropped before a response is received (e.g., due to a +/// timeout), dropping the future will unregister the request ID provided in +/// step 1 (see [`RequestResponseMap::wait_for_response()`] for details). +/// 5. Assuming the future has not been dropped when the response arrives, +/// [`RecvHandler::handle_incoming_packet()`] will look up the request ID and +/// send the response on the sending half of the [`tokio::oneshot`] channel +/// corresponding to the future returned in step 1, fulfilling it. +#[derive(Debug)] +pub(crate) struct RecvHandler { + sp_state: HashMap, + log: Logger, +} + +impl RecvHandler { + /// Create a new `RecvHandler` that is aware of all ports described by + /// `switch`. + pub(crate) fn new( + switch_discovery: ManagementSwitchDiscovery, + log: Logger, + ) -> (ManagementSwitch, Arc) { + // prime `sp_state` with all known ports of the switch + let all_ports = switch_discovery.all_ports(); + let mut sp_state = HashMap::with_capacity(all_ports.len()); + for port in all_ports { + sp_state.insert(port, SingleSpState::default()); + } + + // configure a `ManagementSwitch` that notifies us of every incoming + // packet + let handler = Arc::new(Self { sp_state, log }); + let switch = { + let handler = Arc::clone(&handler); + switch_discovery.start_recv_task(move |port, buf| { + handler.handle_incoming_packet(port, buf) + }) + }; + + (switch, handler) + } + + // SwitchPort instances can only be created by `ManagementSwitch`, so we + // should never be able to instantiate a port that we don't have in + // `self.sp_state` (which we initialize with all ports declared by the + // switch we were given). + fn sp_state(&self, port: SwitchPort) -> &SingleSpState { + self.sp_state.get(&port).expect("invalid switch port") + } + + /// Get our current serial console contents for the given SP component. + pub(crate) fn serial_console_contents( + &self, + port: SwitchPort, + component: &SpComponent, + ) -> Option { + self.sp_state(port).serial_console.lock().unwrap().contents(component) + } + + /// Returns a future that will complete when we receive a response on the + /// given `port` with the corresponding `request_id`. + /// + /// Panics if `port` is not one of the ports defined by the `switch` given + /// to this `RecvHandler` when it was constructed. + pub(crate) async fn register_request_id( + &self, + port: SwitchPort, + request_id: u32, + ) -> Result { + self.sp_state(port).requests.wait_for_response(request_id).await + } + + fn handle_incoming_packet(&self, port: SwitchPort, buf: &[u8]) { + trace!(&self.log, "received {} bytes from {:?}", buf.len(), port); + + // parse into an `SpMessage` + let sp_msg = match gateway_messages::deserialize::(buf) { + Ok((msg, _extra)) => { + // TODO should we check that `extra` is empty? if the + // response is maximal size any extra data is silently + // discarded anyway, so probably not? + msg + } + Err(err) => { + error!(&self.log, "discarding malformed message ({})", err); + return; + } + }; + debug!(&self.log, "received {:?} from {:?}", sp_msg, port); + + // `version` is intentionally the first 4 bytes of the packet; we + // could check it before trying to deserialize? + if sp_msg.version != version::V1 { + error!( + &self.log, + "discarding message with unsupported version {}", + sp_msg.version + ); + return; + } + + // decide whether this is a response to an outstanding request or an + // unprompted message + match sp_msg.kind { + SpMessageKind::Response { request_id, result } => { + self.handle_response(port, request_id, result); + } + SpMessageKind::SerialConsole(serial_console) => { + self.handle_serial_console(port, serial_console); + } + } + } + + fn handle_response( + &self, + port: SwitchPort, + request_id: u32, + result: Result, + ) { + match self.sp_state(port).requests.ingest_response(&request_id, result) + { + ResponseIngestResult::Ok => (), + ResponseIngestResult::UnknownRequestId => { + error!( + &self.log, + "discarding unexpected response {} from {:?} (possibly past timeout?)", + request_id, + port, + ); + } + } + } + + fn handle_serial_console(&self, port: SwitchPort, packet: SerialConsole) { + debug!( + &self.log, + "received serial console data from {:?}: {:?}", port, packet + ); + self.sp_state(port) + .serial_console + .lock() + .unwrap() + .push(packet, &self.log); + } +} + +#[derive(Debug, Default)] +struct SingleSpState { + requests: RequestResponseMap>, + serial_console: Mutex, +} diff --git a/gateway-sp-comms/src/recv_handler/request_response_map.rs b/gateway-sp-comms/src/recv_handler/request_response_map.rs new file mode 100644 index 0000000000..f1400d55ba --- /dev/null +++ b/gateway-sp-comms/src/recv_handler/request_response_map.rs @@ -0,0 +1,170 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +// Copyright 2022 Oxide Computer Company + +use futures::Future; +use std::collections::HashMap; +use std::hash::Hash; +use std::sync::Mutex; +use tokio::sync::oneshot; +use tokio::sync::oneshot::Receiver; +use tokio::sync::oneshot::Sender; + +/// Possible outcomes of ingesting a response. +#[derive(Debug, Clone, Copy, PartialEq)] +pub(super) enum ResponseIngestResult { + Ok, + UnknownRequestId, +} + +/// An in-memory map providing paired request/response functionality to async +/// tasks. Both tasks must know the key of the request. +#[derive(Debug)] +pub(super) struct RequestResponseMap { + requests: Mutex>>, +} + +// this can't be derived for T: !Default, but we don't need T: Default +impl Default for RequestResponseMap { + fn default() -> Self { + Self { requests: Mutex::default() } + } +} + +impl RequestResponseMap +where + K: Hash + Eq + Clone, +{ + /// Register a key to be paired with a future call to `ingest_response()`. + /// Returns a [`Future`] that will complete once that call is made. + /// + /// Panics if `key` is currently in use by another + /// `wait_for_response()` call on `self`. + pub(super) fn wait_for_response( + &self, + key: K, + ) -> impl Future + '_ { + // construct a oneshot channel, and wrap the receiving half in a type + // that will clean up `self.requests` if the future we return is dropped + // (typically due to timeout/cancellation) + let (tx, rx) = oneshot::channel(); + let mut wrapped_rx = + RemoveOnDrop { rx, requests: &self.requests, key: key.clone() }; + + // insert the sending half into `self.requests` for use by + // `ingest_response()` later. + let old = self.requests.lock().unwrap().insert(key, tx); + + // we always clean up `self.requests` after receiving a response (or + // timing out), so we should never see request ID reuse even if they + // roll over. assert that here to ensure we don't get mysterious + // misbehavior if that turns out to be incorrect + assert!(old.is_none(), "request ID reuse"); + + async move { + // wait for someone to call `ingest_response()` with our request id + match (&mut wrapped_rx.rx).await { + Ok(inner_result) => inner_result, + Err(_recv_error) => { + // we hold the sending half in `self.requests` until either + // `wrapped_rx`'s `Drop` impl removes it (in which case + // we're not running anymore), or it's consumed to send the + // result to us. receiving therefore can't fail + unreachable!() + } + } + } + } + + /// Ingest a response, which will cause the corresponding + /// `wait_for_response()` future to be fulfilled (if it exists). + pub(super) fn ingest_response( + &self, + key: &K, + response: T, + ) -> ResponseIngestResult { + // get the sending half of the channel created in `wait_for_response()`, + // if it exists (i.e., `wait_for_response()` was actually called with + // `key`, and the returned future from it hasn't been dropped) + let tx = match self.requests.lock().unwrap().remove(key) { + Some(tx) => tx, + None => return ResponseIngestResult::UnknownRequestId, + }; + + // we got `tx`, so the receiving end existed a moment ago, but there's a + // race here where it could be dropped before we're able to send + // `result` through, so we can't unwrap this send; we treat this failure + // the same as not finding `tx`, because if we had tried to get `tx` a + // moment later it would not have been there. + match tx.send(response) { + Ok(()) => ResponseIngestResult::Ok, + Err(_) => ResponseIngestResult::UnknownRequestId, + } + } +} + +/// `RemoveOnDrop` is a light wrapper around a [`Receiver`] that removes the +/// associated key from the `HashMap` it's given once the response has been +/// received (or the `RemoveOnDrop` itself is dropped). +#[derive(Debug)] +struct RemoveOnDrop<'a, K, T> +where + K: Hash + Eq, +{ + rx: Receiver, + requests: &'a Mutex>>, + key: K, +} + +impl Drop for RemoveOnDrop<'_, K, T> +where + K: Hash + Eq, +{ + fn drop(&mut self) { + // we don't care to check the return value here; this will be `Some(_)` + // if we're being dropped before a response has been received and + // forwarded on to our caller, and `None` if not (because a caller will + // have already extracted the sending channel) + let _ = self.requests.lock().unwrap().remove(&self.key); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::mem; + + #[tokio::test] + async fn basic_usage() { + let os = RequestResponseMap::default(); + + // ingesting a response before waiting for it doesn't work + assert_eq!( + os.ingest_response(&1, "hi"), + ResponseIngestResult::UnknownRequestId + ); + + // ingesting a response after waiting for it works, and the receiving + // half gets the ingested response + let resp = os.wait_for_response(1); + assert_eq!(os.ingest_response(&1, "hello"), ResponseIngestResult::Ok); + assert_eq!(resp.await, "hello"); + } + + #[tokio::test] + async fn dropping_future_cleans_up_key() { + let os = RequestResponseMap::default(); + + // register interest in a response, but then drop the returned future + let resp = os.wait_for_response(1); + mem::drop(resp); + + // attempting to ingest the corresponding response should now fail + assert_eq!( + os.ingest_response(&1, "hi"), + ResponseIngestResult::UnknownRequestId + ); + } +} diff --git a/gateway/src/sp_comms/serial_console_history.rs b/gateway-sp-comms/src/recv_handler/serial_console_history.rs similarity index 65% rename from gateway/src/sp_comms/serial_console_history.rs rename to gateway-sp-comms/src/recv_handler/serial_console_history.rs index a92fb42e5a..db4d0285b0 100644 --- a/gateway/src/sp_comms/serial_console_history.rs +++ b/gateway-sp-comms/src/recv_handler/serial_console_history.rs @@ -2,27 +2,31 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at https://mozilla.org/MPL/2.0/. -use gateway_messages::{SerialConsole, SpComponent}; -use ringbuffer::{AllocRingBuffer, RingBufferExt, RingBufferWrite}; -use schemars::JsonSchema; -use serde::Serialize; -use slog::{warn, Logger}; -use std::{cmp::Ordering, collections::HashMap, mem}; +// Copyright 2022 Oxide Computer Company + +use gateway_messages::SerialConsole; +use gateway_messages::SpComponent; +use ringbuffer::AllocRingBuffer; +use ringbuffer::RingBufferExt; +use ringbuffer::RingBufferWrite; +use slog::warn; +use slog::Logger; +use std::cmp::Ordering; +use std::collections::HashMap; +use std::mem; /// Current in-memory contents of an SP component's serial console. /// /// If we have not received any serial console data from this SP component, /// `start` will be `0` and `chunks` will be empty. -// TODO is it a code smell that this type impls Serialize and JsonSchema but -// isn't defined in `http_entrypoints.rs`? -#[derive(Debug, Default, Serialize, JsonSchema)] -pub(crate) struct SerialConsoleContents { +#[derive(Debug, PartialEq)] +pub struct SerialConsoleContents { /// Position since SP component start of the first byte of the first element /// of `chunks`. /// /// This is equal to the number of bytes that we've discarded since the SP /// started. - pub(crate) start: u64, + pub start: u64, /// Chunks of serial console data. /// @@ -31,14 +35,13 @@ pub(crate) struct SerialConsoleContents { /// be different (i.e., one will be [`SerialConsoleChunk::Missing`] and the /// other will be [`SerialConsoleChunk::Data`]). If `chunks` is not empty, /// its final element is guaranteed to be a [`SerialConsoleChunk::Data`]. - pub(crate) chunks: Vec, + pub chunks: Vec, } /// A chunk of serial console data: either actual data, or an amount of data /// we missed (presumably due to dropped packets or something similar). -#[derive(Debug, Serialize, JsonSchema)] -#[serde(tag = "kind", rename_all = "lowercase")] -pub(crate) enum SerialConsoleChunk { +#[derive(Debug, PartialEq)] +pub enum SerialConsoleChunk { Data { bytes: Vec }, Missing { len: u64 }, } @@ -170,3 +173,91 @@ impl Slot { } } } + +#[cfg(test)] +mod tests { + use super::*; + use once_cell::sync::Lazy; + use std::convert::TryFrom; + + static COMPONENT: Lazy = + Lazy::new(|| SpComponent::try_from("test").unwrap()); + + fn make_packet(offset: usize, data: &str) -> SerialConsole { + let mut packet = SerialConsole { + component: *COMPONENT, + offset: u64::try_from(offset).unwrap(), + len: u16::try_from(data.len()).unwrap(), + data: [0; SerialConsole::MAX_DATA_PER_PACKET], + }; + packet.data[..data.len()].copy_from_slice(data.as_bytes()); + packet + } + + #[test] + fn consecutive_packets_are_coalesced() { + let mut hist = SerialConsoleHistory::default(); + let log = Logger::root(slog::Discard, slog::o!()); + + // push first packet + hist.push(make_packet(0, "hello "), &log); + let contents = hist.contents(&COMPONENT).unwrap(); + assert_eq!( + contents, + SerialConsoleContents { + start: 0, + chunks: vec![SerialConsoleChunk::Data { + bytes: b"hello ".to_vec() + }] + } + ); + + // push second packet with an offset that ends at first + hist.push(make_packet(6, "world"), &log); + let contents = hist.contents(&COMPONENT).unwrap(); + assert_eq!( + contents, + SerialConsoleContents { + start: 0, + chunks: vec![SerialConsoleChunk::Data { + bytes: b"hello world".to_vec() + }] + } + ); + } + + #[test] + fn skipped_data_leaves_gaps() { + let mut hist = SerialConsoleHistory::default(); + let log = Logger::root(slog::Discard, slog::o!()); + + // push first packet + hist.push(make_packet(0, "hello "), &log); + let contents = hist.contents(&COMPONENT).unwrap(); + assert_eq!( + contents, + SerialConsoleContents { + start: 0, + chunks: vec![SerialConsoleChunk::Data { + bytes: b"hello ".to_vec() + }] + } + ); + + // push second packet with an offset that is 5 bytes past the end of the + // first packet (i.e., 5 bytes were lost) + hist.push(make_packet(6 + 5, "world"), &log); + let contents = hist.contents(&COMPONENT).unwrap(); + assert_eq!( + contents, + SerialConsoleContents { + start: 0, + chunks: vec![ + SerialConsoleChunk::Data { bytes: b"hello ".to_vec() }, + SerialConsoleChunk::Missing { len: 5 }, + SerialConsoleChunk::Data { bytes: b"world".to_vec() }, + ] + } + ); + } +} diff --git a/gateway/src/sp_comms/bulk_state_get.rs b/gateway/src/bulk_state_get.rs similarity index 77% rename from gateway/src/sp_comms/bulk_state_get.rs rename to gateway/src/bulk_state_get.rs index 139f35ef85..f79bda357c 100644 --- a/gateway/src/sp_comms/bulk_state_get.rs +++ b/gateway/src/bulk_state_get.rs @@ -9,37 +9,36 @@ //! state. //! //! The flow of this process, starting from the client(s), is below. This -//! modules comes into play at step 3. +//! modules comes into play at step 2. //! //! 1. A client requests the state of all SPs. The client provides a timeout for //! the overall request (or we provide a default timeout on their behalf), //! which specifies the point at which any SPs we believe are on but which we //! haven't heard from are classified as "unresponsive". //! -//! 2. We ask our ignition controller for the ignition state of all SPs. -//! -//! 3. We enter `OutstandingSpStateRequests::start()`, which creates a -//! `ResponseCollector` (identified by an `SpStateRequestId`) and spawns a -//! tokio task responsible for populating it. -//! a. The background task will immediately populate results for any SPs the -//! ignition controller reported were off. For any SPs the ignition -//! controller reported were on, it will ask them for their state and wait -//! for a response (up to the timeout from step 1). -//! b. Once the background task has reported a result for every SP, it enters +//! 2. We enter [`BulkSpStateRequests::start()`]. +//! a. Ask our ignition controller for the ignition state of all SPs. +//! b. We create a [`SpStateRequestId`] (internally a UUID) to identify this +//! request. +//! c. We create a set of futures that will be fulfilled with the state of +//! all SPs. Any offline SPs will be fulfilled immediately. Any online SPs +//! will have their states retreived via [`Communicator::get_state()`]. +//! Each of these futures is bounded by the timeout from step 1. +//! d. Once the background task has reported a result for every SP, it enters //! a grace period in which the request ID from step 3 is still alive in //! memory and can be queried by clients, but there is no more work //! happening. -//! c. Once the grace period ends, the background task purges the data +//! e. Once the grace period ends, the background task purges the data //! associated with the request ID and exits. //! -//! 4. We enter `OutstandingSpStateRequests::get()` with the request ID from -//! step 3. This allows us to look up the corresponding `ResponseCollector` +//! 3. We enter [`BulkSpStateRequests::get()`] with the request ID from +//! step 2. This allows us to look up the corresponding `ResponseCollector` //! and wait for responses. The client endpoint is paginated, and we have the //! option of how (or even if) we want to return partial progress early; //! currently we choose to implement this via a timeout duration specified in //! the gateway configuration. //! -//! 5. We wait in `OutstandingSpStateRequests::get()` until we hit one of three +//! 4. We wait in [`BulkSpStateRequests::get()`] until we hit one of three //! cases: //! //! a. We've collected a partial set of responses of the size specified by @@ -50,25 +49,22 @@ //! //! In any of these three cases, we return a page token to the client that //! includes the request ID, allowing them to fetch the next page (which -//! enters this process at step 4). In case 5c we should _not_ send a page +//! enters this process at step 3). In case 4c we should _not_ send a page //! token, since we know we've reported the final page, but currently there's //! no clean way to handle this with dropshot. //! //! The mechanics for this step are a little messy; see the comments in -//! `OutstandingSpStateRequests::get()` for details. +//! [`BulkSpStateRequests::get()`] for details. -use super::Error; -use super::SpCommunicator; -use crate::http_entrypoints::SpState; -use futures::future; -use futures::future::Either; -use futures::stream::FuturesUnordered; -use futures::Future; -use futures::FutureExt; +use crate::error::Error; +use crate::error::InvalidPageToken; use futures::StreamExt; -use gateway_messages::{IgnitionFlags, IgnitionState}; -use gateway_sp_comms::SwitchPort; -use serde::{Deserialize, Serialize}; +use gateway_messages::IgnitionState; +use gateway_sp_comms::Communicator; +use gateway_sp_comms::FuturesUnorderedImpl; +use gateway_sp_comms::SpIdentifier; +use serde::Deserialize; +use serde::Serialize; use slog::debug; use slog::error; use slog::trace; @@ -82,6 +78,37 @@ use tokio::sync::Notify; use tokio::time::Instant; use uuid::Uuid; +use crate::http_entrypoints::SpState; + +/// Newtype wrapper around [`Uuid`] for long-running state requests. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[serde(transparent)] +pub(crate) struct SpStateRequestId(pub(crate) Uuid); + +impl slog::Value for SpStateRequestId { + fn serialize( + &self, + _record: &slog::Record, + key: slog::Key, + serializer: &mut dyn slog::Serializer, + ) -> slog::Result { + serializer.emit_arguments(key, &format_args!("{}", self.0)) + } +} + +impl SpStateRequestId { + fn new() -> Self { + Self(Uuid::new_v4()) + } +} + +#[derive(Debug, Clone)] +pub(crate) struct BulkSpStateSingleResult { + pub(crate) sp: SpIdentifier, + pub(crate) state: IgnitionState, + pub(crate) result: Result>, +} + /// Set of results for a single page of collecting state from all SPs. #[derive(Debug)] pub(crate) enum BulkStateProgress { @@ -99,51 +126,45 @@ pub(crate) enum BulkStateProgress { type RequestsMap = HashMap>>; -#[derive(Debug, Default)] -pub(super) struct OutstandingSpStateRequests { +#[derive(Debug)] +pub struct BulkSpStateRequests { + communicator: Arc, requests: Arc>, + log: Logger, } -impl OutstandingSpStateRequests { - pub(super) fn start( +impl BulkSpStateRequests { + pub(crate) fn new(communicator: Arc, log: &Logger) -> Self { + Self { + communicator, + requests: Arc::default(), + log: log.new(slog::o!("component" => "BulkSpStateRequests")), + } + } + + pub(crate) async fn start( &self, timeout: Instant, retain_grace_period: Duration, - sps: impl Iterator, - communicator: &Arc, - ) -> SpStateRequestId { + ) -> Result { // set up the receiving end of all SP responses let collector = Arc::new(RwLock::new(ResponseCollector::default())); let id = SpStateRequestId::new(); self.requests.lock().unwrap().insert(id, Arc::clone(&collector)); + // query ignition controller to find out which SPs are powered on + let all_sps = self.communicator.get_ignition_state_all(timeout).await?; + // build collection of futures to contact all SPs - let futures = sps - .map(move |(port, state)| { - let communicator = Arc::clone(communicator); - async move { - // only query the SP if it's powered on - let fut = if state.flags.intersects(IgnitionFlags::POWER) { - Either::Left( - communicator.state_get_by_port(port, timeout).map( - move |result| BulkSpStateSingleResult { - port, - state, - result, - }, - ), - ) - } else { - Either::Right(future::ready(BulkSpStateSingleResult { - port, - state, - result: Ok(SpState::Disabled), - })) - }; - fut.await - } - }) - .collect::>(); + let communicator = Arc::clone(&self.communicator); + let response_stream = self.communicator.query_all_online_sps( + &all_sps, + timeout, + move |sp| { + let communicator = Arc::clone(&communicator); + async move { communicator.get_state(sp, timeout).await } + }, + ); // spawn the background task. we don't keep a handle to this; we // attached timeouts to all the individual requests above, and after @@ -151,28 +172,27 @@ impl OutstandingSpStateRequests { // itself. tokio::spawn(wait_for_sp_responses( id, - futures, + response_stream, Arc::clone(&self.requests), collector, retain_grace_period, - communicator.log.new(slog::o!( + self.log.new(slog::o!( "request_kind" => "bulk_sp_state_start", "request_id" => id, )), )); - id + Ok(id) } pub(super) async fn get( &self, id: &SpStateRequestId, - last_seen: Option, - timeout: Duration, + last_seen: Option, + timeout: Instant, limit: usize, - log: &Logger, ) -> Result { - let log = log.new(slog::o!( + let log = self.log.new(slog::o!( "request_kind" => "bulk_sp_state_get", "request_id" => *id, )); @@ -181,16 +201,16 @@ impl OutstandingSpStateRequests { // Go ahead and create (and pin) the timeout, but we don't actually // await it until the loop at the bottom of this function. - let timeout = tokio::time::sleep(timeout); + let timeout = tokio::time::sleep_until(timeout); tokio::pin!(timeout); - let collector = self - .requests - .lock() - .unwrap() - .get(id) - .map(Arc::clone) - .ok_or(Error::NoSuchRequest)?; + let collector = + self.requests + .lock() + .unwrap() + .get(id) + .map(Arc::clone) + .ok_or(Error::InvalidPageToken(InvalidPageToken::NoSuchId))?; // TODO The locking and notification in this method is a little // precarious. We really want something like an async condition @@ -215,7 +235,7 @@ impl OutstandingSpStateRequests { if let Some(last_seen) = last_seen { for (i, result) in collector.received_states.iter().enumerate() { - if result.port == last_seen { + if result.sp == last_seen { skip_results = i + 1; trace!( log, @@ -234,7 +254,9 @@ impl OutstandingSpStateRequests { "client reported last seeing {:?}, but it isn't in our list of collected responses", last_seen ); - return Err(Error::InvalidLastSpSeen); + return Err(Error::InvalidPageToken( + InvalidPageToken::InvalidLastSeenItem, + )); } // go ahead and check to see if we already have enough info to @@ -344,6 +366,84 @@ impl OutstandingSpStateRequests { } } +type SpStateResult = + Result; + +async fn wait_for_sp_responses( + id: SpStateRequestId, + mut response_stream: S, + requests: Arc>, + collector: Arc>, + retain_grace_period: Duration, + log: Logger, +) where + S: FuturesUnorderedImpl< + Item = ( + SpIdentifier, + IgnitionState, + Option>, + ), + >, +{ + while let Some((sp, state, result)) = response_stream.next().await { + let mut collector = collector.write().unwrap(); + + // Unpack the nasty nested type: + // 1. None => ignition indicated power was off; treat that as success + // (with state = disabled) + // 2. Outer err => timeout; treat that as "success" (with state = + // unresponsive) + // 3. Inner success => true success + // 4. Inner error => wrap in an `Arc` so we can clone it + let result = match result { + None => Ok(SpState::Disabled), + Some(Err(_)) => Ok(SpState::Unresponsive), + Some(Ok(result)) => match result { + Ok(state) => Ok(SpState::from(state)), + Err(err) => Err(Arc::new(err)), + }, + }; + collector.push(BulkSpStateSingleResult { sp, state, result }); + + // this is a little goofy, but we don't want to put it after the loop + // and have to reacquire the write lock; we want clients of `collector` + // to know if the list is done as soon as they wake up. + if response_stream.is_empty() { + collector.done = true; + } + } + + // At this point we've reported the status for every SP, either because we + // know the actual status or because it timed out. We now enter the grace + // period where we keep this request in memory so clients can continue to + // ask for additional pages. + // + // In the event that we get here _earlier_ than the client-requested + // timeout (i.e., because we heard back from all the SPs and none timed + // out), we ignore the client timeout and jump immediately to staying alive + // for our internal post-completion grace period. + + debug!(log, "all responses collected; starting grace period"); + tokio::time::sleep(retain_grace_period).await; + + debug!(log, "grace period elapsed; dropping request"); + requests.lock().unwrap().remove(&id); +} + +#[derive(Debug, Default)] +struct ResponseCollector { + received_states: Vec, + notify: Arc, + done: bool, +} + +impl ResponseCollector { + fn push(&mut self, result: BulkSpStateSingleResult) { + self.received_states.push(result); + self.notify.notify_waiters(); + } +} + enum AccumulationStatus { Complete, PageLimitReached, @@ -386,84 +486,3 @@ impl AccumulationStatus { } } } - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] -#[serde(transparent)] -pub(crate) struct SpStateRequestId(pub(crate) Uuid); - -impl slog::Value for SpStateRequestId { - fn serialize( - &self, - _record: &slog::Record, - key: slog::Key, - serializer: &mut dyn slog::Serializer, - ) -> slog::Result { - serializer.emit_arguments(key, &format_args!("{}", self.0)) - } -} - -impl SpStateRequestId { - fn new() -> Self { - Self(Uuid::new_v4()) - } -} - -#[derive(Debug, Clone)] -pub(crate) struct BulkSpStateSingleResult { - pub(crate) port: SwitchPort, - pub(crate) state: IgnitionState, - pub(crate) result: Result, -} - -#[derive(Debug, Default)] -struct ResponseCollector { - received_states: Vec, - notify: Arc, - done: bool, -} - -impl ResponseCollector { - fn push(&mut self, result: BulkSpStateSingleResult) { - self.received_states.push(result); - self.notify.notify_waiters(); - } -} - -async fn wait_for_sp_responses( - id: SpStateRequestId, - mut futures: FuturesUnordered, - requests: Arc>, - collector: Arc>, - retain_grace_period: Duration, - log: Logger, -) where - Fut: Future, -{ - while let Some(result) = futures.next().await { - let mut collector = collector.write().unwrap(); - collector.push(result); - - // this is a little goofy, but we don't want to put it after the loop - // and have to reacquire the write lock; we want clients of `collector` - // to know if the list is done as soon as they wake up. - if futures.is_empty() { - collector.done = true; - } - } - - // At this point we've reported the status for every SP, either because we - // know the actual status or because it timed out. We now enter the grace - // period where we keep this request in memory so clients can continue to - // ask for additional pages. - // - // In the event that we get here _earlier_ than the client-requested - // timeout (i.e., because we heard back from all the SPs and none timed - // out), we ignore the client timeout and jump immediately to staying alive - // for our internal post-completion grace period. - - debug!(log, "all responses collected; starting grace period"); - tokio::time::sleep(retain_grace_period).await; - - debug!(log, "grace period elapsed; dropping request"); - requests.lock().unwrap().remove(&id); -} diff --git a/gateway/src/context.rs b/gateway/src/context.rs index 648ad63f03..75830e4f04 100644 --- a/gateway/src/context.rs +++ b/gateway/src/context.rs @@ -2,15 +2,16 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at https://mozilla.org/MPL/2.0/. -use crate::sp_comms::SpCommunicator; -use crate::Config; +use crate::{bulk_state_get::BulkSpStateRequests, Config}; use gateway_sp_comms::error::StartupError; +use gateway_sp_comms::Communicator; use slog::Logger; use std::{sync::Arc, time::Duration}; /// Shared state used by API request handlers pub struct ServerContext { - pub sp_comms: Arc, + pub sp_comms: Arc, + pub bulk_sp_state_requests: BulkSpStateRequests, pub timeouts: Timeouts, } @@ -51,10 +52,11 @@ impl ServerContext { config: &Config, log: &Logger, ) -> Result, StartupError> { - let sp_comms = - Arc::new(SpCommunicator::new(config.known_sps.clone(), log).await?); + let comms = + Arc::new(Communicator::new(config.known_sps.clone(), log).await?); Ok(Arc::new(ServerContext { - sp_comms, + sp_comms: Arc::clone(&comms), + bulk_sp_state_requests: BulkSpStateRequests::new(comms, log), timeouts: Timeouts::from(&config.timeouts), })) } diff --git a/gateway/src/error.rs b/gateway/src/error.rs index 48d300dfff..78729add1a 100644 --- a/gateway/src/error.rs +++ b/gateway/src/error.rs @@ -4,44 +4,53 @@ //! Error handling facilities for the management gateway. -use crate::http_entrypoints::SpIdentifier; +use std::borrow::Borrow; + use dropshot::HttpError; -use serde::{Deserialize, Serialize}; +use gateway_sp_comms::error::Error as SpCommsError; -/// An error that can be generated within the gateway. -#[derive(Debug, Deserialize, thiserror::Error, PartialEq, Serialize)] +#[derive(Debug, thiserror::Error)] pub(crate) enum Error { - /// A requested SP does not exist. - /// - /// This is not the same as the requested SP target being offline; this - /// error indicates a fatal, invalid request (e.g., asing for the SP on the - /// 17th switch when there are only two switches). - #[error("SP {} (of type {:?}) does not exist", .0.slot, .0.typ)] - SpDoesNotExist(SpIdentifier), - - /// The requested SP component ID is invalid (i.e., too long). - #[error("invalid SP component ID `{0}`")] - InvalidSpComponentId(String), + #[error("invalid page token ({0})")] + InvalidPageToken(InvalidPageToken), + #[error(transparent)] + CommunicationsError(#[from] SpCommsError), +} - /// The system encountered an unhandled operational error. - #[error("internal error: {internal_message}")] - InternalError { internal_message: String }, +#[derive(Debug, thiserror::Error)] +pub(crate) enum InvalidPageToken { + #[error("no such ID")] + NoSuchId, + #[error("invalid value for last seen item")] + InvalidLastSeenItem, } impl From for HttpError { fn from(err: Error) -> Self { match err { - Error::SpDoesNotExist(_) => HttpError::for_bad_request( - Some(String::from("SpDoesNotExist")), - err.to_string(), - ), - Error::InvalidSpComponentId(_) => HttpError::for_bad_request( - Some(String::from("InvalidSpComponentId")), + Error::InvalidPageToken(_) => HttpError::for_bad_request( + Some("InvalidPageToken".to_string()), err.to_string(), ), - Error::InternalError { internal_message } => { - HttpError::for_internal_error(internal_message) - } + Error::CommunicationsError(err) => http_err_from_comms_err(err), + } + } +} + +pub(crate) fn http_err_from_comms_err(err: E) -> HttpError +where + E: Borrow, +{ + let err = err.borrow(); + match err { + SpCommsError::SpDoesNotExist(_) => HttpError::for_bad_request( + Some("InvalidSp".to_string()), + err.to_string(), + ), + SpCommsError::SpAddressUnknown(_) + | SpCommsError::Timeout + | SpCommsError::SpCommunicationFailed(_) => { + HttpError::for_internal_error(err.to_string()) } } } diff --git a/gateway/src/http_entrypoints.rs b/gateway/src/http_entrypoints.rs index 1c82817020..b3777abcee 100644 --- a/gateway/src/http_entrypoints.rs +++ b/gateway/src/http_entrypoints.rs @@ -6,12 +6,13 @@ //! HTTP entrypoint functions for the gateway service -use crate::error::Error; -use crate::sp_comms::BulkSpStateSingleResult; -use crate::sp_comms::BulkStateProgress; -use crate::sp_comms::Error as SpCommsError; -use crate::sp_comms::SerialConsoleContents; -use crate::sp_comms::SpStateRequestId; +mod conversions; + +use self::conversions::component_from_str; +use crate::bulk_state_get::BulkSpStateSingleResult; +use crate::bulk_state_get::BulkStateProgress; +use crate::bulk_state_get::SpStateRequestId; +use crate::error::http_err_from_comms_err; use crate::ServerContext; use dropshot::endpoint; use dropshot::ApiDescription; @@ -26,10 +27,10 @@ use dropshot::ResultsPage; use dropshot::TypedBody; use dropshot::UntypedBody; use dropshot::WhichPage; -use gateway_messages::{IgnitionFlags, SpComponent}; +use gateway_messages::IgnitionCommand; +use gateway_sp_comms::error::Error as SpCommsError; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use std::convert::TryFrom; use std::sync::Arc; use std::time::Duration; use tokio::time::Instant; @@ -116,22 +117,6 @@ pub enum SpIgnition { }, } -impl From for SpIgnition { - fn from(state: gateway_messages::IgnitionState) -> Self { - // if we have a state, the SP was present - Self::Present { - id: state.id, - power: state.flags.intersects(IgnitionFlags::POWER), - ctrl_detect_0: state.flags.intersects(IgnitionFlags::CTRL_DETECT_0), - ctrl_detect_1: state.flags.intersects(IgnitionFlags::CTRL_DETECT_1), - flt_a3: state.flags.intersects(IgnitionFlags::FLT_A3), - flt_a2: state.flags.intersects(IgnitionFlags::FLT_A2), - flt_rot: state.flags.intersects(IgnitionFlags::FLT_ROT), - flt_sp: state.flags.intersects(IgnitionFlags::FLT_SP), - } - } -} - #[derive(Serialize, JsonSchema)] struct SpComponentInfo {} @@ -171,26 +156,6 @@ pub enum SpType { Switch, } -impl From for gateway_sp_comms::SpType { - fn from(typ: SpType) -> Self { - match typ { - SpType::Sled => Self::Sled, - SpType::Power => Self::Power, - SpType::Switch => Self::Switch, - } - } -} - -impl From for SpType { - fn from(typ: gateway_sp_comms::SpType) -> Self { - match typ { - gateway_sp_comms::SpType::Sled => Self::Sled, - gateway_sp_comms::SpType::Power => Self::Power, - gateway_sp_comms::SpType::Switch => Self::Switch, - } - } -} - #[derive( Debug, Clone, @@ -210,27 +175,6 @@ pub struct SpIdentifier { pub slot: u32, } -impl From for gateway_sp_comms::SpIdentifier { - fn from(id: SpIdentifier) -> Self { - Self { - typ: id.typ.into(), - // id.slot may come from an untrusted source, but usize >= 32 bits - // on any platform that will run this code, so unwrap is fine - slot: usize::try_from(id.slot).unwrap(), - } - } -} - -impl From for SpIdentifier { - fn from(id: gateway_sp_comms::SpIdentifier) -> Self { - Self { - typ: id.typ.into(), - // id.slot comes from a trusted source and will not exceed u32::MAX - slot: u32::try_from(id.slot).unwrap(), - } - } -} - // We can't use the default `Deserialize` derivation for `SpIdentifier::slot` // because it's embedded in other structs via `serde(flatten)`, which does not // play well with the way dropshot parses HTTP queries/paths. serde ends up @@ -274,14 +218,46 @@ struct PathSp { } #[derive(Serialize, Deserialize, JsonSchema)] -pub(crate) struct PathSpComponent { +struct PathSpComponent { /// ID for the SP that the gateway service translates into the appropriate /// port for communicating with the given SP. #[serde(flatten)] - pub(crate) sp: SpIdentifier, + sp: SpIdentifier, /// ID for the component of the SP; this is the internal identifier used by /// the SP itself to identify its components. - pub(crate) component: String, + component: String, +} + +/// Current in-memory contents of an SP component's serial console. +/// +/// If we have not received any serial console data from this SP component, +/// `start` will be `0` and `chunks` will be empty. +#[derive(Debug, Default, PartialEq, Serialize, JsonSchema)] +struct SerialConsoleContents { + /// Position since SP component start of the first byte of the first element + /// of `chunks`. + /// + /// This is equal to the number of bytes that we've discarded since the SP + /// started. + start: u64, + + /// Chunks of serial console data. + /// + /// We collapse contiguous regions of data (present or missing) into a + /// single chunk: any two consecutive elements of `chunks` are guaranteed to + /// be different (i.e., one will be [`SerialConsoleChunk::Missing`] and the + /// other will be [`SerialConsoleChunk::Data`]). If `chunks` is not empty, + /// its final element is guaranteed to be a [`SerialConsoleChunk::Data`]. + chunks: Vec, +} + +/// A chunk of serial console data: either actual data, or an amount of data +/// we missed (presumably due to dropped packets or something similar). +#[derive(Debug, PartialEq, Serialize, JsonSchema)] +#[serde(tag = "kind", rename_all = "lowercase")] +enum SerialConsoleChunk { + Data { bytes: Vec }, + Missing { len: u64 }, } /// List SPs @@ -313,7 +289,6 @@ async fn sp_list( query: Query>, ) -> Result>, HttpError> { let apictx = rqctx.context(); - let sp_comms = &apictx.sp_comms; let page_params = query.into_inner(); let page_limit = rqctx.page_limit(&page_params)?.get() as usize; @@ -328,18 +303,14 @@ async fn sp_list( .min(apictx.timeouts.bulk_request_max); let timeout = Instant::now() + timeout; - // query ignition to find out which SPs are on (and should therefore - // be queried for their state) - let ignition_state = sp_comms - .bulk_ignition_get(apictx.timeouts.ignition_controller) + let request_id = apictx + .bulk_sp_state_requests + .start( + timeout, + apictx.timeouts.bulk_request_retain_grace_period, + ) .await?; - // actually kick off the state collection process - let request_id = sp_comms.bulk_state_start( - timeout, - apictx.timeouts.bulk_request_retain_grace_period, - ignition_state.into_iter(), - ); (request_id, None) } WhichPage::Next(page_selector) => { @@ -347,11 +318,12 @@ async fn sp_list( } }; - let progress = sp_comms - .bulk_state_get_progress( + let progress = apictx + .bulk_sp_state_requests + .get( &request_id, last_seen_target.map(Into::into), - apictx.timeouts.bulk_request_page, + Instant::now() + apictx.timeouts.bulk_request_page, page_limit, ) .await?; @@ -370,28 +342,30 @@ async fn sp_list( let items = items .into_iter() - .map(|BulkSpStateSingleResult { port, state, result }| { + .map(|BulkSpStateSingleResult { sp, state, result }| { let details = match result { Ok(details) => details, - Err(SpCommsError::Timeout) => SpState::Unresponsive, - // TODO we're dropping the error on the floor here - how should - // we handle it? This is an SP that we actively failed to - // communicate with somehow, which isn't the same as - // "unresponsive". Should we fail the entire request? That's how - // we handle this kind of an error in the "get the state of a - // specific SP" endpoint. We could alternatively add an error - // variant to `SpState`? - Err(_) => SpState::Unresponsive, + Err(err) => match &*err { + // TODO Treating "communication failed" and "we don't know + // the IP address" as "unresponsive" may not be right. Do we + // need more refined errors? + SpCommsError::Timeout + | SpCommsError::SpCommunicationFailed(_) + | SpCommsError::SpAddressUnknown(_) => { + SpState::Unresponsive + } + // This error shouldn't be possible since we're generating + // SP ids internally; fail the request if we hit it. + SpCommsError::SpDoesNotExist(_) => return Err(err), + }, }; Ok(SpInfo { - info: SpIgnitionInfo { - id: sp_comms.port_to_id(port).into(), - details: state.into(), - }, + info: SpIgnitionInfo { id: sp.into(), details: state.into() }, details, }) }) - .collect::, Error>>()?; + .collect::, Arc>>() + .map_err(http_err_from_comms_err)?; Ok(HttpResponseOk(ResultsPage::new( items, @@ -434,15 +408,16 @@ async fn sp_get( // ping the ignition controller first; if it says the SP is off or otherwise // unavailable, we're done. let state = comms - .ignition_get(sp.into(), apictx.timeouts.ignition_controller) - .await?; + .get_ignition_state(sp.into(), timeout) + .await + .map_err(http_err_from_comms_err)?; - let details = if state.flags.intersects(IgnitionFlags::POWER) { + let details = if state.is_powered_on() { // ignition indicates the SP is on; ask it for its state - match comms.state_get(sp.into(), timeout).await { - Ok(state) => state, + match comms.get_state(sp.into(), timeout).await { + Ok(state) => SpState::from(state), Err(SpCommsError::Timeout) => SpState::Unresponsive, - Err(other) => return Err(other.into()), + Err(other) => return Err(http_err_from_comms_err(other)), } } else { SpState::Disabled @@ -515,9 +490,11 @@ async fn sp_component_serial_console_get( let comms = &rqctx.context().sp_comms; let PathSpComponent { sp, component } = path.into_inner(); - let component = SpComponent::try_from(component.as_str()) - .map_err(|_| Error::InvalidSpComponentId(component))?; - let contents = comms.serial_console_get(sp.into(), &component)?; + let component = component_from_str(&component)?; + let contents = comms + .serial_console_contents(sp.into(), &component) + .map_err(http_err_from_comms_err)? + .map(SerialConsoleContents::from); // TODO With `unwrap_or_default()`, our caller can't tell the difference // between "this component hasn't sent us any console information yet" and @@ -546,25 +523,26 @@ async fn sp_component_serial_console_post( let comms = &apictx.sp_comms; let PathSpComponent { sp, component } = path.into_inner(); - let component = SpComponent::try_from(component.as_str()) - .map_err(|_| Error::InvalidSpComponentId(component))?; - - // TODO What is our recourse if we hit a timeout here? We don't know whether - // the SP received none, some, or all of the data we sent, only that it - // failed to ack (at least) the last packet in time. Hopefully the user can - // manually detect this by inspecting the serial console output from this - // component (if whatever they sent triggers some kind of output)? But maybe - // we should try to do a little better - if we had to packetize `data`, for - // example, we could at least report how much data was ack'd and how much we - // sent that hasn't been ack'd yet? + let component = component_from_str(&component)?; + + // TODO What is our recourse if we hit a timeout here (or some other kind of + // error partway through)? We don't know whether the SP received none, some, + // or all of the data we sent, only that it failed to ack (at least) the + // last packet in time. Hopefully the user can manually detect this by + // inspecting the serial console output from this component (if whatever + // they sent triggers some kind of output)? But maybe we should try to do a + // little better - if we had to packetize `data`, for example, we could at + // least report how much data was ack'd and how much we sent that hasn't + // been ack'd yet? comms - .serial_console_post( + .serial_console_send( sp.into(), - component, + &component, data.as_bytes(), - apictx.timeouts.sp_request, + Instant::now() + apictx.timeouts.sp_request, ) - .await?; + .await + .map_err(http_err_from_comms_err)?; Ok(HttpResponseUpdatedNoContent {}) } @@ -641,15 +619,16 @@ async fn ignition_list( let apictx = rqctx.context(); let sp_comms = &apictx.sp_comms; - let all_state = - sp_comms.bulk_ignition_get(apictx.timeouts.ignition_controller).await?; + let all_state = sp_comms + .get_ignition_state_all( + Instant::now() + apictx.timeouts.ignition_controller, + ) + .await + .map_err(http_err_from_comms_err)?; let mut out = Vec::with_capacity(all_state.len()); - for (port, state) in all_state { - out.push(SpIgnitionInfo { - id: sp_comms.port_to_id(port).into(), - details: state.into(), - }); + for (id, state) in all_state { + out.push(SpIgnitionInfo { id: id.into(), details: state.into() }); } Ok(HttpResponseOk(out)) } @@ -672,8 +651,12 @@ async fn ignition_get( let state = apictx .sp_comms - .ignition_get(sp.into(), apictx.timeouts.ignition_controller) - .await?; + .get_ignition_state( + sp.into(), + Instant::now() + apictx.timeouts.ignition_controller, + ) + .await + .map_err(http_err_from_comms_err)?; let info = SpIgnitionInfo { id: sp, details: state.into() }; Ok(HttpResponseOk(info)) @@ -693,8 +676,13 @@ async fn ignition_power_on( apictx .sp_comms - .ignition_power_on(sp.into(), apictx.timeouts.ignition_controller) - .await?; + .send_ignition_command( + sp.into(), + IgnitionCommand::PowerOn, + Instant::now() + apictx.timeouts.ignition_controller, + ) + .await + .map_err(http_err_from_comms_err)?; Ok(HttpResponseUpdatedNoContent {}) } @@ -713,8 +701,13 @@ async fn ignition_power_off( apictx .sp_comms - .ignition_power_off(sp.into(), apictx.timeouts.ignition_controller) - .await?; + .send_ignition_command( + sp.into(), + IgnitionCommand::PowerOff, + Instant::now() + apictx.timeouts.ignition_controller, + ) + .await + .map_err(http_err_from_comms_err)?; Ok(HttpResponseUpdatedNoContent {}) } diff --git a/gateway/src/http_entrypoints/conversions.rs b/gateway/src/http_entrypoints/conversions.rs new file mode 100644 index 0000000000..6f664a6f3c --- /dev/null +++ b/gateway/src/http_entrypoints/conversions.rs @@ -0,0 +1,120 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +// Copyright 2022 Oxide Computer Company + +//! Conversions between externally-defined types and HTTP / JsonSchema types. + +use super::SerialConsoleChunk; +use super::SerialConsoleContents; +use super::SpIdentifier; +use super::SpIgnition; +use super::SpState; +use super::SpType; +use dropshot::HttpError; +use gateway_messages::IgnitionFlags; +use gateway_messages::SpComponent; + +// wrap `SpComponent::try_from(&str)` into a usable form for dropshot endpoints +pub(super) fn component_from_str(s: &str) -> Result { + SpComponent::try_from(s).map_err(|_| { + HttpError::for_bad_request( + Some("InvalidSpComponent".to_string()), + "invalid SP component name".to_string(), + ) + }) +} + +impl From for SpState { + fn from(state: gateway_messages::SpState) -> Self { + Self::Enabled { serial_number: hex::encode(&state.serial_number[..]) } + } +} + +impl From for SpIgnition { + fn from(state: gateway_messages::IgnitionState) -> Self { + // if we have a state, the SP was present + Self::Present { + id: state.id, + power: state.flags.intersects(IgnitionFlags::POWER), + ctrl_detect_0: state.flags.intersects(IgnitionFlags::CTRL_DETECT_0), + ctrl_detect_1: state.flags.intersects(IgnitionFlags::CTRL_DETECT_1), + flt_a3: state.flags.intersects(IgnitionFlags::FLT_A3), + flt_a2: state.flags.intersects(IgnitionFlags::FLT_A2), + flt_rot: state.flags.intersects(IgnitionFlags::FLT_ROT), + flt_sp: state.flags.intersects(IgnitionFlags::FLT_SP), + } + } +} + +impl From for gateway_sp_comms::SpType { + fn from(typ: SpType) -> Self { + match typ { + SpType::Sled => Self::Sled, + SpType::Power => Self::Power, + SpType::Switch => Self::Switch, + } + } +} + +impl From for SpType { + fn from(typ: gateway_sp_comms::SpType) -> Self { + match typ { + gateway_sp_comms::SpType::Sled => Self::Sled, + gateway_sp_comms::SpType::Power => Self::Power, + gateway_sp_comms::SpType::Switch => Self::Switch, + } + } +} + +impl From for gateway_sp_comms::SpIdentifier { + fn from(id: SpIdentifier) -> Self { + Self { + typ: id.typ.into(), + // id.slot may come from an untrusted source, but usize >= 32 bits + // on any platform that will run this code, so unwrap is fine + slot: usize::try_from(id.slot).unwrap(), + } + } +} + +impl From for SpIdentifier { + fn from(id: gateway_sp_comms::SpIdentifier) -> Self { + Self { + typ: id.typ.into(), + // id.slot comes from a trusted source (gateway_sp_comms) and will + // not exceed u32::MAX + slot: u32::try_from(id.slot).unwrap(), + } + } +} + +impl From for SerialConsoleChunk { + fn from(chunk: gateway_sp_comms::SerialConsoleChunk) -> Self { + match chunk { + gateway_sp_comms::SerialConsoleChunk::Data { bytes } => { + Self::Data { bytes } + } + gateway_sp_comms::SerialConsoleChunk::Missing { len } => { + Self::Missing { len } + } + } + } +} + +impl From for SerialConsoleContents { + fn from(contents: gateway_sp_comms::SerialConsoleContents) -> Self { + // TODO this is awkward and (probably?) not free; I tried avoiding this + // with serde's remote derive support, but couldn't get something that + // satisfied all three of serde, schemars, and openapi. + Self { + start: contents.start, + chunks: contents + .chunks + .into_iter() + .map(SerialConsoleChunk::from) + .collect(), + } + } +} diff --git a/gateway/src/lib.rs b/gateway/src/lib.rs index 604aa13717..0977d6ccf8 100644 --- a/gateway/src/lib.rs +++ b/gateway/src/lib.rs @@ -2,14 +2,16 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at https://mozilla.org/MPL/2.0/. +mod bulk_state_get; mod config; mod context; mod error; + pub mod http_entrypoints; // TODO pub only for testing - is this right? -mod sp_comms; pub use config::Config; pub use context::ServerContext; + use slog::{debug, error, info, o, Logger}; use std::sync::Arc; use uuid::Uuid; diff --git a/gateway/src/sp_comms.rs b/gateway/src/sp_comms.rs deleted file mode 100644 index 4ae57e872a..0000000000 --- a/gateway/src/sp_comms.rs +++ /dev/null @@ -1,737 +0,0 @@ -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at https://mozilla.org/MPL/2.0/. - -// TODO: This may need to move in part or in its entirety to a separate crate -// for reuse by RSS, once that exists. - -//! Inteface for communicating with SPs over UDP on the management network. - -mod bulk_state_get; -mod serial_console_history; - -pub(crate) use self::bulk_state_get::BulkSpStateSingleResult; -pub(crate) use self::bulk_state_get::BulkStateProgress; -pub(crate) use self::bulk_state_get::SpStateRequestId; -pub(crate) use self::serial_console_history::SerialConsoleContents; - -use self::bulk_state_get::OutstandingSpStateRequests; -use self::serial_console_history::SerialConsoleHistory; - -use crate::http_entrypoints::SpState; -use dropshot::HttpError; -use gateway_messages::sp_impl::SerialConsolePacketizer; -use gateway_messages::version; -use gateway_messages::IgnitionCommand; -use gateway_messages::IgnitionState; -use gateway_messages::Request; -use gateway_messages::RequestKind; -use gateway_messages::ResponseError; -use gateway_messages::ResponseKind; -use gateway_messages::SerialConsole; -use gateway_messages::SerializedSize; -use gateway_messages::SpComponent; -use gateway_messages::SpMessage; -use gateway_messages::SpMessageKind; -use gateway_sp_comms::error::StartupError; -use gateway_sp_comms::KnownSps; -use gateway_sp_comms::ManagementSwitch; -use gateway_sp_comms::ManagementSwitchDiscovery; -use gateway_sp_comms::SpIdentifier; -use gateway_sp_comms::SpSocket; -use gateway_sp_comms::SwitchPort; -use slog::debug; -use slog::error; -use slog::info; -use slog::o; -use slog::Logger; -use std::collections::HashMap; -use std::io; -use std::net::SocketAddr; -use std::sync::atomic::AtomicU32; -use std::sync::Arc; -use std::sync::Mutex; -use std::time::Duration; -use thiserror::Error; -use tokio::sync::oneshot; -use tokio::sync::oneshot::error::RecvError; -use tokio::sync::oneshot::Receiver; -use tokio::sync::oneshot::Sender; -use tokio::time::Instant; - -// TODO This has some duplication with `gateway::error::Error`. This might be -// right if these comms are going to move to their own crate, but at the moment -// it's confusing. For now we'll keep them separate, but maybe we should split -// this module out into its own crate sooner rather than later. -#[derive(Debug, Clone, Error)] -pub enum Error { - // ---- - // internal errors - // ---- - // `Error` needs to be `Clone` because we hold onto `Result` - // in memory during bulk "get all SP state" requests, and then we clone - // those results to give to any clients that want info for that request. - // `io::Error` isn't `Clone`, so we wrap it in an `Arc`. - #[error("error sending to UDP address {addr}: {err}")] - UdpSend { addr: SocketAddr, err: Arc }, - #[error( - "SP sent a bogus response type (got `{got}`; expected `{expected}`)" - )] - BogusResponseType { got: &'static str, expected: &'static str }, - #[error("error from SP: {0}")] - SpError(#[from] ResponseError), - #[error("timeout")] - Timeout, - #[error("received ignition target for an unknown port ({0})")] - UnknownIgnitionTargetPort(usize), - #[error("unknown SP destination address for {0:?}")] - UnknownSpAddress(SwitchPort), - #[error("nonexistent SP {0:?}")] - SpDoesNotExist(SpIdentifier), - - // ---- - // client errors - // ---- - #[error("invalid page token (no such request)")] - NoSuchRequest, - #[error("invalid page token (invalid last SP seen)")] - InvalidLastSpSeen, -} - -impl From for Error { - fn from(_: tokio::time::error::Elapsed) -> Self { - Self::Timeout - } -} - -impl From for HttpError { - fn from(err: Error) -> Self { - match err { - Error::NoSuchRequest => HttpError::for_bad_request( - Some(String::from("NoSuchRequest")), - err.to_string(), - ), - Error::InvalidLastSpSeen => HttpError::for_bad_request( - Some(String::from("InvalidLastSpSeen")), - err.to_string(), - ), - Error::UdpSend { .. } - | Error::BogusResponseType { .. } - | Error::SpError(_) - | Error::Timeout - | Error::UnknownIgnitionTargetPort(_) - | Error::UnknownSpAddress(_) - | Error::SpDoesNotExist(_) => { - HttpError::for_internal_error(err.to_string()) - } - } - } -} - -impl Error { - fn from_unhandled_response_kind( - kind: &ResponseKind, - expected: &'static str, - ) -> Self { - let got = match kind { - ResponseKind::Pong => response_kind_names::PONG, - ResponseKind::IgnitionState(_) => { - response_kind_names::IGNITION_STATE - } - ResponseKind::BulkIgnitionState(_) => { - response_kind_names::BULK_IGNITION_STATE - } - ResponseKind::IgnitionCommandAck => { - response_kind_names::IGNITION_COMMAND_ACK - } - ResponseKind::SpState(_) => response_kind_names::SP_STATE, - ResponseKind::SerialConsoleWriteAck => { - response_kind_names::SERIAL_CONSOLE_WRITE_ACK - } - }; - Self::BogusResponseType { got, expected } - } -} - -// helper constants mapping SP response kinds to stringy names for error -// messages -mod response_kind_names { - pub(super) const PONG: &str = "pong"; - pub(super) const IGNITION_STATE: &str = "ignition_state"; - pub(super) const BULK_IGNITION_STATE: &str = "bulk_ignition_state"; - pub(super) const IGNITION_COMMAND_ACK: &str = "ignition_command_ack"; - pub(super) const SP_STATE: &str = "sp_state"; - pub(super) const SERIAL_CONSOLE_WRITE_ACK: &str = - "serial_console_write_ack"; -} - -#[derive(Debug)] -pub struct SpCommunicator { - log: Logger, - switch: ManagementSwitch, - sp_state: Arc, - bulk_state_requests: OutstandingSpStateRequests, - request_id: AtomicU32, -} - -impl SpCommunicator { - pub async fn new( - known_sps: KnownSps, - log: &Logger, - ) -> Result { - let log = log.new(o!("componennt" => "SpCommunicator")); - let discovery = ManagementSwitchDiscovery::placeholder_start( - known_sps, - log.clone(), - ) - .await?; - - // build our map of switch ports to state-of-the-SP-on-that-port - let sp_state = Arc::new(AllSpState::new(&discovery)); - - let recv_handler = RecvHandler::new(Arc::clone(&sp_state), log.clone()); - let switch = discovery - .start_recv_task(move |port, data| recv_handler.handle(port, data)); - - info!(&log, "started SP communicator"); - Ok(Self { - log, - switch, - sp_state, - bulk_state_requests: OutstandingSpStateRequests::default(), - request_id: AtomicU32::new(0), - }) - } - - fn id_to_port(&self, sp: SpIdentifier) -> Result { - self.switch.switch_port(sp).ok_or(Error::SpDoesNotExist(sp)) - } - - pub(crate) fn port_to_id(&self, port: SwitchPort) -> SpIdentifier { - self.switch.switch_port_to_id(port) - } - - pub(crate) async fn state_get( - &self, - sp: SpIdentifier, - timeout: Instant, - ) -> Result { - self.state_get_by_port(self.id_to_port(sp)?, timeout).await - } - - async fn state_get_by_port( - &self, - port: SwitchPort, - timeout: Instant, - ) -> Result { - tokio::time::timeout_at(timeout, self.state_get_impl(port)).await? - } - - pub(crate) fn bulk_state_start( - self: &Arc, - timeout: Instant, - retain_grace_period: Duration, - sps: impl Iterator, - ) -> SpStateRequestId { - self.bulk_state_requests.start(timeout, retain_grace_period, sps, self) - } - - pub(crate) async fn bulk_state_get_progress( - &self, - id: &SpStateRequestId, - last_seen: Option, - timeout: Duration, - limit: usize, - ) -> Result { - let last_seen = match last_seen { - Some(sp) => Some(self.id_to_port(sp)?), - None => None, - }; - self.bulk_state_requests - .get(id, last_seen, timeout, limit, &self.log) - .await - } - - async fn state_get_impl(&self, port: SwitchPort) -> Result { - let sp = - self.switch.sp_socket(port).ok_or(Error::UnknownSpAddress(port))?; - - match self.request_response(&sp, RequestKind::SpState).await? { - ResponseKind::SpState(state) => Ok(SpState::Enabled { - serial_number: hex::encode(&state.serial_number[..]), - }), - other => Err(Error::from_unhandled_response_kind( - &other, - response_kind_names::SP_STATE, - )), - } - } - - pub(crate) fn serial_console_get( - &self, - sp: SpIdentifier, - component: &SpComponent, - ) -> Result, Error> { - let state = self.sp_state.get(self.id_to_port(sp)?); - Ok(state.serial_console_from_sp.lock().unwrap().contents(component)) - } - - pub(crate) async fn serial_console_post( - &self, - sp: SpIdentifier, - component: SpComponent, - data: &[u8], - timeout: Duration, - ) -> Result<(), Error> { - tokio::time::timeout( - timeout, - self.serial_console_post_impl(sp, component, data), - ) - .await? - } - - async fn serial_console_post_impl( - &self, - sp: SpIdentifier, - component: SpComponent, - data: &[u8], - ) -> Result<(), Error> { - let sp_port = self.id_to_port(sp)?; - let sp_state = self.sp_state.get(self.id_to_port(sp)?); - - let sp_sock = self - .switch - .sp_socket(sp_port) - .ok_or(Error::UnknownSpAddress(sp_port))?; - - let mut packetizers = sp_state.serial_console_to_sp.lock().await; - let packetizer = packetizers - .entry(component) - .or_insert_with(|| SerialConsolePacketizer::new(component)); - - for packet in packetizer.packetize(data) { - let request = RequestKind::SerialConsoleWrite(packet); - match self.request_response(&sp_sock, request).await? { - ResponseKind::SerialConsoleWriteAck => (), - other => { - return Err(Error::from_unhandled_response_kind( - &other, - response_kind_names::SERIAL_CONSOLE_WRITE_ACK, - )) - } - } - } - - Ok(()) - } - - pub async fn bulk_ignition_get( - &self, - timeout: Duration, - ) -> Result, Error> { - tokio::time::timeout(timeout, self.bulk_ignition_get_impl()).await? - } - - async fn bulk_ignition_get_impl( - &self, - ) -> Result, Error> { - let controller = self.switch.ignition_controller(); - let request = RequestKind::BulkIgnitionState; - - match self.request_response(&controller, request).await? { - ResponseKind::BulkIgnitionState(state) => { - let mut results = Vec::new(); - for (i, state) in state.targets - [..usize::from(state.num_targets)] - .iter() - .copied() - .enumerate() - { - let port = self - .switch - .switch_port_from_ignition_target(i) - .ok_or_else(|| Error::UnknownIgnitionTargetPort(i))?; - results.push((port, state)); - } - Ok(results) - } - other => { - return Err(Error::from_unhandled_response_kind( - &other, - response_kind_names::BULK_IGNITION_STATE, - )) - } - } - } - - // How do we want to describe ignition targets? Currently we want to - // send a u8 in the UDP message, so just take that for now. - pub async fn ignition_get( - &self, - sp: SpIdentifier, - timeout: Duration, - ) -> Result { - tokio::time::timeout(timeout, self.ignition_get_impl(sp)).await? - } - - async fn ignition_get_impl( - &self, - sp: SpIdentifier, - ) -> Result { - let controller = self.switch.ignition_controller(); - let port = self.id_to_port(sp)?; - let request = - RequestKind::IgnitionState { target: port.as_ignition_target() }; - - match self.request_response(&controller, request).await? { - ResponseKind::IgnitionState(state) => Ok(state), - other => { - return Err(Error::from_unhandled_response_kind( - &other, - response_kind_names::IGNITION_STATE, - )) - } - } - } - - pub async fn ignition_power_on( - &self, - sp: SpIdentifier, - timeout: Duration, - ) -> Result<(), Error> { - tokio::time::timeout( - timeout, - self.ignition_command(sp, IgnitionCommand::PowerOn), - ) - .await? - } - - pub async fn ignition_power_off( - &self, - sp: SpIdentifier, - timeout: Duration, - ) -> Result<(), Error> { - tokio::time::timeout( - timeout, - self.ignition_command(sp, IgnitionCommand::PowerOff), - ) - .await? - } - - async fn ignition_command( - &self, - sp: SpIdentifier, - command: IgnitionCommand, - ) -> Result<(), Error> { - let controller = self.switch.ignition_controller(); - let port = self.id_to_port(sp)?; - let request = RequestKind::IgnitionCommand { - target: port.as_ignition_target(), - command, - }; - - match self.request_response(&controller, request).await? { - ResponseKind::IgnitionCommandAck => Ok(()), - other => { - return Err(Error::from_unhandled_response_kind( - &other, - response_kind_names::IGNITION_COMMAND_ACK, - )) - } - } - } - - async fn request_response( - &self, - sp: &SpSocket<'_>, - request: RequestKind, - ) -> Result { - // request IDs will eventually roll over; since we enforce timeouts - // this should be a non-issue in practice. does this need testing? - let request_id = - self.request_id.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - - // tell our background receiver to expect a response to this request - let response = - self.sp_state.insert_expected_response(sp.port(), request_id); - - // Serialize and send our request. We know `buf` is large enough for any - // `Request`, so unwrapping here is fine. - let request = - Request { version: version::V1, request_id, kind: request }; - let mut buf = [0; Request::MAX_SIZE]; - let n = gateway_messages::serialize(&mut buf, &request).unwrap(); - - let serialized_request = &buf[..n]; - debug!(&self.log, "sending {:?} to SP {:?}", request, sp); - sp.send(serialized_request).await.map_err(|err| Error::UdpSend { - addr: sp.addr(), - err: Arc::new(err), - })?; - - // recv() can only fail if the sender is dropped, but we're holding it - // in `self.outstanding_requests`; unwrap() is fine. - Ok(response.recv().await.unwrap()?) - } -} - -/// Handler for incoming packets received on management switch ports. -/// -/// When the communicator wants to send a request on behalf of an HTTP request: -/// -/// 1. `SpCommunicator` creates a tokio oneshot channel for this handler to use -/// to send the response. -/// 2. `SpCommunicator` inserts the sending half of that channel into -/// `outstanding_requests`, which is keyed by both the switch port and the -/// u32 ID attached to the request. -/// 3. `SpCommunicator` sends the UDP packet containing the request to the -/// target SP, and waits for a response on the channel it created in 1. -/// 4. When we receive a packet, we check: -/// a. Does it parse as a `Response`? -/// b. Is there a corresponding entry in `outstanding_requests` for this port -/// + request ID? -/// If so, we send the response on the channel, which unblocks -/// `SpCommunicator`, who can now return the response to its caller. -/// -/// If a timeout or other error occurs between step 2 and the end of step 4, the -/// `ResponseReceiver` wrapper below is responsible for cleaning up the entry in -/// `outstanding_requests` (via its `Drop` impl). -/// -/// We can also receive messages from SPs that are not responses to oustanding -/// requests. These are handled on a case-by-case basis; e.g., serial console -/// data is pushed into the in-memory ringbuffer corresponding to the source. -struct RecvHandler { - sp_state: Arc, - log: Logger, -} - -impl RecvHandler { - fn new(sp_state: Arc, log: Logger) -> Self { - Self { sp_state, log } - } - - fn handle(&self, port: SwitchPort, buf: &[u8]) { - debug!(&self.log, "received {} bytes from {:?}", buf.len(), port); - - // parse into an `SpMessage` - let sp_msg = match gateway_messages::deserialize::(buf) { - Ok((msg, _extra)) => { - // TODO should we check that `extra` is empty? if the - // response is maximal size any extra data is silently - // discarded anyway, so probably not? - msg - } - Err(err) => { - error!(&self.log, "discarding malformed message ({})", err); - return; - } - }; - debug!(&self.log, "received {:?} from {:?}", sp_msg, port); - - // `version` is intentionally the first 4 bytes of the packet; we - // could check it before trying to deserialize? - if sp_msg.version != version::V1 { - error!( - &self.log, - "discarding message with unsupported version {}", - sp_msg.version - ); - return; - } - - // decide whether this is a response to an outstanding request or an - // unprompted message - match sp_msg.kind { - SpMessageKind::Response { request_id, result } => { - self.handle_response(port, request_id, result); - } - SpMessageKind::SerialConsole(serial_console) => { - self.handle_serial_console(port, serial_console); - } - } - } - - fn handle_response( - &self, - port: SwitchPort, - request_id: u32, - result: Result, - ) { - // see if we know who to send the response to - let tx = match self.sp_state.remove_expected_response(port, request_id) - { - Some(tx) => tx, - None => { - error!( - &self.log, - "discarding unexpected response {} from {:?} (possibly past timeout?)", - request_id, - port, - ); - return; - } - }; - - // actually send it - if tx.send(result).is_err() { - // This can only fail if the receiving half has been dropped. - // That's held in the relevant `SpCommunicator` method above - // that initiated this request; they should only have dropped - // the rx half if they've been dropped (in which case we've been - // aborted and can't get here) or if we landed in a race where - // the `SpCommunicator` task was cancelled (presumably by - // timeout) in between us pulling `tx` out of - // `outstanding_requests` and actually sending the response on - // it. But that window does exist, so log when we fail to send. - // I believe these should be interpreted as timeout failures; - // most of the time failing to get a `tx` at all (above) is also - // caused by a timeout, but that path is also invoked if we get - // a garbage response somehow. - error!( - &self.log, - "discarding unexpected response {} from {:?} (receiver gone)", - request_id, - port, - ); - } - } - - fn handle_serial_console(&self, port: SwitchPort, packet: SerialConsole) { - debug!( - &self.log, - "received serial console data from {:?}: {:?}", port, packet - ); - self.sp_state.push_serial_console(port, packet, &self.log); - } -} - -#[derive(Debug)] -struct AllSpState { - all_sps: HashMap, -} - -impl AllSpState { - fn new(switch: &ManagementSwitchDiscovery) -> Self { - let all_ports = switch.all_ports(); - let mut all_sps = HashMap::with_capacity(all_ports.len()); - for port in all_ports { - all_sps.insert(port, SingleSpState::default()); - } - Self { all_sps } - } - - fn get(&self, port: SwitchPort) -> &SingleSpState { - // we initialize `all_sps` with state for every switch port, so the only - // way to panic here is to construct a port that didn't exist when we - // were created - self.all_sps - .get(&port) - .expect("sp state doesn't contain an entry for a valid switch port") - } - - fn push_serial_console( - &self, - port: SwitchPort, - packet: SerialConsole, - log: &Logger, - ) { - let state = self.get(port); - state.serial_console_from_sp.lock().unwrap().push(packet, log); - } - - fn insert_expected_response( - &self, - port: SwitchPort, - request_id: u32, - ) -> ResponseReceiver { - let state = self.get(port); - state.outstanding_requests.insert(request_id) - } - - fn remove_expected_response( - &self, - port: SwitchPort, - request_id: u32, - ) -> Option>> { - let state = self.get(port); - state.outstanding_requests.remove(request_id) - } -} - -#[derive(Debug, Default)] -struct SingleSpState { - // map of requests we're waiting to receive - outstanding_requests: Arc, - // ringbuffer of serial console data from the SP - serial_console_from_sp: Mutex, - // counter of bytes we've sent per SP component; we want to hold this mutex - // across await points as we packetize data, so we have to use a tokio mutex - // here instead of a `std::sync::Mutex` - serial_console_to_sp: - tokio::sync::Mutex>, -} - -#[derive(Debug, Default)] -struct OutstandingRequests { - // map of request ID -> receiving oneshot channel - requests: Mutex>>>, -} - -impl OutstandingRequests { - fn insert(self: &Arc, request_id: u32) -> ResponseReceiver { - let (tx, rx) = oneshot::channel(); - self.requests.lock().unwrap().insert(request_id, tx); - - ResponseReceiver { - parent: Arc::clone(self), - request_id, - rx, - removed_from_parent: false, - } - } - - fn remove( - &self, - request_id: u32, - ) -> Option>> { - self.requests.lock().unwrap().remove(&request_id) - } -} - -// Wrapper around a tokio oneshot receiver that removes itself from its parent -// `OutstandingRequests` either when the message is received (happy path) or -// we're dropped (due to either a timeout or some other kind of -// error/cancellation) -struct ResponseReceiver { - parent: Arc, - request_id: u32, - rx: Receiver>, - removed_from_parent: bool, -} - -impl Drop for ResponseReceiver { - fn drop(&mut self) { - self.remove_from_parent(); - } -} - -impl ResponseReceiver { - async fn recv( - mut self, - ) -> Result, RecvError> { - let result = (&mut self.rx).await; - self.remove_from_parent(); - result - } - - fn remove_from_parent(&mut self) { - // we're unconditionally called from our `Drop` impl, but if we already - // successfully received a response we already removed ourselves - if self.removed_from_parent { - return; - } - - self.parent.requests.lock().unwrap().remove(&self.request_id); - self.removed_from_parent = true; - } -}