Skip to content

Commit

Permalink
fix(storage): Updates scalers to do a single data lookup per event (#211
Browse files Browse the repository at this point in the history
)

This works by looking up all the data once per event (whenever scalers
are called). This reduces lookups of links by at least 40% and should
reduce data usage as well (though not as much as doing some sort of
caching would)

Fixes #203

Signed-off-by: Taylor Thomas <[email protected]>
  • Loading branch information
thomastaylor312 authored Nov 2, 2023
1 parent 0354619 commit 54ee1a4
Show file tree
Hide file tree
Showing 4 changed files with 204 additions and 23 deletions.
58 changes: 36 additions & 22 deletions src/scaler/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use crate::{
publisher::Publisher,
scaler::{spreadscaler::ActorSpreadScaler, Command, Scaler},
server::StatusInfo,
storage::ReadStore,
storage::{snapshot::SnapshotStore, ReadStore},
workers::{CommandPublisher, LinkSource, StatusPublisher},
DEFAULT_LINK_NAME,
};
Expand Down Expand Up @@ -115,10 +115,9 @@ pub struct ScalerManager<StateStore, P: Clone, L: Clone> {
client: P,
subject: String,
lattice_id: String,
state_store: StateStore,
command_publisher: CommandPublisher<P>,
status_publisher: StatusPublisher<P>,
link_getter: L,
snapshot_data: SnapshotStore<StateStore, L>,
}

impl<StateStore, P: Clone, L: Clone> Drop for ScalerManager<StateStore, P, L> {
Expand Down Expand Up @@ -187,35 +186,39 @@ where
.filter_map(|manifest| manifest.transpose())
.map(|res| res.map(|(manifest, _)| manifest))
.collect::<Result<Vec<_>>>()?;
let snapshot_data = SnapshotStore::new(
state_store.clone(),
link_getter.clone(),
lattice_id.to_owned(),
);
let scalers: HashMap<String, ScalerList> = all_manifests
.into_iter()
.filter_map(|manifest| {
let data = manifest.get_deployed()?;
let name = manifest.name().to_owned();
let scalers = components_to_scalers(
&data.spec.components,
&state_store,
lattice_id,
&client,
&name,
&subject,
&link_getter,
&snapshot_data,
);
Some((name, scalers))
})
.collect();

let scalers = Arc::new(RwLock::new(scalers));

let mut manager = ScalerManager {
handle: None,
scalers,
client,
subject,
lattice_id: lattice_id.to_owned(),
state_store,
command_publisher,
status_publisher,
link_getter,
snapshot_data,
};
let cloned = manager.clone();
let handle = tokio::spawn(async move { cloned.notify(messages).await });
Expand All @@ -234,19 +237,29 @@ where
status_publisher: StatusPublisher<P>,
link_getter: L,
) -> ScalerManager<StateStore, P, L> {
let snapshot_data = SnapshotStore::new(
state_store.clone(),
link_getter.clone(),
lattice_id.to_owned(),
);
ScalerManager {
handle: None,
scalers: Arc::new(RwLock::new(HashMap::new())),
client,
subject: format!("{WADM_NOTIFY_PREFIX}.{lattice_id}"),
lattice_id: lattice_id.to_owned(),
state_store,
command_publisher,
status_publisher,
link_getter,
snapshot_data,
}
}

/// Refreshes the snapshot data consumed by all scalers. This is a temporary workaround until we
/// start caching data
pub(crate) async fn refresh_data(&self) -> Result<()> {
self.snapshot_data.refresh().await
}

/// Adds scalers for the given manifest. Emitting an event to notify other wadm processes that
/// they should create them as well. Only returns an error if it can't notify. Returns the
/// scaler list for immediate use in reconciliation
Expand Down Expand Up @@ -275,12 +288,11 @@ where
pub fn scalers_for_manifest<'a>(&'a self, manifest: &'a Manifest) -> ScalerList {
components_to_scalers(
&manifest.spec.components,
&self.state_store,
&self.lattice_id,
&self.client,
&manifest.metadata.name,
&self.subject,
&self.link_getter,
&self.snapshot_data,
)
}

