Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pyo3 update. #2545

Merged
merged 2 commits into from
Oct 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions candle-examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ intel-mkl-src = { workspace = true, optional = true }
num-traits = { workspace = true }
palette = { version = "0.7.6", optional = true }
enterpolation = { version = "0.2.1", optional = true}
pyo3 = { version = "0.21.0", features = ["auto-initialize"], optional = true }
pyo3 = { version = "0.22.0", features = ["auto-initialize"], optional = true }
rayon = { workspace = true }
rubato = { version = "0.15.0", optional = true }
safetensors = { workspace = true }
Expand Down Expand Up @@ -121,4 +121,4 @@ required-features = ["onnx"]

[[example]]
name = "colpali"
required-features = ["pdf2image"]
required-features = ["pdf2image"]
4 changes: 2 additions & 2 deletions candle-pyo3/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ candle-nn = { workspace = true }
candle-onnx = { workspace = true, optional = true }
half = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
pyo3 = { version = "0.21.0", features = ["extension-module", "abi3-py38"] }
pyo3 = { version = "0.22.0", features = ["extension-module", "abi3-py38"] }

[build-dependencies]
pyo3-build-config = "0.21"
pyo3-build-config = "0.22"

[features]
default = []
Expand Down
10 changes: 3 additions & 7 deletions candle-pyo3/py_src/candle/utils/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,15 @@ def has_mkl() -> bool:
pass

@staticmethod
def load_ggml(
path: Union[str, PathLike], device: Optional[Device] = None
) -> Tuple[Dict[str, QTensor], Dict[str, Any], List[str]]:
def load_ggml(path, device=None) -> Tuple[Dict[str, QTensor], Dict[str, Any], List[str]]:
"""
Load a GGML file. Returns a tuple of three objects: a dictionary mapping tensor names to tensors,
a dictionary mapping hyperparameter names to hyperparameter values, and a vocabulary.
"""
pass

