Skip to content

Commit

Permalink
feat: add StorageView
Browse files Browse the repository at this point in the history
  • Loading branch information
jkawamoto committed Jul 12, 2024
1 parent 665954e commit 73afadd
Show file tree
Hide file tree
Showing 5 changed files with 232 additions and 2 deletions.
3 changes: 3 additions & 0 deletions build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ fn main() {
println!("cargo:rerun-if-changed=src/translator.cpp");
println!("cargo:rerun-if-changed=src/generator.rs");
println!("cargo:rerun-if-changed=src/generator.cpp");
println!("cargo:rerun-if-changed=src/storage_view.rs");
println!("cargo:rerun-if-changed=include/types.h");
println!("cargo:rerun-if-changed=include/config.h");
println!("cargo:rerun-if-changed=include/translator.h");
println!("cargo:rerun-if-changed=include/generator.h");
println!("cargo:rerun-if-changed=include/storage_view.h");
println!("cargo:rerun-if-changed=CTranslate2");
println!("cargo:rerun-if-env-changed=LIBRARY_PATH");
if let Ok(library_path) = env::var("LIBRARY_PATH") {
Expand Down Expand Up @@ -84,6 +86,7 @@ fn main() {
"src/config.rs",
"src/translator.rs",
"src/generator.rs",
"src/storage_view.rs",
])
.file("src/translator.cpp")
.file("src/generator.cpp")
Expand Down
63 changes: 63 additions & 0 deletions include/storage_view.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// storage_view.h
//
// Copyright (c) 2023-2024 Junpei Kawamoto
//
// This software is released under the MIT License.
//
// http://opensource.org/licenses/mit-license.php

#pragma once

#include <memory>
#include <sstream>
#include <string>

#include <ctranslate2/storage_view.h>

#include "rust/cxx.h"

using ctranslate2::Device;
using ctranslate2::StorageView;

inline std::unique_ptr<StorageView> storage_view_from_float(
const rust::Slice<const size_t> shape,
const rust::Slice<const float> init,
const Device device
) {
return std::make_unique<StorageView>(
ctranslate2::Shape(shape.begin(), shape.end()),
std::vector<float>(init.begin(), init.end()),
device
);
}

inline std::unique_ptr<StorageView> storage_view_from_int8(
const rust::Slice<const size_t> shape,
const rust::Slice<const int8_t> init,
const Device device
) {
return std::make_unique<StorageView>(
ctranslate2::Shape(shape.begin(), shape.end()),
std::vector<int8_t>(init.begin(), init.end()),
device
);
}

inline std::unique_ptr<StorageView> storage_view_from_int16(
const rust::Slice<const size_t> shape,
const rust::Slice<const int16_t> init,
const Device device
) {
return std::make_unique<StorageView>(
ctranslate2::Shape(shape.begin(), shape.end()),
std::vector<int16_t>(init.begin(), init.end()),
device
);
}

rust::String to_string(const StorageView& storage) {
std::ostringstream oss;
oss << storage;

return rust::String(oss.str());
}
2 changes: 1 addition & 1 deletion src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

//! Configs and associated enums.
use std::fmt::{Debug, Display, Formatter, Pointer};
use std::fmt::{Debug, Display, Formatter};

use cxx::UniquePtr;

Expand Down
4 changes: 3 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,16 +138,18 @@ use std::path::Path;
use anyhow::{anyhow, Result};

use crate::auto::Tokenizer as AutoTokenizer;
use crate::config::Config;
pub use crate::config::{set_log_level, set_random_seed};
use crate::config::Config;
pub use crate::generator::GenerationOptions;
pub use crate::storage_view::StorageView;
pub use crate::translator::TranslationOptions;

pub mod auto;
pub mod bpe;
pub mod config;
pub mod generator;
pub mod sentencepiece;
pub mod storage_view;
pub mod tokenizers;
pub mod translator;
mod types;
Expand Down
162 changes: 162 additions & 0 deletions src/storage_view.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
// storage_view.rs
//
// Copyright (c) 2023-2024 Junpei Kawamoto
//
// This software is released under the MIT License.
//
// http://opensource.org/licenses/mit-license.php

use std::fmt::{Debug, Formatter};
use std::ops::Deref;

use anyhow::Result;
use cxx::UniquePtr;

use crate::config::Device;

#[cxx::bridge]
pub(crate) mod ffi {
unsafe extern "C++" {
include!("ct2rs/include/storage_view.h");

type Device = crate::config::ffi::Device;

type StorageView;

fn storage_view_from_float(
shape: &[usize],
init: &[f32],
device: Device,
) -> Result<UniquePtr<StorageView>>;

fn storage_view_from_int8(
shape: &[usize],
init: &[i8],
device: Device,
) -> Result<UniquePtr<StorageView>>;

fn storage_view_from_int16(
shape: &[usize],
init: &[i16],
device: Device,
) -> Result<UniquePtr<StorageView>>;

fn device(self: &StorageView) -> Device;

fn size(self: &StorageView) -> i64;

fn rank(self: &StorageView) -> i64;

fn to_string(storage: &StorageView) -> String;
}
}

/// A Rust binding to the
/// [`ctranslate2::StorageView`](https://opennmt.net/CTranslate2/python/ctranslate2.StorageView.html).
pub struct StorageView {
ptr: UniquePtr<ffi::StorageView>,
}

impl StorageView {
/// Creates a storage view with the given shape from the given array of float values.
pub fn from_f32(shape: &[usize], init: &[f32], device: Device) -> Result<Self> {
Ok(Self {
ptr: ffi::storage_view_from_float(shape, init, device)?,
})
}

/// Creates a storage view with the given shape from the given array of int8 values.
pub fn from_i8(shape: &[usize], init: &[i8], device: Device) -> Result<Self> {
Ok(Self {
ptr: ffi::storage_view_from_int8(shape, init, device)?,
})
}

/// Creates a storage view with the given shape from the given array of int16 values.
pub fn from_i16(shape: &[usize], init: &[i16], device: Device) -> Result<Self> {
Ok(Self {
ptr: ffi::storage_view_from_int16(shape, init, device)?,
})
}

/// Device where the storage is allocated.
pub fn device(&self) -> Device {
self.ptr.device()
}

/// Returns the size of this storage.
pub fn size(&self) -> i64 {
self.ptr.size()
}

/// Returns the rank of this storage.
pub fn rank(&self) -> i64 {
self.ptr.rank()
}

/// Returns true if this storage is empty.
pub fn empty(&self) -> bool {
self.size() == 0
}
}

impl Debug for StorageView {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", ffi::to_string(self.ptr.deref()))
}
}

