diff --git a/Cargo.lock b/Cargo.lock index f701f3d..4bc56f7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -80,9 +80,9 @@ dependencies = [ [[package]] name = "anybytes" -version = "0.1.0" +version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60637d581e64c883b418fe0bd56d08dde5a3ae6812e82743af45ab2ee9cfdd37" +checksum = "f4a2cf0427c0ec3291d0ceee9e4d2fe349f969e4f1b4ad8447b12a66263070f3" dependencies = [ "bytes", "memmap2", @@ -123,13 +123,13 @@ checksum = "d92bec98840b8f03a5ff5413de5293bfcd8bf96467cf5452609f939ec6f5de16" [[package]] name = "async-trait" -version = "0.1.80" +version = "0.1.81" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca" +checksum = "6e0c28dcc82d7c8ead5cb13beb15405b57b8546e93215673ff8ca0349a028107" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.70", ] [[package]] @@ -198,7 +198,7 @@ dependencies = [ "regex", "rustc-hash", "shlex", - "syn 2.0.68", + "syn 2.0.70", "which", ] @@ -314,9 +314,9 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.0.101" +version = "1.0.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac367972e516d45567c7eafc73d24e1c193dcf200a8d94e9db7b3d38b349572d" +checksum = "066fce287b1d4eafef758e89e09d724a24808a9196fe9756b8ca90e86d0719a2" dependencies = [ "jobserver", "libc", @@ -347,7 +347,7 @@ dependencies = [ "android-tzdata", "iana-time-zone", "num-traits", - "windows-targets 0.52.5", + "windows-targets 0.52.6", ] [[package]] @@ -545,7 +545,7 @@ checksum = "f46882e17999c6cc590af592290432be3bce0428cb0d5f8b6715e4dc7b383eb3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.70", ] [[package]] @@ -566,7 +566,7 @@ checksum = "67e77553c4162a157adbf834ebae5b415acbecbeafc7a74b0e886657506a7611" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.70", ] [[package]] @@ -650,7 +650,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.70", ] [[package]] @@ -811,7 +811,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.70", ] [[package]] @@ -1252,7 +1252,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e310b3a6b5907f99202fcdb4960ff45b93735d7c7d96b760fcff8db2dc0e103d" dependencies = [ "cfg-if", - "windows-targets 0.48.5", + "windows-targets 0.52.6", ] [[package]] @@ -1279,9 +1279,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.21" +version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" +checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" [[package]] name = "mach2" @@ -1453,9 +1453,9 @@ dependencies = [ [[package]] name = "object" -version = "0.36.0" +version = "0.36.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "576dfe1fc8f9df304abb159d767a29d0476f7750fbf8aa7ad07816004a207434" +checksum = "081b846d1d56ddfc18fdf1a922e4f6e07a11768ea1b92dec44e42b72712ccfce" dependencies = [ "memchr", ] @@ -1489,9 +1489,9 @@ checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" [[package]] name = "oorandom" -version = "11.1.3" +version = "11.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575" +checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9" [[package]] name = "oxigraph" @@ -1616,7 +1616,7 @@ dependencies = [ "libc", "redox_syscall 0.5.2", "smallvec", - "windows-targets 0.52.5", + "windows-targets 0.52.6", ] [[package]] @@ -1739,7 +1739,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5f12335488a2f3b0a83b14edad48dca9879ce89b2edd10e80237e4e852dd645e" dependencies = [ "proc-macro2", - "syn 2.0.68", + "syn 2.0.70", ] [[package]] @@ -1818,7 +1818,7 @@ dependencies = [ "proc-macro2", "pyo3-macros-backend", "quote", - "syn 2.0.68", + "syn 2.0.70", ] [[package]] @@ -1831,7 +1831,7 @@ dependencies = [ "proc-macro2", "pyo3-build-config", "quote", - "syn 2.0.68", + "syn 2.0.70", ] [[package]] @@ -2104,9 +2104,9 @@ checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b" [[package]] name = "serde" -version = "1.0.203" +version = "1.0.204" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7253ab4de971e72fb7be983802300c30b5a7f0c2e56fab8abfc6a214307c0094" +checksum = "bc76f558e0cbb2a839d37354c575f1dc3fdc6546b5be373ba43d95f231bf7c12" dependencies = [ "serde_derive", ] @@ -2123,20 +2123,20 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.203" +version = "1.0.204" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "500cbc0ebeb6f46627f50f3f5811ccf6bf00643be300b4c3eabc0ef55dc5b5ba" +checksum = "e0cd7e117be63d3c3678776753929474f3b04a43a080c744d6b0ae2a8c28e222" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.70", ] [[package]] name = "serde_json" -version = "1.0.118" +version = "1.0.120" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d947f6b3163d8857ea16c4fa0dd4840d52f3041039a85decd46867eb1abef2e4" +checksum = "4e0d21c9a8cae1235ad58a00c11cb40d4b1e5c784f1ef2c537876ed6ffd8b7c5" dependencies = [ "itoa", "ryu", @@ -2314,9 +2314,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.68" +version = "2.0.70" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "901fa70d88b9d6c98022e23b4136f9f3e54e4662c3bc1bd1d84a42a9a0f0c1e9" +checksum = "2f0209b68b3613b093e0ec905354eccaedcfe83b8cb37cbdeae64026c3064c16" dependencies = [ "proc-macro2", "quote", @@ -2339,9 +2339,9 @@ dependencies = [ [[package]] name = "target-lexicon" -version = "0.12.14" +version = "0.12.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1fc403891a21bcfb7c37834ba66a547a8f402146eba7265b5a6d88059c9ff2f" +checksum = "4873307b7c257eddcb50c9bedf158eb669578359fb28428bef438fec8e6ba7c2" [[package]] name = "tempfile" @@ -2390,7 +2390,7 @@ checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.70", ] [[package]] @@ -2405,9 +2405,9 @@ dependencies = [ [[package]] name = "tinyvec" -version = "1.6.1" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c55115c6fbe2d2bef26eb09ad74bde02d8255476fc0c7b515ef09fbb35742d82" +checksum = "445e881f4f6d382d5f27c034e25eb92edd7c784ceab92a0937db7f2e9471b938" dependencies = [ "tinyvec_macros", ] @@ -2438,7 +2438,7 @@ checksum = "5f5ae998a069d4b5aba8ee9dad856af7d520c3699e6159b185c2acd48155d39a" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.70", ] [[package]] @@ -2460,7 +2460,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.70", ] [[package]] @@ -2475,7 +2475,7 @@ dependencies = [ [[package]] name = "tribles" version = "0.2.0-alpha-1" -source = "git+https://github.com/triblesspace/tribles-rust.git#ce4ef272600932a0a35c3ca9ed632e1fce70fdfb" +source = "git+https://github.com/triblesspace/tribles-rust.git#a8707a22f568ae97d83822129210a9dee9d979ef" dependencies = [ "anybytes", "anyhow", @@ -2637,7 +2637,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.70", "wasm-bindgen-shared", ] @@ -2659,7 +2659,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.70", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -2754,7 +2754,7 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" dependencies = [ - "windows-targets 0.52.5", + "windows-targets 0.52.6", ] [[package]] @@ -2763,7 +2763,7 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" dependencies = [ - "windows-targets 0.52.5", + "windows-targets 0.52.6", ] [[package]] @@ -2783,18 +2783,18 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" dependencies = [ - "windows_aarch64_gnullvm 0.52.5", - "windows_aarch64_msvc 0.52.5", - "windows_i686_gnu 0.52.5", + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", "windows_i686_gnullvm", - "windows_i686_msvc 0.52.5", - "windows_x86_64_gnu 0.52.5", - "windows_x86_64_gnullvm 0.52.5", - "windows_x86_64_msvc 0.52.5", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", ] [[package]] @@ -2805,9 +2805,9 @@ checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" [[package]] name = "windows_aarch64_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" [[package]] name = "windows_aarch64_msvc" @@ -2817,9 +2817,9 @@ checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" [[package]] name = "windows_aarch64_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" [[package]] name = "windows_i686_gnu" @@ -2829,15 +2829,15 @@ checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" [[package]] name = "windows_i686_gnu" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" [[package]] name = "windows_i686_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" [[package]] name = "windows_i686_msvc" @@ -2847,9 +2847,9 @@ checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" [[package]] name = "windows_i686_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" [[package]] name = "windows_x86_64_gnu" @@ -2859,9 +2859,9 @@ checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" [[package]] name = "windows_x86_64_gnu" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" [[package]] name = "windows_x86_64_gnullvm" @@ -2871,9 +2871,9 @@ checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" [[package]] name = "windows_x86_64_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" [[package]] name = "windows_x86_64_msvc" @@ -2883,9 +2883,9 @@ checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" [[package]] name = "windows_x86_64_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "winreg" @@ -2899,9 +2899,9 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.7.34" +version = "0.7.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae87e3fcd617500e5d106f0380cf7b77f3c6092aae37191433159dda23cfb087" +checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" dependencies = [ "byteorder", "zerocopy-derive", @@ -2909,13 +2909,13 @@ dependencies = [ [[package]] name = "zerocopy-derive" -version = "0.7.34" +version = "0.7.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "15e934569e47891f7d9411f1a451d947a60e000ab3bd24fbb970f000387d1b3b" +checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.68", + "syn 2.0.70", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index d19dc7d..4cdd3f9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,7 @@ pyo3 = "0.21.1" tribles = {git = "https://github.com/triblesspace/tribles-rust.git"} digest = "0.10.7" zerocopy = { version = "0.7.34", features = ["derive"] } -anybytes = "0.1.0" +anybytes = "0.1.2" hdf5 = { version = "0.8" } ndarray = { version = "0.15" } diff --git a/notebooks/experiments.py b/notebooks/experiments.py index e1bb2b0..375ba2c 100644 --- a/notebooks/experiments.py +++ b/notebooks/experiments.py @@ -1,6 +1,6 @@ import marimo -__generated_with = "0.6.23" +__generated_with = "0.6.26" app = marimo.App() @@ -22,7 +22,7 @@ def __(): import dorf import importlib importlib.reload(dorf) - return importlib, dorf + return dorf, importlib @app.cell @@ -48,13 +48,13 @@ def __(): @app.cell -def __(mnist_dataset, dorf): - dorf.bench_mnist(mnist_dataset, True) +def __(dorf, mnist_dataset): + dorf.bench_mnist_dorf(mnist_dataset, True) return @app.cell -def __(mo, dorf): +def __(dorf, mo): with mo.redirect_stdout(): dorf.printstuff() return diff --git a/src/benchmarks/ann_mnist_784_euclidean.rs b/src/benchmarks/ann_mnist_784_euclidean.rs index ce33dcc..776eadd 100644 --- a/src/benchmarks/ann_mnist_784_euclidean.rs +++ b/src/benchmarks/ann_mnist_784_euclidean.rs @@ -1,9 +1,13 @@ use cpu_time::ProcessTime; +use tribles::{types::hash::Blake3, BlobSet}; +use zerocopy::{F32, LE}; use std::{ io::Write, time::{Duration, SystemTime}}; use anndists::dist::*; use hnsw_rs::prelude::*; +use crate::ml::{Embedding, SW, ZC}; + pub fn run_hnsw(stdout: &mut impl Write, fname: String, parallel: bool) -> Result<(), hdf5::Error> { // # load dataset let file = hdf5::File::open(&fname)?; @@ -151,7 +155,10 @@ pub fn run_hnsw(stdout: &mut impl Write, fname: String, parallel: bool) -> Resul Ok(()) } + pub fn run_dorf(stdout: &mut impl Write, fname: String, parallel: bool) -> Result<(), hdf5::Error> { + writeln!(stdout, "Loading dataset...").unwrap(); + // # load dataset let file = hdf5::File::open(&fname)?; @@ -164,20 +171,71 @@ pub fn run_dorf(stdout: &mut impl Write, fname: String, parallel: bool) -> Resul .read_2d::()?; // load test data - let test_data: Vec<_> = file.dataset("test")? + let test_data: Vec> = file.dataset("test")? .read_2d::()? .rows().into_iter() - .map(|row| row.to_vec()) + .map(|row| row.as_slice().unwrap().try_into().unwrap()) .collect(); // load train data - let train_data: Vec<_> = file.dataset("train")? + let train_data: Vec>> = file.dataset("train")? .read_2d::()? .rows().into_iter() - .map(|row| row.to_vec()) + .map(|row| { + let slice = row.as_slice().unwrap(); + assert!(!slice.iter().any(|i| i.is_nan())); + let embedding:Embedding<784, f32> = slice.try_into().unwrap(); + assert!(!embedding.iter().any(|i| i.is_nan())); + let zc_embedding: ZC> = embedding.into(); + assert!(!zc_embedding.iter().any(|i| i.is_nan())); + + zc_embedding + }) .collect(); - //let train_embeddings = + writeln!(stdout, "Loading complete...").unwrap(); + + let blobs: BlobSet = BlobSet::new(); + let mut sw = SW::new(blobs, |n: &ZC>, o: &ZC>| { + assert!(n.len() == 784); + assert!(o.len() == 784); + assert!(!n.iter().any(|i| i.is_nan())); + assert!(!o.iter().any(|i| i.is_nan())); + + DistL2::eval(&DistL2{}, n, o) +}); + + writeln!(stdout, "Caching embeddings...").unwrap(); + let start = ProcessTime::now(); + for d in train_data { + sw.insert(d); + } + let cpu_time: Duration = start.elapsed(); + writeln!(stdout, "Caching completed in {:?}...", cpu_time).unwrap(); + + writeln!(stdout, "Preparing sw...").unwrap(); + let start = ProcessTime::now(); + sw.prepare(); + let cpu_time: Duration = start.elapsed(); + writeln!(stdout, "Preparations completed in {:?}...", cpu_time).unwrap(); + + writeln!(stdout, "Stepping sw with {:?} nodes...", sw.nodes.len()).unwrap(); + let start = ProcessTime::now(); + sw.step(); + let cpu_time: Duration = start.elapsed(); + writeln!(stdout, "Stepping completed in {:?}...", cpu_time).unwrap(); Ok(()) +} + + + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn dorf() { + run_dorf(&mut std::io::stdout(), "/Users/jp/Desktop/triblespace/dorf/datasets/fashion-mnist-784-euclidean.hdf5".into(), false).unwrap(); + } } \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 4123498..29d66fd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,9 +11,17 @@ fn bench_mnist_hnsw(fname: String, parallel: bool) { benchmarks::ann_mnist_784_euclidean::run_hnsw(&mut stdout, fname, parallel).unwrap(); } +/// Run bench on mnist784 dataset. +#[pyfunction] +fn bench_mnist_dorf(fname: String, parallel: bool) { + let mut stdout = stdio::stdout(); + benchmarks::ann_mnist_784_euclidean::run_dorf(&mut stdout, fname, parallel).unwrap(); +} + /// A Python module implemented in Rust. #[pymodule] fn dorf(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(bench_mnist_hnsw, m)?)?; + m.add_function(wrap_pyfunction!(bench_mnist_dorf, m)?)?; Ok(()) } diff --git a/src/ml.rs b/src/ml.rs index c4ce936..932605b 100644 --- a/src/ml.rs +++ b/src/ml.rs @@ -1,15 +1,26 @@ +use std::fmt::Debug; use std::marker::PhantomData; +use std::ops::Deref; +use std::sync::Arc; +use anndists::dist::{DistL2, Distance}; use anybytes::ByteOwner; use digest::consts::U32; use digest::Digest; use tribles::types::Hash; -use tribles::{BlobParseError, BlobSet, Bloblike, Bytes, Handle}; -use tribles::types::hash::Blake3; +use tribles::{BlobParseError, BlobSet, Bloblike, Bytes, Handle, Value}; use zerocopy::{AsBytes, FromBytes, FromZeroes}; -#[derive(AsBytes, FromZeroes, FromBytes)] +#[cfg(not(target_endian = "little"))] +compile_error!("This crate does not compile on BE architectures. +The reason being that most libraries just assume that they run on LE platforms, +e.g. when performing zero copy reads from transmuted arrays. +So long as Rust does not refine its handling of endianess, e.g. by introducing explicit endian +number types in core, we have no other choice than to pave the cow paths and assume +that all native numbers are little endian."); + +#[derive(AsBytes, FromZeroes, FromBytes, Debug)] #[repr(transparent)] pub struct Embedding([T; LEN]); @@ -34,22 +45,21 @@ impl ByteOwner for Box { +pub struct ZC { bytes: Bytes, _type: PhantomData, } -impl ZeroCopy -where T: ByteOwner { - pub fn from(owner: T) -> ZeroCopy { - ZeroCopy { - bytes: Bytes::from_owner(owner), - _type: PhantomData - } +impl std::fmt::Debug for ZC +where T: FromBytes + Debug { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let inner: &T = self; + Debug::fmt(inner, f) } } -impl std::ops::Deref for ZeroCopy + +impl std::ops::Deref for ZC where T: FromBytes { type Target = T; @@ -59,7 +69,27 @@ where T: FromBytes { } } -impl Bloblike for ZeroCopy +impl From for ZC +where T: ByteOwner { + fn from(value: T) -> Self { + ZC { + bytes: Bytes::from_owner(value), + _type: PhantomData + } + } +} + +impl From> for ZC +where T: ByteOwner { + fn from(value: Arc) -> Self { + ZC { + bytes: Bytes::from_arc(value), + _type: PhantomData + } + } +} + +impl Bloblike for ZC where T: FromBytes { fn into_blob(self) -> Bytes { self.bytes @@ -69,7 +99,7 @@ where T: FromBytes { if ::ref_from(&blob).is_none() { Err(BlobParseError::new("wrong size or alignment of bytes for type")) } else { - Ok(ZeroCopy {bytes: blob, _type: PhantomData}) + Ok(ZC {bytes: blob, _type: PhantomData}) } } @@ -82,36 +112,91 @@ where T: FromBytes { } } +#[derive(Debug)] pub enum EmbeddingError { BadLength } impl TryFrom> for Embedding { - type Error = EmbeddingError; + type Error = Vec; fn try_from(vec: Vec) -> Result { - let v = vec.try_into().map_err(|_| EmbeddingError::BadLength)?; + let v = vec.try_into()?; Ok(Embedding(v)) } } -/* -pub struct SW { - blobs: BlobSet +impl<'a, const LEN: usize, T> TryFrom<&'a [T]> for Embedding +where [T; LEN]: TryFrom<&'a [T]> { + type Error = EmbeddingError; + + fn try_from(value: &'a [T]) -> Result { + if value.len() != LEN { + return Err(EmbeddingError::BadLength) + } + let Ok(arr) = value.try_into() else { panic!("failed conversion despite correct length") }; + Ok(Embedding(arr)) + } +} + + +pub struct SW { + blobs: BlobSet, + pub nodes: Vec>, + pub steps: Vec>, + pub dist: F, } -impl SW -where H: Digest{ - fn new(blobs: BlobSet) -> Self { - return Self{ blobs }; +impl SW +where H: Digest, + T: Bloblike + ?Sized, + F: Fn(&T, &T) -> f32 { + pub fn new(blobs: BlobSet, dist: F) -> Self { + return Self { blobs, nodes: vec![], steps: vec![], dist }; } - fn embed(&mut self, data: T) -> Handle { - let embedding = T.into(); - let handle = self.blobs.put(embedding); + pub fn insert(&mut self, node: T) -> Handle { + let handle = self.blobs.insert(node); + self.nodes.push(handle); handle } + + pub fn prepare(&mut self) { + self.nodes.sort(); + self.nodes.dedup(); + + let mut step_0 = vec![self.nodes.len() - 1]; + step_0.extend(0..self.nodes.len() - 1); + + self.steps.push(step_0); + } + + pub fn step(&mut self) { + if let Some(step) = self.steps.last() { + let mut next = vec![]; + + for (node_i, &target_i) in step.into_iter().enumerate() { + let hop_i = step[target_i]; + let node = self.blobs.get(self.nodes[node_i]).unwrap().unwrap(); + let target = self.blobs.get(self.nodes[target_i]).unwrap().unwrap(); + let hop = self.blobs.get(self.nodes[hop_i]).unwrap().unwrap(); + + let target_distance = (self.dist)(&node, &target); + let hop_distance = (self.dist)(&node, &hop); + + let next_i = if hop_distance < target_distance { + target_i + } else { + hop_i + }; + + next.push(next_i); + } + + self.steps.push(next); + } + } } -*/ + /* #[cfg(test)] mod tests { @@ -128,7 +213,7 @@ mod tests { NS! { pub namespace library { "47390346743AC0879BA0E77B95B9683F" as title: ShortString; - "7B7D8B046B3FCC7CD7888C5FF03D34E8" as embedded_title: Handle; + "7B7D8B046B3FCC7CD7888C5FF03D34E8" as embedded_title: Handle>; } } @@ -137,20 +222,21 @@ mod tests { let blobs = BlobSet::::new(); let mut books = TribleSet::new(); let mut book_embeddings = HNSW::new(blobs); + let embedder = Embedder::new(); books.union(library::entity!({ title: ShortString::new("LOTR").unwrap(), - embedded_title: book_embeddings.embed("LOTR") + embedded_title: book_embeddings.insert(embedder.embed("LOTR")) })); books.union(library::entity!({ title: ShortString::new("Dragonrider").unwrap(), - embedded_title: book_embeddings.embed("Dragonrider") + embedded_title: book_embeddings.insert(embedder.embed("Dragonrider")) })); books.union(library::entity!({ title: ShortString::new("Highlander").unwrap(), - embedded_title: book_embeddings.embed("Highlander") + embedded_title: book_embeddings.insert(embedder.embed("Highlander")) })); let similar: Vec<_> = find!(