@staticmethod
def load_gguf(
path: Union[str, PathLike], device: Optional[Device] = None
) -> Tuple[Dict[str, QTensor], Dict[str, Any]]:
def load_gguf(path, device=None) -> Tuple[Dict[str, QTensor], Dict[str, Any]]:
"""
Loads a GGUF file. Returns a tuple of two dictionaries: the first maps tensor names to tensors,
and the second maps metadata keys to metadata values.
Expand All @@ -60,7 +56,7 @@ def load_safetensors(path: Union[str, PathLike]) -> Dict[str, Tensor]:
pass

@staticmethod
def save_gguf(path: Union[str, PathLike], tensors: Dict[str, QTensor], metadata: Dict[str, Any]):
def save_gguf(path, tensors, metadata):
"""
Save quanitzed tensors and metadata to a GGUF file.
"""
Expand Down
19 changes: 9 additions & 10 deletions candle-pyo3/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use pyo3::types::{IntoPyDict, PyDict, PyTuple};
use pyo3::ToPyObject;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::os::raw::c_long;
use std::sync::Arc;

use half::{bf16, f16};
Expand Down Expand Up @@ -115,7 +114,7 @@ impl PyDevice {
}

impl<'source> FromPyObject<'source> for PyDevice {
fn extract(ob: &'source PyAny) -> PyResult<Self> {
fn extract_bound(ob: &Bound<'source, PyAny>) -> PyResult<Self> {
let device: String = ob.extract()?;
let device = match device.as_str() {
"cpu" => PyDevice::Cpu,
Expand Down Expand Up @@ -217,11 +216,11 @@ enum Indexer {
IndexSelect(Tensor),
}

#[derive(Clone, Debug)]
#[derive(Debug)]
struct TorchTensor(PyObject);

impl<'source> pyo3::FromPyObject<'source> for TorchTensor {
fn extract(ob: &'source PyAny) -> PyResult<Self> {
fn extract_bound(ob: &Bound<'source, PyAny>) -> PyResult<Self> {
let numpy_value: PyObject = ob.getattr("numpy")?.call0()?.extract()?;
Ok(TorchTensor(numpy_value))
}
Expand Down Expand Up @@ -540,7 +539,7 @@ impl PyTensor {
))
} else if let Ok(slice) = py_indexer.downcast::<pyo3::types::PySlice>() {
// Handle a single slice e.g. tensor[0:1] or tensor[0:-1]
let index = slice.indices(dims[current_dim] as c_long)?;
let index = slice.indices(dims[current_dim] as isize)?;
Ok((
Indexer::Slice(index.start as usize, index.stop as usize),
current_dim + 1,
Expand Down Expand Up @@ -1284,7 +1283,7 @@ fn save_safetensors(
}

#[pyfunction]
#[pyo3(text_signature = "(path:Union[str,PathLike], device: Optional[Device] = None)")]
#[pyo3(signature = (path, device = None))]
/// Load a GGML file. Returns a tuple of three objects: a dictionary mapping tensor names to tensors,
/// a dictionary mapping hyperparameter names to hyperparameter values, and a vocabulary.
/// &RETURNS&: Tuple[Dict[str,QTensor], Dict[str,Any], List[str]]
Expand Down Expand Up @@ -1325,7 +1324,7 @@ fn load_ggml(
}

#[pyfunction]
#[pyo3(text_signature = "(path:Union[str,PathLike], device: Optional[Device] = None)")]
#[pyo3(signature = (path, device = None))]
/// Loads a GGUF file. Returns a tuple of two dictionaries: the first maps tensor names to tensors,
/// and the second maps metadata keys to metadata values.
/// &RETURNS&: Tuple[Dict[str,QTensor], Dict[str,Any]]
Expand Down Expand Up @@ -1384,7 +1383,7 @@ fn load_gguf(

#[pyfunction]
#[pyo3(
text_signature = "(path:Union[str,PathLike], tensors:Dict[str,QTensor], metadata:Dict[str,Any])"
signature = (path, tensors, metadata)
)]
/// Save quanitzed tensors and metadata to a GGUF file.
fn save_gguf(path: &str, tensors: PyObject, metadata: PyObject, py: Python<'_>) -> PyResult<()> {
Expand Down Expand Up @@ -1430,7 +1429,7 @@ fn save_gguf(path: &str, tensors: PyObject, metadata: PyObject, py: Python<'_>)
Ok(v)
}
let tensors = tensors
.extract::<&PyDict>(py)
.downcast_bound::<PyDict>(py)
.map_err(|_| PyErr::new::<PyValueError, _>("expected a dict"))?
.iter()
.map(|(key, value)| {
Expand All @@ -1443,7 +1442,7 @@ fn save_gguf(path: &str, tensors: PyObject, metadata: PyObject, py: Python<'_>)
.collect::<PyResult<Vec<_>>>()?;

let metadata = metadata
.extract::<&PyDict>(py)
.downcast_bound::<PyDict>(py)
.map_err(|_| PyErr::new::<PyValueError, _>("expected a dict"))?
.iter()
.map(|(key, value)| {
Expand Down
12 changes: 6 additions & 6 deletions candle-pyo3/src/shape.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use pyo3::prelude::*;
pub struct PyShape(Vec<usize>);

impl<'source> pyo3::FromPyObject<'source> for PyShape {
fn extract(ob: &'source PyAny) -> PyResult<Self> {
fn extract_bound(ob: &Bound<'source, PyAny>) -> PyResult<Self> {
if ob.is_none() {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"Shape cannot be None",
Expand All @@ -16,10 +16,10 @@ impl<'source> pyo3::FromPyObject<'source> for PyShape {
let tuple = ob.downcast::<pyo3::types::PyTuple>()?;
if tuple.len() == 1 {
let first_element = tuple.get_item(0)?;
let dims: Vec<usize> = pyo3::FromPyObject::extract(first_element)?;
let dims: Vec<usize> = pyo3::FromPyObject::extract_bound(&first_element)?;
Ok(PyShape(dims))
} else {
let dims: Vec<usize> = pyo3::FromPyObject::extract(tuple)?;
let dims: Vec<usize> = pyo3::FromPyObject::extract_bound(tuple)?;
Ok(PyShape(dims))
}
}
Expand All @@ -36,7 +36,7 @@ impl From<PyShape> for ::candle::Shape {
pub struct PyShapeWithHole(Vec<isize>);

impl<'source> pyo3::FromPyObject<'source> for PyShapeWithHole {
fn extract(ob: &'source PyAny) -> PyResult<Self> {
fn extract_bound(ob: &Bound<'source, PyAny>) -> PyResult<Self> {
if ob.is_none() {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"Shape cannot be None",
Expand All @@ -46,9 +46,9 @@ impl<'source> pyo3::FromPyObject<'source> for PyShapeWithHole {
let tuple = ob.downcast::<pyo3::types::PyTuple>()?;
let dims: Vec<isize> = if tuple.len() == 1 {
let first_element = tuple.get_item(0)?;
pyo3::FromPyObject::extract(first_element)?
pyo3::FromPyObject::extract_bound(&first_element)?
} else {
pyo3::FromPyObject::extract(tuple)?
pyo3::FromPyObject::extract_bound(tuple)?
};

// Ensure we have only positive numbers and at most one "hole" (-1)
Expand Down
Loading