impl Deref for StorageView {
type Target = ffi::StorageView;

fn deref(&self) -> &Self::Target {
&self.ptr
}
}

unsafe impl Send for StorageView {}
unsafe impl Sync for StorageView {}

#[cfg(test)]
mod tests {
use crate::config::Device;
use crate::storage_view::StorageView;

#[test]
fn test_from_f32() {
let shape = vec![1, 2, 4];
let data = vec![1., 2., 3., 4., 5., 6., 7., 8.];
let v = StorageView::from_f32(&shape, &data, Default::default()).unwrap();

assert_eq!(v.size(), data.len() as i64);
assert_eq!(v.rank(), shape.len() as i64);
assert!(!v.empty());
assert_eq!(v.device(), Device::CPU);

println!("{:?}", v);
}

#[test]
fn test_from_i8() {
let shape = vec![2, 2];
let data = vec![3, 4, 5, 6];
let v = StorageView::from_i8(&shape, &data, Default::default()).unwrap();

assert_eq!(v.size(), data.len() as i64);
assert_eq!(v.rank(), shape.len() as i64);
assert!(!v.empty());
assert_eq!(v.device(), Device::CPU);
}

#[test]
fn test_from_i16() {
let shape = vec![2, 2];
let data = vec![3, 4, 5, 6];
let v = StorageView::from_i16(&shape, &data, Default::default()).unwrap();

assert_eq!(v.size(), data.len() as i64);
assert_eq!(v.rank(), shape.len() as i64);
assert!(!v.empty());
assert_eq!(v.device(), Device::CPU);
}
}

0 comments on commit 73afadd

Please sign in to comment.