Skip to content

Commit

Permalink
Plumb --samples_per_plugin to Rust data server (#4689)
Browse files Browse the repository at this point in the history
Introduces a `--samples-per-plugin` flag on the Rust data server, which
takes the `--samples_per_plugin` flag from standard TensorBoard's CLI.
At the moment, this change makes the value of `0` interpreted as
"create a reservoir with capacity 0", while traditionally the value would be
interpreted as "create a reservoir with unbounded capacity". This
discrepancy will be addressed separately.

Test plan:
Added unit tests and ran with `cargo test`.
Ran `bazel run tensorboard:dev --define=link_data_server=true -- --load_fast --logdir <logdir> --bind_all --samples_per_plugin=scalars=5,images=0` and observed that scalar
charts that normally show lots of points now only show 5, while images do
not appear at all.

Associated issue: "samples per plugin" item in the Rustboard task list #4422.
  • Loading branch information
psybuzz authored Feb 22, 2021
1 parent 143c3a0 commit 319db49
Show file tree
Hide file tree
Showing 7 changed files with 254 additions and 29 deletions.
10 changes: 8 additions & 2 deletions tensorboard/data/server/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@ limitations under the License.
use clap::Clap;
use log::info;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Instant;

use rustboard_core::cli::dynamic_logdir::DynLogdir;
use rustboard_core::commit::Commit;
use rustboard_core::logdir::LogdirLoader;
use rustboard_core::{cli::dynamic_logdir::DynLogdir, types::PluginSamplingHint};

#[derive(Clap)]
struct Opts {
Expand All @@ -46,7 +47,12 @@ fn main() {

let commit = Commit::new();
let logdir = DynLogdir::new(opts.logdir).expect("DynLogdir::new");
let mut loader = LogdirLoader::new(&commit, logdir, opts.reload_threads.unwrap_or(0));
let mut loader = LogdirLoader::new(
&commit,
logdir,
opts.reload_threads.unwrap_or(0),
Arc::new(PluginSamplingHint::default()),
);
loader.checksum(opts.checksum); // if neither `--[no-]checksum` given, defaults to false

info!("Starting load cycle");
Expand Down
16 changes: 14 additions & 2 deletions tensorboard/data/server/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use std::io::{Read, Write};
use std::net::{IpAddr, SocketAddr};
use std::path::{Path, PathBuf};
use std::str::FromStr;
use std::sync::Arc;
use std::thread;
use std::time::{Duration, Instant};
use tokio::net::TcpListener;
Expand All @@ -32,6 +33,7 @@ use crate::commit::Commit;
use crate::logdir::LogdirLoader;
use crate::proto::tensorboard::data;
use crate::server::DataProviderHandler;
use crate::types::PluginSamplingHint;

use data::tensor_board_data_provider_server::TensorBoardDataProviderServer;

Expand Down Expand Up @@ -114,6 +116,16 @@ struct Opts {
)]
#[allow(unused)]
no_checksum: bool,

/// Set explicit series sampling
///
/// A comma separated list of `plugin_name=num_samples` pairs to explicitly specify how many
/// samples to keep per tag for the specified plugin. For unspecified plugins, series are
/// randomly downsampled to reasonable values to prevent out-of-memory errors in long running
/// jobs. For instance, `--samples_per_plugin=scalars=500,images=0` keeps 500 events in each
/// scalar series and keeps none of the images.
#[clap(long, default_value = "", setting(clap::ArgSettings::AllowEmptyValues))]
samples_per_plugin: PluginSamplingHint,
}

#[derive(Debug, Copy, Clone, PartialEq, Eq)]
Expand Down Expand Up @@ -174,7 +186,7 @@ pub async fn main() -> Result<(), Box<dyn std::error::Error>> {
// Leak the commit object, since the Tonic server must have only 'static references. This only
// leaks the outer commit structure (of constant size), not the pointers to the actual data.
let commit: &'static Commit = Box::leak(Box::new(Commit::new()));

let psh_ref = Arc::new(opts.samples_per_plugin);
thread::Builder::new()
.name("Reloader".to_string())
.spawn({
Expand All @@ -185,7 +197,7 @@ pub async fn main() -> Result<(), Box<dyn std::error::Error>> {
// Create the logdir in the child thread, where no async runtime is active (see
// docs for `DynLogdir::new`).
let logdir = DynLogdir::new(raw_logdir).unwrap_or_else(|| std::process::exit(1));
let mut loader = LogdirLoader::new(commit, logdir, 0);
let mut loader = LogdirLoader::new(commit, logdir, 0, psh_ref);
// Checksum only if `--checksum` given (i.e., off by default).
loader.checksum(checksum);
loop {
Expand Down
28 changes: 21 additions & 7 deletions tensorboard/data/server/logdir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ use rayon::prelude::{IntoParallelIterator, ParallelIterator};
use std::collections::HashMap;
use std::io::{self, Read};
use std::path::PathBuf;
use std::sync::Arc;

use crate::commit::Commit;
use crate::run::RunLoader;
use crate::types::Run;
use crate::types::{PluginSamplingHint, Run};

/// A TensorBoard log directory, with event files organized into runs.
pub trait Logdir {
Expand Down Expand Up @@ -75,6 +76,8 @@ pub struct LogdirLoader<'a, L: Logdir> {
runs: HashMap<Run, RunLoader<<L as Logdir>::File>>,
/// Whether new run loaders should unconditionally verify CRCs (see [`RunLoader::checksum`]).
checksum: bool,
/// A map defining how many samples per plugin to keep.
plugin_sampling_hint: Arc<PluginSamplingHint>,
}

type Discoveries = HashMap<Run, Vec<EventFileBuf>>;
Expand All @@ -94,7 +97,12 @@ where
///
/// If [`rayon::ThreadPoolBuilder::build`] returns an error; should only happen if there is a
/// failure to create threads at the OS level.
pub fn new(commit: &'a Commit, logdir: L, reload_threads: usize) -> Self {
pub fn new(
commit: &'a Commit,
logdir: L,
reload_threads: usize,
plugin_sampling_hint: Arc<PluginSamplingHint>,
) -> Self {
let thread_pool = rayon::ThreadPoolBuilder::new()
.num_threads(reload_threads)
.thread_name(|i| format!("Reloader-{:03}", i))
Expand All @@ -106,6 +114,7 @@ where
logdir,
runs: HashMap::new(),
checksum: true,
plugin_sampling_hint,
}
}

Expand Down Expand Up @@ -176,8 +185,9 @@ where
// Add new runs.
for run_name in discoveries.keys() {
let checksum = self.checksum;
let plugin_sampling_hint = self.plugin_sampling_hint.clone();
self.runs.entry(run_name.clone()).or_insert_with(|| {
let mut loader = RunLoader::new(run_name.clone());
let mut loader = RunLoader::new(run_name.clone(), plugin_sampling_hint);
loader.checksum(checksum);
loader
});
Expand Down Expand Up @@ -275,7 +285,8 @@ mod tests {

let commit = Commit::new();
let logdir = DiskLogdir::new(logdir.path().to_path_buf());
let mut loader = LogdirLoader::new(&commit, logdir, 1);
let mut loader =
LogdirLoader::new(&commit, logdir, 1, Arc::new(PluginSamplingHint::default()));

// Check that we persist the right run states in the loader.
loader.reload();
Expand Down Expand Up @@ -330,7 +341,8 @@ mod tests {

let commit = Commit::new();
let logdir = DiskLogdir::new(logdir.path().to_path_buf());
let mut loader = LogdirLoader::new(&commit, logdir, 1);
let mut loader =
LogdirLoader::new(&commit, logdir, 1, Arc::new(PluginSamplingHint::default()));

let get_run_names = || {
let runs_store = commit.runs.read().unwrap();
Expand Down Expand Up @@ -381,7 +393,8 @@ mod tests {

let commit = Commit::new();
let logdir = DiskLogdir::new(logdir.path().to_path_buf());
let mut loader = LogdirLoader::new(&commit, logdir, 1);
let mut loader =
LogdirLoader::new(&commit, logdir, 1, Arc::new(PluginSamplingHint::default()));
loader.reload();

assert_eq!(
Expand All @@ -404,7 +417,8 @@ mod tests {

let commit = Commit::new();
let logdir = DiskLogdir::new(logdir.path().to_path_buf());
let mut loader = LogdirLoader::new(&commit, logdir, 1);
let mut loader =
LogdirLoader::new(&commit, logdir, 1, Arc::new(PluginSamplingHint::default()));
loader.reload(); // should not hang
Ok(())
}
Expand Down
Loading

0 comments on commit 319db49

Please sign in to comment.