Skip to content

Commit

Permalink
port Python::import to Bound API
Browse files Browse the repository at this point in the history
  • Loading branch information
Icxolu committed Feb 13, 2024
1 parent e308c8d commit db87a73
Show file tree
Hide file tree
Showing 26 changed files with 148 additions and 104 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,10 @@ use pyo3::types::IntoPyDict;

fn main() -> PyResult<()> {
Python::with_gil(|py| {
let sys = py.import("sys")?;
let sys = py.import_bound("sys")?;
let version: String = sys.getattr("version")?.extract()?;

let locals = [("os", py.import("os")?)].into_py_dict_bound(py);
let locals = [("os", py.import_bound("os")?)].into_py_dict_bound(py);
let code = "os.getenv('USER') or os.getenv('USERNAME') or 'Unknown'";
let user: String = py.eval_bound(code, None, Some(&locals))?.extract()?;

Expand Down
4 changes: 2 additions & 2 deletions guide/src/python_from_rust.md
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ fn main() -> PyResult<()> {
let path = Path::new("/usr/share/python_app");
let py_app = fs::read_to_string(path.join("app.py"))?;
let from_python = Python::with_gil(|py| -> PyResult<Py<PyAny>> {
let syspath: &PyList = py.import("sys")?.getattr("path")?.downcast()?;
let syspath = py.import_bound("sys")?.getattr("path")?.downcast_into::<PyList>()?;
syspath.insert(0, &path)?;
let app: Py<PyAny> = PyModule::from_code(py, &py_app, "", "")?
.getattr("run")?
Expand Down Expand Up @@ -498,7 +498,7 @@ use pyo3::prelude::*;

# fn main() -> PyResult<()> {
Python::with_gil(|py| -> PyResult<()> {
let signal = py.import("signal")?;
let signal = py.import_bound("signal")?;
// Set SIGINT to have the default action
signal
.getattr("signal")?
Expand Down
7 changes: 4 additions & 3 deletions src/buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,7 @@ impl_element!(f64, Float);
mod tests {
use super::PyBuffer;
use crate::ffi;
use crate::types::any::PyAnyMethods;
use crate::Python;

#[test]
Expand Down Expand Up @@ -890,11 +891,11 @@ mod tests {
fn test_array_buffer() {
Python::with_gil(|py| {
let array = py
.import("array")
.import_bound("array")
.unwrap()
.call_method("array", ("f", (1.0, 1.5, 2.0, 2.5)), None)
.unwrap();
let buffer = PyBuffer::get(array).unwrap();
let buffer = PyBuffer::get(array.as_gil_ref()).unwrap();
assert_eq!(buffer.dimensions(), 1);
assert_eq!(buffer.item_count(), 4);
assert_eq!(buffer.format().to_str().unwrap(), "f");
Expand Down Expand Up @@ -924,7 +925,7 @@ mod tests {
assert_eq!(buffer.to_vec(py).unwrap(), [10.0, 11.0, 12.0, 13.0]);

// F-contiguous fns
let buffer = PyBuffer::get(array).unwrap();
let buffer = PyBuffer::get(array.as_gil_ref()).unwrap();
let slice = buffer.as_fortran_slice(py).unwrap();
assert_eq!(slice.len(), 4);
assert_eq!(slice[1].get(), 11.0);
Expand Down
30 changes: 15 additions & 15 deletions src/conversions/chrono.rs
Original file line number Diff line number Diff line change
Expand Up @@ -540,15 +540,15 @@ impl DatetimeTypes {
static TYPES: GILOnceCell<DatetimeTypes> = GILOnceCell::new();
TYPES
.get_or_try_init(py, || {
let datetime = py.import("datetime")?;
let datetime = py.import_bound("datetime")?;
let timezone = datetime.getattr("timezone")?;
Ok::<_, PyErr>(Self {
date: datetime.getattr("date")?.into(),
datetime: datetime.getattr("datetime")?.into(),
time: datetime.getattr("time")?.into(),
timedelta: datetime.getattr("timedelta")?.into(),
timezone: timezone.into(),
timezone_utc: timezone.getattr("utc")?.into(),
timezone: timezone.into(),
tzinfo: datetime.getattr("tzinfo")?.into(),
})
})
Expand Down Expand Up @@ -683,7 +683,7 @@ mod tests {
let delta = delta.to_object(py);
let py_delta = new_py_datetime_ob(py, "timedelta", (py_days, py_seconds, py_ms));
assert!(
delta.as_ref(py).eq(py_delta).unwrap(),
delta.bind(py).eq(&py_delta).unwrap(),
"{}: {} != {}",
name,
delta,
Expand Down Expand Up @@ -779,7 +779,7 @@ mod tests {
.to_object(py);
let py_date = new_py_datetime_ob(py, "date", (year, month, day));
assert_eq!(
date.as_ref(py).compare(py_date).unwrap(),
date.bind(py).compare(&py_date).unwrap(),
Ordering::Equal,
"{}: {} != {}",
name,
Expand Down Expand Up @@ -838,7 +838,7 @@ mod tests {
),
);
assert_eq!(
datetime.as_ref(py).compare(py_datetime).unwrap(),
datetime.bind(py).compare(&py_datetime).unwrap(),
Ordering::Equal,
"{}: {} != {}",
name,
Expand Down Expand Up @@ -880,7 +880,7 @@ mod tests {
(year, month, day, hour, minute, second, py_ms, py_tz),
);
assert_eq!(
datetime.as_ref(py).compare(py_datetime).unwrap(),
datetime.bind(py).compare(&py_datetime).unwrap(),
Ordering::Equal,
"{}: {} != {}",
name,
Expand Down Expand Up @@ -1005,7 +1005,7 @@ mod tests {
Python::with_gil(|py| {
let utc = Utc.to_object(py);
let py_utc = python_utc(py);
assert!(utc.as_ref(py).is(py_utc));
assert!(utc.bind(py).is(&py_utc));
})
}

Expand Down Expand Up @@ -1036,7 +1036,7 @@ mod tests {
.to_object(py);
let py_time = new_py_datetime_ob(py, "time", (hour, minute, second, py_ms));
assert!(
time.as_ref(py).eq(py_time).unwrap(),
time.bind(py).eq(&py_time).unwrap(),
"{}: {} != {}",
name,
time,
Expand Down Expand Up @@ -1071,21 +1071,21 @@ mod tests {
})
}

fn new_py_datetime_ob<'a>(
py: Python<'a>,
fn new_py_datetime_ob<'py>(
py: Python<'py>,
name: &str,
args: impl IntoPy<Py<PyTuple>>,
) -> &'a PyAny {
py.import("datetime")
) -> Bound<'py, PyAny> {
py.import_bound("datetime")
.unwrap()
.getattr(name)
.unwrap()
.call1(args)
.unwrap()
}

fn python_utc(py: Python<'_>) -> &PyAny {
py.import("datetime")
fn python_utc(py: Python<'_>) -> Bound<'_, PyAny> {
py.import_bound("datetime")
.unwrap()
.getattr("timezone")
.unwrap()
Expand All @@ -1108,7 +1108,7 @@ mod tests {
fn test_pyo3_offset_fixed_frompyobject_created_in_python(timestamp in 0..(i32::MAX as i64), timedelta in -86399i32..=86399i32) {
Python::with_gil(|py| {

let globals = [("datetime", py.import("datetime").unwrap())].into_py_dict_bound(py);
let globals = [("datetime", py.import_bound("datetime").unwrap())].into_py_dict_bound(py);
let code = format!("datetime.datetime.fromtimestamp({}).replace(tzinfo=datetime.timezone(datetime.timedelta(seconds={})))", timestamp, timedelta);
let t = py.eval_bound(&code, Some(&globals), None).unwrap();

Expand Down
13 changes: 8 additions & 5 deletions src/conversions/chrono_tz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ mod tests {
#[test]
fn test_topyobject() {
Python::with_gil(|py| {
let assert_eq = |l: PyObject, r: &PyAny| {
assert!(l.as_ref(py).eq(r).unwrap());
let assert_eq = |l: PyObject, r: Bound<'_, PyAny>| {
assert!(l.bind(py).eq(r).unwrap());
};

assert_eq(
Expand All @@ -105,11 +105,14 @@ mod tests {
});
}

fn new_zoneinfo<'a>(py: Python<'a>, name: &str) -> &'a PyAny {
fn new_zoneinfo<'py>(py: Python<'py>, name: &str) -> Bound<'py, PyAny> {
zoneinfo_class(py).call1((name,)).unwrap()
}

fn zoneinfo_class(py: Python<'_>) -> &PyAny {
py.import("zoneinfo").unwrap().getattr("ZoneInfo").unwrap()
fn zoneinfo_class(py: Python<'_>) -> Bound<'_, PyAny> {
py.import_bound("zoneinfo")
.unwrap()
.getattr("ZoneInfo")
.unwrap()
}
}
2 changes: 1 addition & 1 deletion src/conversions/rust_decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ static DECIMAL_CLS: GILOnceCell<Py<PyType>> = GILOnceCell::new();
fn get_decimal_cls(py: Python<'_>) -> PyResult<&PyType> {
DECIMAL_CLS
.get_or_try_init(py, || {
py.import(intern!(py, "decimal"))?
py.import_bound(intern!(py, "decimal"))?
.getattr(intern!(py, "Decimal"))?
.extract()
})
Expand Down
43 changes: 26 additions & 17 deletions src/conversions/std/time.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ fn unix_epoch_py(py: Python<'_>) -> &PyObject {
}
#[cfg(Py_LIMITED_API)]
{
let datetime = py.import("datetime")?;
let datetime = py.import_bound("datetime")?;
let utc = datetime.getattr("timezone")?.getattr("utc")?;
Ok::<_, PyErr>(
datetime
Expand Down Expand Up @@ -216,8 +216,8 @@ mod tests {
#[test]
fn test_duration_topyobject() {
Python::with_gil(|py| {
let assert_eq = |l: PyObject, r: &PyAny| {
assert!(l.as_ref(py).eq(r).unwrap());
let assert_eq = |l: PyObject, r: Bound<'_, PyAny>| {
assert!(l.bind(py).eq(r).unwrap());
};

assert_eq(
Expand Down Expand Up @@ -300,8 +300,8 @@ mod tests {
#[test]
fn test_time_topyobject() {
Python::with_gil(|py| {
let assert_eq = |l: PyObject, r: &PyAny| {
assert!(l.as_ref(py).eq(r).unwrap());
let assert_eq = |l: PyObject, r: Bound<'_, PyAny>| {
assert!(l.bind(py).eq(r).unwrap());
};

assert_eq(
Expand Down Expand Up @@ -331,7 +331,7 @@ mod tests {
minute: u8,
second: u8,
microsecond: u32,
) -> &PyAny {
) -> Bound<'_, PyAny> {
datetime_class(py)
.call1((
year,
Expand All @@ -346,13 +346,11 @@ mod tests {
.unwrap()
}

fn max_datetime(py: Python<'_>) -> &PyAny {
fn max_datetime(py: Python<'_>) -> Bound<'_, PyAny> {
let naive_max = datetime_class(py).getattr("max").unwrap();
let kargs = PyDict::new_bound(py);
kargs.set_item("tzinfo", tz_utc(py)).unwrap();
naive_max
.call_method("replace", (), Some(kargs.as_gil_ref()))
.unwrap()
naive_max.call_method("replace", (), Some(&kargs)).unwrap()
}

#[test]
Expand All @@ -365,26 +363,37 @@ mod tests {
})
}

fn tz_utc(py: Python<'_>) -> &PyAny {
py.import("datetime")
fn tz_utc(py: Python<'_>) -> Bound<'_, PyAny> {
py.import_bound("datetime")
.unwrap()
.getattr("timezone")
.unwrap()
.getattr("utc")
.unwrap()
}

fn new_timedelta(py: Python<'_>, days: i32, seconds: i32, microseconds: i32) -> &PyAny {
fn new_timedelta(
py: Python<'_>,
days: i32,
seconds: i32,
microseconds: i32,
) -> Bound<'_, PyAny> {
timedelta_class(py)
.call1((days, seconds, microseconds))
.unwrap()
}

fn datetime_class(py: Python<'_>) -> &PyAny {
py.import("datetime").unwrap().getattr("datetime").unwrap()
fn datetime_class(py: Python<'_>) -> Bound<'_, PyAny> {
py.import_bound("datetime")
.unwrap()
.getattr("datetime")
.unwrap()
}

fn timedelta_class(py: Python<'_>) -> &PyAny {
py.import("datetime").unwrap().getattr("timedelta").unwrap()
fn timedelta_class(py: Python<'_>) -> Bound<'_, PyAny> {
py.import_bound("datetime")
.unwrap()
.getattr("timedelta")
.unwrap()
}
}
3 changes: 2 additions & 1 deletion src/coroutine/waker.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::sync::GILOnceCell;
use crate::types::any::PyAnyMethods;
use crate::types::PyCFunction;
use crate::{intern, wrap_pyfunction, Py, PyAny, PyObject, PyResult, Python};
use pyo3_macros::pyfunction;
Expand Down Expand Up @@ -56,7 +57,7 @@ impl LoopAndFuture {
fn new(py: Python<'_>) -> PyResult<Self> {
static GET_RUNNING_LOOP: GILOnceCell<PyObject> = GILOnceCell::new();
let import = || -> PyResult<_> {
let module = py.import("asyncio")?;
let module = py.import_bound("asyncio")?;
Ok(module.getattr("get_running_loop")?.into())
};
let event_loop = GET_RUNNING_LOOP.get_or_try_init(py, import)?.call0(py)?;
Expand Down
3 changes: 2 additions & 1 deletion src/err/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -974,6 +974,7 @@ impl_signed_integer!(isize);
mod tests {
use super::PyErrState;
use crate::exceptions::{self, PyTypeError, PyValueError};
use crate::types::any::PyAnyMethods;
use crate::{PyErr, PyTypeInfo, Python};

#[test]
Expand Down Expand Up @@ -1174,7 +1175,7 @@ mod tests {
let cls = py.get_type::<exceptions::PyUserWarning>();

// Reset warning filter to default state
let warnings = py.import("warnings").unwrap();
let warnings = py.import_bound("warnings").unwrap();
warnings.call_method0("resetwarnings").unwrap();

// First, test the warning is emitted
Expand Down
7 changes: 4 additions & 3 deletions src/exceptions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,14 @@ macro_rules! import_exception {
fn type_object_raw(py: $crate::Python<'_>) -> *mut $crate::ffi::PyTypeObject {
use $crate::sync::GILOnceCell;
use $crate::prelude::PyTracebackMethods;
use $crate::prelude::PyAnyMethods;
static TYPE_OBJECT: GILOnceCell<$crate::Py<$crate::types::PyType>> =
GILOnceCell::new();

TYPE_OBJECT
.get_or_init(py, || {
let imp = py
.import(stringify!($module))
.import_bound(stringify!($module))
.unwrap_or_else(|err| {
let traceback = err
.traceback_bound(py)
Expand Down Expand Up @@ -812,7 +813,7 @@ mod tests {
Python::with_gil(|py| {
let err: PyErr = gaierror::new_err(());
let socket = py
.import("socket")
.import_bound("socket")
.map_err(|e| e.display(py))
.expect("could not import socket");

Expand All @@ -836,7 +837,7 @@ mod tests {
Python::with_gil(|py| {
let err: PyErr = MessageError::new_err(());
let email = py
.import("email")
.import_bound("email")
.map_err(|e| e.display(py))
.expect("could not import email");

Expand Down
2 changes: 1 addition & 1 deletion src/gil.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ where

// Import the threading module - this ensures that it will associate this thread as the "main"
// thread, which is important to avoid an `AssertionError` at finalization.
pool.python().import("threading").unwrap();
pool.python().import_bound("threading").unwrap();

// Execute the closure.
let result = f(pool.python());
Expand Down
3 changes: 2 additions & 1 deletion src/impl_/pymodule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,8 @@ mod tests {
assert_eq!((*module_def.ffi_def.get()).m_doc, DOC.as_ptr() as _);

Python::with_gil(|py| {
module_def.initializer.0(py, py.import("builtins").unwrap()).unwrap();
module_def.initializer.0(py, py.import_bound("builtins").unwrap().into_gil_ref())
.unwrap();
assert!(INIT_CALLED.load(Ordering::SeqCst));
})
}
Expand Down
Loading

0 comments on commit db87a73

Please sign in to comment.