diff --git a/newsfragments/3920.added.md b/newsfragments/3920.added.md new file mode 100644 index 00000000000..e19b0eb0e77 --- /dev/null +++ b/newsfragments/3920.added.md @@ -0,0 +1 @@ +Add `pyo3::stdio::stdout` and `pyo3::stdio::stderr` to enable direct print to python `sys.stdout` and `sys.stderr`. \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index c2591cec91e..934483aad33 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -430,6 +430,7 @@ pub mod impl_; mod instance; pub mod marker; pub mod marshal; +pub mod stdio; #[macro_use] pub mod sync; pub mod panic; diff --git a/src/stdio.rs b/src/stdio.rs new file mode 100644 index 00000000000..b994a158783 --- /dev/null +++ b/src/stdio.rs @@ -0,0 +1,87 @@ +//! Enables direct write access to I/O streams in Python's `sys` module. + +//! In some cases printing to Rust's `std::io::stdout` or `std::io::stderr` will not appear +//! in the Python interpreter, e.g. in Jupyter notebooks. This module provides a way to write +//! directly to Python's I/O streams from Rust in such cases. + +//! ```rust +//! let mut stdout = pyo3::stdio::stdout(); +//! +//! // This may not appear in Jupyter notebooks... +//! println!("Hello, world!"); +//! +//! // ...but this will. +//! writeln!(stdout, "Hello, world!").unwrap(); +//! ``` + +use crate::intern; +use crate::prelude::*; +use crate::types::PyString; +use std::io::{LineWriter, Write}; + +/// Implements `std::io::Write` for a `PyAny` object. The underlying +/// Python object must provide both `write` and `flush` methods. +pub struct PyWriter(Py); + +impl PyWriter { + /// Construct a new `PyWriter` from a `PyAny` object. + pub fn buffered(self) -> LineWriter { + LineWriter::new(self) + } +} + +/// A GIL-attached equivalent to PyWriter. +pub struct PyWriterBound<'a, 'py>(&'a Bound<'py, PyAny>); + +fn get_stdio_pywriter(stream: &str) -> PyWriter { + Python::with_gil(|py| { + let module = PyModule::import_bound(py, "sys").unwrap(); + let stream = module.getattr(stream).unwrap(); + PyWriter(stream.into()) + }) +} + +/// Construct a new `PyWriter` for Python's `sys.stdout` stream. +pub fn stdout() -> PyWriter { + get_stdio_pywriter("stdout") +} + +/// Construct a new `PyWriter` for Python's `sys.stderr` stream. +pub fn stderr() -> PyWriter { + get_stdio_pywriter("stderr") +} + +/// Construct a new `PyWriter` for Python's `sys.__stdout__` stream. +pub fn __stdout__() -> PyWriter { + get_stdio_pywriter("__stdout__") +} + +/// Construct a new `PyWriter` for Python's `sys.__stderr__` stream. +pub fn __stderr__() -> PyWriter { + get_stdio_pywriter("__stderr__") +} + +impl Write for PyWriter { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + Python::with_gil(|py| PyWriterBound(self.0.bind(py)).write(buf)) + } + fn flush(&mut self) -> std::io::Result<()> { + Python::with_gil(|py| PyWriterBound(self.0.bind(py)).flush()) + } +} + +impl Write for PyWriterBound<'_, '_> { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + let str = PyString::new_bound(self.0.py(), &String::from_utf8_lossy(buf)); + self.0 + .call_method1(intern!(self.0.py(), "write"), (str,)) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; + Ok(buf.len()) + } + fn flush(&mut self) -> std::io::Result<()> { + self.0 + .call_method0(intern!(self.0.py(), "flush")) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; + Ok(()) + } +} diff --git a/tests/test_stdio.rs b/tests/test_stdio.rs new file mode 100644 index 00000000000..aba353bdc18 --- /dev/null +++ b/tests/test_stdio.rs @@ -0,0 +1,47 @@ +#![cfg(feature = "macros")] + +use pyo3::prelude::*; + +#[macro_use] +#[path = "../src/tests/common.rs"] +mod common; + +#[test] +fn test_stdio() { + use pyo3::stdio::*; + use pyo3::types::IntoPyDict; + use std::io::Write; + + let stream_fcns = [stdout, stderr, __stdout__, __stderr__]; + let stream_names = ["stdout", "stderr", "__stdout__", "__stderr__"]; + + for (stream_fcn, stream_name) in stream_fcns.iter().zip(stream_names.iter()) { + Python::with_gil(|py| { + py.run_bound("import sys, io", None, None).unwrap(); + + // redirect stdout or stderr output to a StringIO object + let target = py.eval_bound("io.StringIO()", None, None).unwrap(); + let locals = [("target", &target)].into_py_dict_bound(py); + py.run_bound( + &format!("sys.{} = target", stream_name), + None, + Some(&locals), + ) + .unwrap(); + + let mut stream = stream_fcn(); + assert!(writeln!(stream, "writing to {}", stream_name).is_ok()); + + Python::run_bound( + py, + &format!( + "assert target.getvalue() == 'writing to {}\\n'", + stream_name + ), + Some(&locals), + None, + ) + .unwrap(); + }); + } +}