Skip to content

Commit

Permalink
Add psql download progress
Browse files Browse the repository at this point in the history
  • Loading branch information
hlinander committed Jan 18, 2024
1 parent 98a9da1 commit ef04a46
Showing 1 changed file with 109 additions and 33 deletions.
142 changes: 109 additions & 33 deletions rust/vis/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,19 @@ enum NPYArray {
Error(String),
}

#[derive(Clone)]
struct DownloadProgressStatus {
downloaded: usize,
size: usize,
}

struct DownloadProgress {
rx_update: Receiver<DownloadProgressStatus>,
status: DownloadProgressStatus,
}

enum BinaryArtifact {
Loading(JoinHandle<Result<Vec<u8>, String>>),
Loading((JoinHandle<Result<Vec<u8>, String>>, DownloadProgress)),
Loaded(Vec<u8>),
Error(String),
}
Expand All @@ -111,9 +122,10 @@ fn download_artifact(
name: String,
path: String,
tx_path_mutex: Arc<Mutex<Sender<(String, String, String)>>>,
rx_artifact_mutex: Arc<Mutex<Receiver<Result<Vec<u8>, String>>>>,
rx_artifact_mutex: Arc<Mutex<Receiver<ArtifactTransfer>>>,
) -> BinaryArtifact {
BinaryArtifact::Loading(tokio::spawn(async move {
let (tx_update, rx_update) = mpsc::channel::<DownloadProgressStatus>();
let join_handle = tokio::spawn(async move {
let uri = if !path.starts_with("/") {
println!("path: {}", path);
let end = path.split("results/").last().unwrap();
Expand All @@ -138,32 +150,42 @@ fn download_artifact(
// name.clone()
// );
let rx_db_artifact = rx_artifact_mutex.lock_owned().await;
let rx_res = rx_db_artifact.recv();
// println!(
// "[db] Download complete {} {}",
// train_id.clone(),
// name.clone()
// );
match rx_res {
Ok(artifact_binary_res) => match artifact_binary_res {
Ok(artifact_binary) => {
return Ok(artifact_binary);
}
Err(artifact_binary_err) => {
return Err(artifact_binary_err.to_string());
loop {
let rx_res = rx_db_artifact.recv();
match rx_res {
Ok(artifact_binary_res) => match artifact_binary_res {
ArtifactTransfer::Done(artifact_binary) => {
return Ok(artifact_binary);
}
ArtifactTransfer::Err(artifact_binary_err) => {
return Err(artifact_binary_err.to_string());
}
ArtifactTransfer::Loading(downloaded, size) => {
tx_update.send(DownloadProgressStatus { downloaded, size });
}
},
Err(err) => {
return Err(err.to_string());
}
},
Err(err) => {
return Err(err.to_string());
}
}
}))
});
BinaryArtifact::Loading((
join_handle,
DownloadProgress {
rx_update,
status: DownloadProgressStatus {
downloaded: 0,
size: 0,
},
},
))
}

fn poll_artifact_download(binary_artifact: &mut BinaryArtifact) {
let mut new_binary_artifact: Option<BinaryArtifact> = None;
match binary_artifact {
BinaryArtifact::Loading(join_handle) => {
BinaryArtifact::Loading((join_handle, download_progress)) => {
// println!("[ARTIFACTS] Loading ....");
if join_handle.is_finished() {
let data_res = tokio::runtime::Handle::current()
Expand All @@ -175,6 +197,8 @@ fn poll_artifact_download(binary_artifact: &mut BinaryArtifact) {
}
Err(err) => new_binary_artifact = Some(BinaryArtifact::Error(err.to_string())),
}
} else if let Ok(download_status) = download_progress.rx_update.try_recv() {
download_progress.status = download_status;
}
}
BinaryArtifact::Loaded(_) => {}
Expand Down Expand Up @@ -204,7 +228,7 @@ fn add_artifact(
path: &str,
args: &Args,
tx_path_mutex: &mut Arc<Mutex<Sender<(String, String, String)>>>,
rx_artifact_mutex: &Arc<Mutex<Receiver<Result<Vec<u8>, String>>>>,
rx_artifact_mutex: &Arc<Mutex<Receiver<ArtifactTransfer>>>,
) {
let artifact_id = ArtifactId {
train_id: run_id.to_string(),
Expand Down Expand Up @@ -329,7 +353,7 @@ struct GuiRuns {
recomputed_reciever: Receiver<HashMap<String, Run>>,
dirty_sender: Sender<(GuiParams, HashMap<String, Run>)>,
tx_db_artifact_path: Arc<Mutex<Sender<(String, String, String)>>>,
rx_db_artifact: Arc<Mutex<Receiver<Result<Vec<u8>, String>>>>,
rx_db_artifact: Arc<Mutex<Receiver<ArtifactTransfer>>>,
initialized: bool,
data_status: DataStatus,
gui_params: GuiParams,
Expand Down Expand Up @@ -665,8 +689,22 @@ fn show_artifacts(
ui.allocate_space(egui::Vec2::new(ui.available_width(), 0.0));
for (artifact_id, array) in array_group {
match array {
NPYArray::Loading(_) => {
ui.label("loading...");
NPYArray::Loading(binary_artifact) => {
if let BinaryArtifact::Loading((_, status)) =
binary_artifact
{
if status.status.size > 0 {
ui.label(format!(
"{:.1}/{:.1}",
status.status.downloaded as f32 / 1e6,
status.status.size as f32 / 1e6
));
ui.add(egui::ProgressBar::new(
status.status.downloaded as f32
/ status.status.size as f32,
));
}
}
}
NPYArray::Loaded(array) => {
// ui.allocate_ui()
Expand Down Expand Up @@ -1587,6 +1625,12 @@ fn parse_metric_rows(
}
}
}

enum ArtifactTransfer {
Done(Vec<u8>),
Loading(usize, usize),
Err(String),
}
// #[tokio::main(flavor = "current_thread")]
fn main() -> Result<(), sqlx::Error> {
// Load environment variables
Expand All @@ -1597,7 +1641,7 @@ fn main() -> Result<(), sqlx::Error> {
let (tx_gui_dirty, rx_gui_dirty) = mpsc::channel();
let (tx_gui_recomputed, rx_gui_recomputed) = mpsc::channel();
let (tx_db_filters, rx_db_filters) = mpsc::channel();
let (tx_db_artifact, rx_db_artifact) = mpsc::channel::<Result<Vec<u8>, String>>();
let (tx_db_artifact, rx_db_artifact) = mpsc::channel::<ArtifactTransfer>();
let (tx_db_artifact_path, rx_db_artifact_path) = mpsc::channel::<(String, String, String)>();
let rt = tokio::runtime::Runtime::new().unwrap();
let rt_handle = rt.handle().clone();
Expand Down Expand Up @@ -1659,18 +1703,50 @@ fn main() -> Result<(), sqlx::Error> {
loop {
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
if let Ok((train_id, name, path)) = rx_db_artifact_path.try_recv() {
let blobs_rows = sqlx::query(
let size_res = sqlx::query(
r#"
SELECT pg_read_binary_file($1)
SELECT (pg_stat_file($1)).size as size
"#,
)
.bind(path)
.bind(&path)
.fetch_one(&pool)
.await;
if let Ok(row) = blobs_rows {
tx_db_artifact.send(Ok(row.get::<Vec<u8>, _>("pg_read_binary_file")));
} else if let Err(error) = blobs_rows {
tx_db_artifact.send(Err(error.to_string()));
if let Ok(row) = size_res {
let filesize = row.get::<i64, _>("size");
println!("[f] File {} size {}", &path, &filesize);
let mut offset = 0;
let chunk_size = 1_000_000;
let mut buffer: Vec<u8> = vec![0; filesize as usize];
while offset < filesize {
let length = chunk_size.min(filesize - offset);
let blobs_rows = sqlx::query(
r#"
SELECT pg_read_binary_file($1, $2, $3)
"#,
)
.bind(&path)
.bind(&offset)
.bind(&length)
.fetch_one(&pool)
.await;
if let Ok(row) = blobs_rows {
let dst =
&mut buffer[offset as usize..offset as usize + length as usize];
dst.copy_from_slice(
row.get::<Vec<u8>, _>("pg_read_binary_file").as_slice(),
);
offset += length;
println!("[f] Read chunk {} at {}", length, offset);
tx_db_artifact.send(ArtifactTransfer::Loading(
offset as usize,
filesize as usize,
));
} else if let Err(error) = blobs_rows {
tx_db_artifact.send(ArtifactTransfer::Err(error.to_string()));
break;
}
}
tx_db_artifact.send(ArtifactTransfer::Done(buffer));
}
// let blobs_rows = sqlx::query(
// r#"
Expand Down

0 comments on commit ef04a46

Please sign in to comment.