Skip to content

Commit

Permalink
feat: add StorageView
Browse files Browse the repository at this point in the history
  • Loading branch information
jkawamoto committed Jul 11, 2024
1 parent 665954e commit bc9cd42
Show file tree
Hide file tree
Showing 5 changed files with 202 additions and 2 deletions.
3 changes: 3 additions & 0 deletions build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ fn main() {
println!("cargo:rerun-if-changed=src/translator.cpp");
println!("cargo:rerun-if-changed=src/generator.rs");
println!("cargo:rerun-if-changed=src/generator.cpp");
println!("cargo:rerun-if-changed=src/storage_view.rs");
println!("cargo:rerun-if-changed=include/types.h");
println!("cargo:rerun-if-changed=include/config.h");
println!("cargo:rerun-if-changed=include/translator.h");
println!("cargo:rerun-if-changed=include/generator.h");
println!("cargo:rerun-if-changed=include/storage_view.h");
println!("cargo:rerun-if-changed=CTranslate2");
println!("cargo:rerun-if-env-changed=LIBRARY_PATH");
if let Ok(library_path) = env::var("LIBRARY_PATH") {
Expand Down Expand Up @@ -84,6 +86,7 @@ fn main() {
"src/config.rs",
"src/translator.rs",
"src/generator.rs",
"src/storage_view.rs",
])
.file("src/translator.cpp")
.file("src/generator.cpp")
Expand Down
54 changes: 54 additions & 0 deletions include/storage_view.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// storage_view.h
//
// Copyright (c) 2023-2024 Junpei Kawamoto
//
// This software is released under the MIT License.
//
// http://opensource.org/licenses/mit-license.php

#pragma once

#include <memory>

#include <ctranslate2/storage_view.h>

#include "rust/cxx.h"

using ctranslate2::Device;
using ctranslate2::StorageView;

inline std::unique_ptr<StorageView> storage_view_from_float(
const rust::Vec<int64_t>& shape,
const rust::Vec<float>& init,
const Device device
) {
return std::make_unique<StorageView>(
ctranslate2::Shape(shape.begin(), shape.end()),
std::vector<float>(init.begin(), init.end()),
device
);
}

inline std::unique_ptr<StorageView> storage_view_from_int8(
const rust::Vec<int64_t>& shape,
const rust::Vec<int8_t>& init,
const Device device
) {
return std::make_unique<StorageView>(
ctranslate2::Shape(shape.begin(), shape.end()),
std::vector<int8_t>(init.begin(), init.end()),
device
);
}

inline std::unique_ptr<StorageView> storage_view_from_int16(
const rust::Vec<int64_t>& shape,
const rust::Vec<int16_t>& init,
const Device device
) {
return std::make_unique<StorageView>(
ctranslate2::Shape(shape.begin(), shape.end()),
std::vector<int16_t>(init.begin(), init.end()),
device
);
}
2 changes: 1 addition & 1 deletion src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

//! Configs and associated enums.
use std::fmt::{Debug, Display, Formatter, Pointer};
use std::fmt::{Debug, Display, Formatter};

use cxx::UniquePtr;

Expand Down
4 changes: 3 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,16 +138,18 @@ use std::path::Path;
use anyhow::{anyhow, Result};

use crate::auto::Tokenizer as AutoTokenizer;
use crate::config::Config;
pub use crate::config::{set_log_level, set_random_seed};
use crate::config::Config;
pub use crate::generator::GenerationOptions;
pub use crate::storage_view::StorageView;
pub use crate::translator::TranslationOptions;

pub mod auto;
pub mod bpe;
pub mod config;
pub mod generator;
pub mod sentencepiece;
pub mod storage_view;
pub mod tokenizers;
pub mod translator;
mod types;
Expand Down
141 changes: 141 additions & 0 deletions src/storage_view.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
// storage_view.rs
//
// Copyright (c) 2023-2024 Junpei Kawamoto
//
// This software is released under the MIT License.
//
// http://opensource.org/licenses/mit-license.php

use anyhow::Result;
use cxx::UniquePtr;

use crate::config::Device;

#[cxx::bridge]
mod ffi {
unsafe extern "C++" {
include!("ct2rs/include/storage_view.h");

type Device = crate::config::ffi::Device;

type StorageView;

fn storage_view_from_float(
shape: &Vec<i64>,
init: &Vec<f32>,
device: Device,
) -> Result<UniquePtr<StorageView>>;

fn storage_view_from_int8(
shape: &Vec<i64>,
init: &Vec<i8>,
device: Device,
) -> Result<UniquePtr<StorageView>>;

fn storage_view_from_int16(
shape: &Vec<i64>,
init: &Vec<i16>,
device: Device,
) -> Result<UniquePtr<StorageView>>;

fn device(self: &StorageView) -> Device;

fn size(self: &StorageView) -> i64;

fn rank(self: &StorageView) -> i64;
}
}

/// A Rust binding to the
/// [`ctranslate2::StorageView`](https://opennmt.net/CTranslate2/python/ctranslate2.StorageView.html).
pub struct StorageView {
ptr: UniquePtr<ffi::StorageView>,
}

impl StorageView {
/// Creates a storage view with the given shape from the given array of float values.
pub fn from_f32(shape: &Vec<i64>, init: &Vec<f32>, device: Device) -> Result<Self> {
Ok(Self {
ptr: ffi::storage_view_from_float(shape, init, device)?,
})
}

/// Creates a storage view with the given shape from the given array of int8 values.
pub fn from_i8(shape: &Vec<i64>, init: &Vec<i8>, device: Device) -> Result<Self> {
Ok(Self {
ptr: ffi::storage_view_from_int8(shape, init, device)?,
})
}

/// Creates a storage view with the given shape from the given array of int16 values.
pub fn from_i16(shape: &Vec<i64>, init: &Vec<i16>, device: Device) -> Result<Self> {
Ok(Self {
ptr: ffi::storage_view_from_int16(shape, init, device)?,
})
}

/// Device where the storage is allocated.
pub fn device(&self) -> Device {
self.ptr.device()
}

/// Returns the size of this storage.
pub fn size(&self) -> i64 {
self.ptr.size()
}

/// Returns the rank of this storage.
pub fn rank(&self) -> i64 {
self.ptr.rank()
}

/// Returns true if this storage is empty.
pub fn empty(&self) -> bool {
self.size() == 0
}
}

unsafe impl Send for StorageView {}
unsafe impl Sync for StorageView {}

#[cfg(test)]
mod tests {
use crate::config::Device;
use crate::storage_view::StorageView;

#[test]
fn test_from_f32() {
let shape = vec![2, 2];
let data = vec![3., 4., 5., 6.];
let v = StorageView::from_f32(&shape, &data, Default::default()).unwrap();

assert_eq!(v.size(), data.len() as i64);
assert_eq!(v.rank(), shape.len() as i64);
assert!(!v.empty());
assert_eq!(v.device(), Device::CPU);
}

#[test]
fn test_from_i8() {
let shape = vec![2, 2];
let data = vec![3, 4, 5, 6];
let v = StorageView::from_i8(&shape, &data, Default::default()).unwrap();

assert_eq!(v.size(), data.len() as i64);
assert_eq!(v.rank(), shape.len() as i64);
assert!(!v.empty());
assert_eq!(v.device(), Device::CPU);
}

#[test]
fn test_from_i16() {
let shape = vec![2, 2];
let data = vec![3, 4, 5, 6];
let v = StorageView::from_i16(&shape, &data, Default::default()).unwrap();

assert_eq!(v.size(), data.len() as i64);
assert_eq!(v.rank(), shape.len() as i64);
assert!(!v.empty());
assert_eq!(v.device(), Device::CPU);
}
}

0 comments on commit bc9cd42

Please sign in to comment.