Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add is_safe to rust UUID #70

Merged
merged 18 commits into from
Nov 21, 2024
3 changes: 3 additions & 0 deletions python/uuid_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from uuid import SafeUUID

from ._uuid_utils import (
NAMESPACE_DNS,
NAMESPACE_OID,
Expand Down Expand Up @@ -29,6 +31,7 @@
"RESERVED_NCS",
"RFC_4122",
"UUID",
"SafeUUID",
"__version__",
"getnode",
"uuid1",
Expand Down
29 changes: 23 additions & 6 deletions python/uuid_utils/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import builtins
import sys
from enum import Enum
from uuid import SafeUUID

from typing_extensions import TypeAlias

Expand All @@ -9,11 +9,6 @@ _FieldsType: TypeAlias = tuple[int, int, int, int, int, int]

__version__: str

class SafeUUID(Enum):
safe: int
unsafe: int
unknown: None

class UUID:
"""Instances of the UUID class represent UUIDs as specified in RFC 4122.
UUID objects are immutable, hashable, and usable as dictionary keys.
Expand Down Expand Up @@ -182,3 +177,25 @@ RESERVED_NCS: str
RFC_4122: str
RESERVED_MICROSOFT: str
RESERVED_FUTURE: str

__all__ = [
"NAMESPACE_DNS",
"NAMESPACE_OID",
"NAMESPACE_URL",
"NAMESPACE_X500",
"RESERVED_FUTURE",
"RESERVED_MICROSOFT",
"RESERVED_NCS",
"RFC_4122",
"UUID",
"SafeUUID",
aminalaee marked this conversation as resolved.
Show resolved Hide resolved
"__version__",
"getnode",
"uuid1",
"uuid3",
"uuid4",
"uuid5",
"uuid6",
"uuid7",
"uuid8",
]
28 changes: 26 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
use mac_address::get_mac_address;
use pyo3::{
exceptions::{PyTypeError, PyValueError},
ffi,
prelude::*,
pyclass::CompareOp,
types::{PyBytes, PyDict},
};
use rand::RngCore;
use std::hash::Hasher;
use std::sync::atomic::{AtomicU64, Ordering};
use std::{collections::hash_map::DefaultHasher, hash::Hash};
use std::{hash::Hasher, sync::atomic::AtomicPtr};
use std::{
ptr::null_mut,
sync::atomic::{AtomicU64, Ordering},
};
use uuid::{Builder, Bytes, Context, Timestamp, Uuid, Variant, Version};

static NODE: AtomicU64 = AtomicU64::new(0);
Expand Down Expand Up @@ -306,6 +310,11 @@ impl UUID {
uuid: Uuid::from_u128(int),
})
}

#[getter]
fn is_safe(&self) -> *mut ffi::PyObject {
return SAFE_UUID_UNKNOWN.load(Ordering::Relaxed);
}
}

#[pyfunction]
Expand Down Expand Up @@ -429,13 +438,28 @@ fn _getnode() -> u64 {
node
}

// ptr to python stdlib uuid.SafeUUID.unknown
static SAFE_UUID_UNKNOWN: AtomicPtr<ffi::PyObject> = AtomicPtr::new(null_mut());

#[pyfunction]
fn getnode() -> PyResult<u64> {
Ok(_getnode())
}

#[pymodule]
fn _uuid_utils(m: &Bound<'_, PyModule>) -> PyResult<()> {
let safe_uuid_unknown = Python::with_gil(|py| {
return PyModule::import_bound(py, "uuid")
.unwrap()
.getattr("SafeUUID")
.unwrap()
.getattr("unknown")
.unwrap()
.unbind();
});

SAFE_UUID_UNKNOWN.store(safe_uuid_unknown.into_ptr(), Ordering::Relaxed);

m.add("__version__", env!("CARGO_PKG_VERSION"))?;
m.add_class::<UUID>()?;
m.add_function(wrap_pyfunction!(uuid1, m)?)?;
Expand Down
7 changes: 6 additions & 1 deletion tests/test_uuid.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pickle
import sys
from datetime import datetime
from uuid import UUID, getnode
from uuid import UUID, SafeUUID, getnode

import pytest
import uuid_utils
Expand Down Expand Up @@ -200,6 +200,11 @@ def test_copy() -> None:
assert copy.deepcopy(uuid) == uuid


def test_is_safe() -> None:
assert uuid_utils.uuid1().is_safe is SafeUUID.unknown
assert uuid_utils.uuid4().is_safe is SafeUUID.unknown


@pytest.mark.xfail(sys.platform == "linux", reason="Might fail in Github Actions")
def test_getnode() -> None:
assert uuid_utils.getnode() == getnode()