diff --git a/build.rs b/build.rs index 027d02e..314d1e9 100644 --- a/build.rs +++ b/build.rs @@ -20,10 +20,12 @@ fn main() { println!("cargo:rerun-if-changed=src/translator.cpp"); println!("cargo:rerun-if-changed=src/generator.rs"); println!("cargo:rerun-if-changed=src/generator.cpp"); + println!("cargo:rerun-if-changed=src/storage_view.rs"); println!("cargo:rerun-if-changed=include/types.h"); println!("cargo:rerun-if-changed=include/config.h"); println!("cargo:rerun-if-changed=include/translator.h"); println!("cargo:rerun-if-changed=include/generator.h"); + println!("cargo:rerun-if-changed=include/storage_view.h"); println!("cargo:rerun-if-changed=CTranslate2"); println!("cargo:rerun-if-env-changed=LIBRARY_PATH"); if let Ok(library_path) = env::var("LIBRARY_PATH") { @@ -84,6 +86,7 @@ fn main() { "src/config.rs", "src/translator.rs", "src/generator.rs", + "src/storage_view.rs", ]) .file("src/translator.cpp") .file("src/generator.cpp") diff --git a/include/storage_view.h b/include/storage_view.h new file mode 100644 index 0000000..9afbc56 --- /dev/null +++ b/include/storage_view.h @@ -0,0 +1,63 @@ +// storage_view.h +// +// Copyright (c) 2023-2024 Junpei Kawamoto +// +// This software is released under the MIT License. +// +// http://opensource.org/licenses/mit-license.php + +#pragma once + +#include +#include +#include + +#include + +#include "rust/cxx.h" + +using ctranslate2::Device; +using ctranslate2::StorageView; + +inline std::unique_ptr storage_view_from_float( + const rust::Slice shape, + const rust::Slice init, + const Device device +) { + return std::make_unique( + ctranslate2::Shape(shape.begin(), shape.end()), + std::vector(init.begin(), init.end()), + device + ); +} + +inline std::unique_ptr storage_view_from_int8( + const rust::Slice shape, + const rust::Slice init, + const Device device +) { + return std::make_unique( + ctranslate2::Shape(shape.begin(), shape.end()), + std::vector(init.begin(), init.end()), + device + ); +} + +inline std::unique_ptr storage_view_from_int16( + const rust::Slice shape, + const rust::Slice init, + const Device device +) { + return std::make_unique( + ctranslate2::Shape(shape.begin(), shape.end()), + std::vector(init.begin(), init.end()), + device + ); +} + +rust::String to_string(const StorageView& storage) { + std::ostringstream oss; + oss << storage; + + return rust::String(oss.str()); +} diff --git a/src/config.rs b/src/config.rs index 37b41da..150396f 100644 --- a/src/config.rs +++ b/src/config.rs @@ -8,7 +8,7 @@ //! Configs and associated enums. -use std::fmt::{Debug, Display, Formatter, Pointer}; +use std::fmt::{Debug, Display, Formatter}; use cxx::UniquePtr; diff --git a/src/lib.rs b/src/lib.rs index 9dcbfea..59c8bec 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -138,9 +138,10 @@ use std::path::Path; use anyhow::{anyhow, Result}; use crate::auto::Tokenizer as AutoTokenizer; -use crate::config::Config; pub use crate::config::{set_log_level, set_random_seed}; +use crate::config::Config; pub use crate::generator::GenerationOptions; +pub use crate::storage_view::StorageView; pub use crate::translator::TranslationOptions; pub mod auto; @@ -148,6 +149,7 @@ pub mod bpe; pub mod config; pub mod generator; pub mod sentencepiece; +pub mod storage_view; pub mod tokenizers; pub mod translator; mod types; diff --git a/src/storage_view.rs b/src/storage_view.rs new file mode 100644 index 0000000..df3ade8 --- /dev/null +++ b/src/storage_view.rs @@ -0,0 +1,162 @@ +// storage_view.rs +// +// Copyright (c) 2023-2024 Junpei Kawamoto +// +// This software is released under the MIT License. +// +// http://opensource.org/licenses/mit-license.php + +use std::fmt::{Debug, Formatter}; +use std::ops::Deref; + +use anyhow::Result; +use cxx::UniquePtr; + +use crate::config::Device; + +#[cxx::bridge] +pub(crate) mod ffi { + unsafe extern "C++" { + include!("ct2rs/include/storage_view.h"); + + type Device = crate::config::ffi::Device; + + type StorageView; + + fn storage_view_from_float( + shape: &[usize], + init: &[f32], + device: Device, + ) -> Result>; + + fn storage_view_from_int8( + shape: &[usize], + init: &[i8], + device: Device, + ) -> Result>; + + fn storage_view_from_int16( + shape: &[usize], + init: &[i16], + device: Device, + ) -> Result>; + + fn device(self: &StorageView) -> Device; + + fn size(self: &StorageView) -> i64; + + fn rank(self: &StorageView) -> i64; + + fn to_string(storage: &StorageView) -> String; + } +} + +/// A Rust binding to the +/// [`ctranslate2::StorageView`](https://opennmt.net/CTranslate2/python/ctranslate2.StorageView.html). +pub struct StorageView { + ptr: UniquePtr, +} + +impl StorageView { + /// Creates a storage view with the given shape from the given array of float values. + pub fn from_f32(shape: &[usize], init: &[f32], device: Device) -> Result { + Ok(Self { + ptr: ffi::storage_view_from_float(shape, init, device)?, + }) + } + + /// Creates a storage view with the given shape from the given array of int8 values. + pub fn from_i8(shape: &[usize], init: &[i8], device: Device) -> Result { + Ok(Self { + ptr: ffi::storage_view_from_int8(shape, init, device)?, + }) + } + + /// Creates a storage view with the given shape from the given array of int16 values. + pub fn from_i16(shape: &[usize], init: &[i16], device: Device) -> Result { + Ok(Self { + ptr: ffi::storage_view_from_int16(shape, init, device)?, + }) + } + + /// Device where the storage is allocated. + pub fn device(&self) -> Device { + self.ptr.device() + } + + /// Returns the size of this storage. + pub fn size(&self) -> i64 { + self.ptr.size() + } + + /// Returns the rank of this storage. + pub fn rank(&self) -> i64 { + self.ptr.rank() + } + + /// Returns true if this storage is empty. + pub fn empty(&self) -> bool { + self.size() == 0 + } +} + +impl Debug for StorageView { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", ffi::to_string(self.ptr.deref())) + } +} + +impl Deref for StorageView { + type Target = ffi::StorageView; + + fn deref(&self) -> &Self::Target { + &self.ptr + } +} + +unsafe impl Send for StorageView {} +unsafe impl Sync for StorageView {} + +#[cfg(test)] +mod tests { + use crate::config::Device; + use crate::storage_view::StorageView; + + #[test] + fn test_from_f32() { + let shape = vec![1, 2, 4]; + let data = vec![1., 2., 3., 4., 5., 6., 7., 8.]; + let v = StorageView::from_f32(&shape, &data, Default::default()).unwrap(); + + assert_eq!(v.size(), data.len() as i64); + assert_eq!(v.rank(), shape.len() as i64); + assert!(!v.empty()); + assert_eq!(v.device(), Device::CPU); + + println!("{:?}", v); + } + + #[test] + fn test_from_i8() { + let shape = vec![2, 2]; + let data = vec![3, 4, 5, 6]; + let v = StorageView::from_i8(&shape, &data, Default::default()).unwrap(); + + assert_eq!(v.size(), data.len() as i64); + assert_eq!(v.rank(), shape.len() as i64); + assert!(!v.empty()); + assert_eq!(v.device(), Device::CPU); + } + + #[test] + fn test_from_i16() { + let shape = vec![2, 2]; + let data = vec![3, 4, 5, 6]; + let v = StorageView::from_i16(&shape, &data, Default::default()).unwrap(); + + assert_eq!(v.size(), data.len() as i64); + assert_eq!(v.rank(), shape.len() as i64); + assert!(!v.empty()); + assert_eq!(v.device(), Device::CPU); + } +}