Expand Down Expand Up @@ -370,6 +382,10 @@ where
/// Does everything except sending the notification
#[instrument(level = "debug", skip(self), fields(lattice_id = %self.lattice_id))]
async fn remove_scalers_internal(&self, name: &str) -> Option<Result<ScalerList>> {
// Always refresh data before removing
if let Err(e) = self.refresh_data().await {
return Some(Err(e));
}
let scalers = self.remove_raw_scalers(name).await?;
let commands = match futures::future::join_all(
scalers.iter().map(|scaler| scaler.cleanup()),
Expand Down Expand Up @@ -417,12 +433,11 @@ where
// We don't want to trigger the notification, so just create the scalers and then insert
let scalers = components_to_scalers(
&manifest.spec.components,
&self.state_store,
&self.lattice_id,
&self.client,
&manifest.metadata.name,
&self.subject,
&self.link_getter,
&self.snapshot_data,
);
let num_scalers = scalers.len();
self.add_raw_scalers(&manifest.metadata.name, scalers).await;
Expand Down Expand Up @@ -547,12 +562,11 @@ const EMPTY_TRAIT_VEC: Vec<Trait> = Vec::new();
/// * `name` - The name of the manifest that the scalers are being created for
pub(crate) fn components_to_scalers<S, P, L>(
components: &[Component],
store: &S,
lattice_id: &str,
notifier: &P,
name: &str,
notifier_subject: &str,
link_getter: &L,
snapshot_data: &SnapshotStore<S, L>,
) -> ScalerList
where
S: ReadStore + Send + Sync + Clone + 'static,
Expand All @@ -569,7 +583,7 @@ where
(SPREADSCALER_TRAIT, TraitProperty::SpreadScaler(p)) => {
Some(Box::new(BackoffAwareScaler::new(
ActorSpreadScaler::new(
store.clone(),
snapshot_data.clone(),
props.image.to_owned(),
lattice_id.to_owned(),
name.to_owned(),
Expand All @@ -585,7 +599,7 @@ where
(DAEMONSCALER_TRAIT, TraitProperty::SpreadScaler(p)) => {
Some(Box::new(BackoffAwareScaler::new(
ActorDaemonScaler::new(
store.clone(),
snapshot_data.clone(),
props.image.to_owned(),
lattice_id.to_owned(),
name.to_owned(),
Expand All @@ -607,15 +621,15 @@ where
{
Some(Box::new(BackoffAwareScaler::new(
LinkScaler::new(
store.clone(),
snapshot_data.clone(),
props.image.to_owned(),
cappy.image.to_owned(),
cappy.contract.to_owned(),
cappy.link_name.to_owned(),
lattice_id.to_owned(),
name.to_owned(),
p.values.to_owned(),
link_getter.clone(),
snapshot_data.clone(),
),
notifier.to_owned(),
notifier_subject,
Expand All @@ -638,7 +652,7 @@ where
(SPREADSCALER_TRAIT, TraitProperty::SpreadScaler(p)) => {
Some(Box::new(BackoffAwareScaler::new(
ProviderSpreadScaler::new(
store.clone(),
snapshot_data.clone(),
ProviderSpreadConfig {
lattice_id: lattice_id.to_owned(),
provider_reference: props.image.to_owned(),
Expand All @@ -664,7 +678,7 @@ where
(DAEMONSCALER_TRAIT, TraitProperty::SpreadScaler(p)) => {
Some(Box::new(BackoffAwareScaler::new(
ProviderDaemonScaler::new(
store.clone(),
snapshot_data.clone(),
ProviderSpreadConfig {
lattice_id: lattice_id.to_owned(),
provider_reference: props.image.to_owned(),
Expand Down Expand Up @@ -695,7 +709,7 @@ where
// Allow providers to omit the scaler entirely for simplicity
scalers.push(Box::new(BackoffAwareScaler::new(
ProviderSpreadScaler::new(
store.clone(),
snapshot_data.clone(),
ProviderSpreadConfig {
lattice_id: lattice_id.to_owned(),
provider_reference: props.image.to_owned(),
Expand Down
1 change: 1 addition & 0 deletions src/storage/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::{collections::HashMap, ops::Deref};

pub mod nats_kv;
pub mod reaper;
pub(crate) mod snapshot;
mod state;

pub use state::{Actor, Host, Provider, ProviderStatus, WadmActorInfo};
Expand Down
161 changes: 161 additions & 0 deletions src/storage/snapshot.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
use std::collections::HashMap;
use std::sync::Arc;

use tokio::sync::RwLock;
use wasmcloud_control_interface::LinkDefinition;

use crate::storage::{Actor, Host, Provider, ReadStore, StateKind};
use crate::workers::LinkSource;

// NOTE(thomastaylor312): This type is real ugly and we should probably find a better way to
// structure the ReadStore trait so it doesn't have the generic T we have to work around here. This
// is essentially a map of "state kind" -> map of ID to partially serialized state. I did try to
// implement some sort of getter trait but it has to be generic across T
type InMemoryData = HashMap<String, HashMap<String, serde_json::Value>>;

/// A store and claims/links source implementation that contains a static snapshot of the data that
/// can be refreshed periodically. Please note that this is scoped to a specific lattice ID and
/// should be constructed separately for each lattice ID.
///
/// NOTE: This is a temporary workaround until we get a proper caching store in place
pub struct SnapshotStore<S, L> {
store: S,
link_source: L,
lattice_id: String,
stored_state: Arc<RwLock<InMemoryData>>,
links: Arc<RwLock<Vec<LinkDefinition>>>,
}

impl<S, L> Clone for SnapshotStore<S, L>
where
S: Clone,
L: Clone,
{
fn clone(&self) -> Self {
Self {
store: self.store.clone(),
link_source: self.link_source.clone(),
lattice_id: self.lattice_id.clone(),
stored_state: self.stored_state.clone(),
links: self.links.clone(),
}
}
}

impl<S, L> SnapshotStore<S, L>
where
S: ReadStore,
L: LinkSource,
{
/// Creates a new snapshot store that is scoped to the given lattice ID
pub fn new(store: S, link_source: L, lattice_id: String) -> Self {
Self {
store,
link_source,
lattice_id,
stored_state: Default::default(),
links: Arc::new(RwLock::new(Vec::new())),
}
}

/// Refreshes the snapshotted data, returning an error if it couldn't update the data
pub async fn refresh(&self) -> anyhow::Result<()> {
// SAFETY: All of these unwraps are safe because we _just_ deserialized from JSON
let providers = self
.store
.list::<Provider>(&self.lattice_id)
.await?
.into_iter()
.map(|(key, val)| (key, serde_json::to_value(val).unwrap()))
.collect::<HashMap<_, _>>();
let actors = self
.store
.list::<Actor>(&self.lattice_id)
.await?
.into_iter()
.map(|(key, val)| (key, serde_json::to_value(val).unwrap()))
.collect::<HashMap<_, _>>();
let hosts = self
.store
.list::<Host>(&self.lattice_id)
.await?
.into_iter()
.map(|(key, val)| (key, serde_json::to_value(val).unwrap()))
.collect::<HashMap<_, _>>();
let links = self.link_source.get_links().await?;

{
let mut stored_state = self.stored_state.write().await;
stored_state.insert(Provider::KIND.to_owned(), providers);
stored_state.insert(Actor::KIND.to_owned(), actors);
stored_state.insert(Host::KIND.to_owned(), hosts);
}

*self.links.write().await = links;

Ok(())
}
}

#[async_trait::async_trait]
impl<S, L> ReadStore for SnapshotStore<S, L>
where
// NOTE(thomastaylor312): We need this bound so we can pass through the error type.
S: ReadStore + Send + Sync,
L: Send + Sync,
{
type Error = S::Error;

// NOTE(thomastaylor312): See other note about the generic T above, but this is hardcore lolsob
async fn get<T>(&self, _lattice_id: &str, id: &str) -> Result<Option<T>, Self::Error>
where
T: serde::de::DeserializeOwned + StateKind,
{
Ok(self
.stored_state
.read()
.await
.get(T::KIND)
.and_then(|data| {
data.get(id).map(|data| {
serde_json::from_value::<T>(data.clone()).expect(
"Failed to deserialize data from snapshot, this is programmer error",
)
})
}))
}

async fn list<T>(&self, _lattice_id: &str) -> Result<HashMap<String, T>, Self::Error>
where
T: serde::de::DeserializeOwned + StateKind,
{
Ok(self
.stored_state
.read()
.await
.get(T::KIND)
.cloned()
.unwrap_or_default()
.into_iter()
.map(|(key, val)| {
(
key,
serde_json::from_value::<T>(val).expect(
"Failed to deserialize data from snapshot, this is programmer error",
),
)
})
.collect())
}
}

#[async_trait::async_trait]
impl<S, L> LinkSource for SnapshotStore<S, L>
where
S: Send + Sync,
L: Send + Sync,
{
async fn get_links(&self) -> anyhow::Result<Vec<LinkDefinition>> {
Ok(self.links.read().await.clone())
}
}
Loading

0 comments on commit 54ee1a4

Please sign in to comment.