From 5640135ee44c5ea874ff47d8330cffa8d371a527 Mon Sep 17 00:00:00 2001 From: Kyle Kelley Date: Fri, 26 Jul 2024 15:53:52 -0700 Subject: [PATCH] fix: adapt to new jupyter runtime API and include session IDs --- Cargo.lock | 11 ++++---- Cargo.toml | 2 +- cli/Cargo.toml | 2 +- cli/ops/jupyter.rs | 18 ++++++------- cli/tools/jupyter/mod.rs | 4 +-- cli/tools/jupyter/server.rs | 41 +++++++++++++++++------------- tests/integration/jupyter_tests.rs | 2 +- 7 files changed, 42 insertions(+), 38 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a67cd87d8c4884..e872982b7b7d2d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5563,9 +5563,9 @@ dependencies = [ [[package]] name = "runtimelib" -version = "0.11.0" +version = "0.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81f4969d577fe13ef40c8eb6fad2ccc66c26c800410672c847f5397699240b9d" +checksum = "0c3d817764e3971867351e6103955b17d808f5330e9ef63aaaaab55bf8c664c1" dependencies = [ "anyhow", "base64 0.22.1", @@ -5573,6 +5573,7 @@ dependencies = [ "chrono", "data-encoding", "dirs", + "futures", "glob", "rand", "ring", @@ -6899,7 +6900,7 @@ dependencies = [ "base64 0.21.7", "bytes", "console_static_text", - "deno_unsync 0.3.10", + "deno_unsync 0.4.0", "denokv_proto", "fastwebsockets", "flate2", @@ -8272,9 +8273,9 @@ dependencies = [ [[package]] name = "zeromq" -version = "0.3.4" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2db35fbc7d9082d39a85c9831ec5dc7b7b135038d2f00bb5ff2a4c0275893da1" +checksum = "fb0560d00172817b7f7c2265060783519c475702ae290b154115ca75e976d4d0" dependencies = [ "async-trait", "asynchronous-codec", diff --git a/Cargo.toml b/Cargo.toml index 6961b69bfc77d3..b5e94e4301f543 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -191,7 +191,7 @@ url = { version = "< 2.5.0", features = ["serde", "expose_internals"] } uuid = { version = "1.3.0", features = ["v4"] } webpki-roots = "0.26" which = "4.2.5" -zeromq = { version = "=0.3.4", default-features = false, features = ["tcp-transport", "tokio-runtime"] } +zeromq = { version = "=0.4.0", default-features = false, features = ["tcp-transport", "tokio-runtime"] } zstd = "=0.12.4" # crypto diff --git a/cli/Cargo.toml b/cli/Cargo.toml index 2d1a2fcad4e6fa..fa95e352fca1dc 100644 --- a/cli/Cargo.toml +++ b/cli/Cargo.toml @@ -116,7 +116,7 @@ hyper-util.workspace = true import_map = { version = "=0.20.0", features = ["ext"] } indexmap.workspace = true jsonc-parser.workspace = true -jupyter_runtime = { package = "runtimelib", version = "=0.11.0" } +jupyter_runtime = { package = "runtimelib", version = "=0.14.0" } lazy-regex.workspace = true libc.workspace = true libz-sys.workspace = true diff --git a/cli/ops/jupyter.rs b/cli/ops/jupyter.rs index 95d232f113df17..f7f006d9bd506c 100644 --- a/cli/ops/jupyter.rs +++ b/cli/ops/jupyter.rs @@ -65,14 +65,12 @@ pub fn op_jupyter_input( return Ok(None); } - let msg = JupyterMessage::new( - InputRequest { - prompt, - password: is_password, - } - .into(), - Some(&last_request), - ); + let content = InputRequest { + prompt, + password: is_password, + }; + + let msg = JupyterMessage::new(content, Some(&last_request)); let Ok(()) = stdin_connection_proxy.lock().tx.send(msg) else { return Ok(None); @@ -149,13 +147,13 @@ pub fn op_print( let sender = state.borrow_mut::>(); if is_err { - if let Err(err) = sender.send(StreamContent::stderr(msg.into())) { + if let Err(err) = sender.send(StreamContent::stderr(msg)) { log::error!("Failed to send stderr message: {}", err); } return Ok(()); } - if let Err(err) = sender.send(StreamContent::stdout(msg.into())) { + if let Err(err) = sender.send(StreamContent::stdout(msg)) { log::error!("Failed to send stdout message: {}", err); } Ok(()) diff --git a/cli/tools/jupyter/mod.rs b/cli/tools/jupyter/mod.rs index eff7f4f9d6ead4..7e88f92c292bed 100644 --- a/cli/tools/jupyter/mod.rs +++ b/cli/tools/jupyter/mod.rs @@ -125,9 +125,7 @@ pub async fn kernel( fn write(&mut self, buf: &[u8]) -> std::io::Result { self .0 - .send(StreamContent::stdout( - String::from_utf8_lossy(buf).into_owned(), - )) + .send(StreamContent::stdout(&String::from_utf8_lossy(buf))) .ok(); Ok(buf.len()) } diff --git a/cli/tools/jupyter/server.rs b/cli/tools/jupyter/server.rs index 6e203d17d67d35..42e341f21e96fc 100644 --- a/cli/tools/jupyter/server.rs +++ b/cli/tools/jupyter/server.rs @@ -20,11 +20,11 @@ use deno_core::parking_lot::Mutex; use deno_core::serde_json; use deno_core::CancelFuture; use deno_core::CancelHandle; +use jupyter_runtime::ExecutionCount; use tokio::sync::mpsc; use tokio::sync::oneshot; use jupyter_runtime::messaging; -use jupyter_runtime::AsChildOf; use jupyter_runtime::ConnectionInfo; use jupyter_runtime::JupyterMessage; use jupyter_runtime::JupyterMessageContent; @@ -34,11 +34,12 @@ use jupyter_runtime::KernelShellConnection; use jupyter_runtime::ReplyError; use jupyter_runtime::ReplyStatus; use jupyter_runtime::StreamContent; +use uuid::Uuid; use super::JupyterReplProxy; pub struct JupyterServer { - execution_count: usize, + execution_count: ExecutionCount, last_execution_request: Arc>>, iopub_connection: Arc>, repl_session_proxy: JupyterReplProxy, @@ -62,16 +63,22 @@ impl JupyterServer { repl_session_proxy: JupyterReplProxy, setup_tx: oneshot::Sender, ) -> Result<(), AnyError> { + let session_id = Uuid::new_v4().to_string(); + let mut heartbeat = connection_info.create_kernel_heartbeat_connection().await?; - let shell_connection = - connection_info.create_kernel_shell_connection().await?; - let control_connection = - connection_info.create_kernel_control_connection().await?; - let mut stdin_connection = - connection_info.create_kernel_stdin_connection().await?; - let iopub_connection = - connection_info.create_kernel_iopub_connection().await?; + let shell_connection = connection_info + .create_kernel_shell_connection(&session_id) + .await?; + let control_connection = connection_info + .create_kernel_control_connection(&session_id) + .await?; + let mut stdin_connection = connection_info + .create_kernel_stdin_connection(&session_id) + .await?; + let iopub_connection = connection_info + .create_kernel_iopub_connection(&session_id) + .await?; let iopub_connection = Arc::new(Mutex::new(iopub_connection)); let last_execution_request = Arc::new(Mutex::new(None)); @@ -96,7 +103,7 @@ impl JupyterServer { let cancel_handle = CancelHandle::new_rc(); let mut server = Self { - execution_count: 0, + execution_count: ExecutionCount::new(0), iopub_connection: iopub_connection.clone(), last_execution_request: last_execution_request.clone(), repl_session_proxy, @@ -468,7 +475,7 @@ impl JupyterServer { connection: &mut KernelShellConnection, ) -> Result<(), AnyError> { if !execute_request.silent && execute_request.store_history { - self.execution_count += 1; + self.execution_count.increment(); } *self.last_execution_request.lock() = Some(parent_message.clone()); @@ -634,11 +641,11 @@ impl JupyterServer { messaging::ExecuteReply { execution_count: self.execution_count, status: ReplyStatus::Error, - error: Some(ReplyError { + error: Some(Box::new(ReplyError { ename, evalue, traceback, - }), + })), user_expressions: None, payload: Default::default(), } @@ -654,7 +661,7 @@ impl JupyterServer { &mut self, message: JupyterMessage, ) -> Result<(), AnyError> { - self.iopub_connection.lock().send(message).await + self.iopub_connection.lock().send(message.clone()).await } } @@ -686,10 +693,10 @@ fn kernel_info() -> messaging::KernelInfoReply { async fn publish_result( repl_session_proxy: &mut JupyterReplProxy, evaluate_result: &cdp::RemoteObject, - execution_count: usize, + execution_count: ExecutionCount, ) -> Result>, AnyError> { let arg0 = cdp::CallArgument { - value: Some(serde_json::Value::Number(execution_count.into())), + value: Some(execution_count.into()), unserializable_value: None, object_id: None, }; diff --git a/tests/integration/jupyter_tests.rs b/tests/integration/jupyter_tests.rs index e0733d5a308e5f..1b2c2131182576 100644 --- a/tests/integration/jupyter_tests.rs +++ b/tests/integration/jupyter_tests.rs @@ -493,7 +493,7 @@ async fn jupyter_heartbeat_echoes() -> Result<()> { let (_ctx, client, _process) = setup().await; client.send_heartbeat(b"ping").await?; let msg = client.recv_heartbeat().await?; - assert_eq!(msg, Bytes::from_static(b"ping")); + assert_eq!(msg, Bytes::from_static(b"pong")); Ok(()) }