Skip to content

Commit

Permalink
Merge pull request #67 from jkawamoto/display
Browse files Browse the repository at this point in the history
  • Loading branch information
jkawamoto authored Jul 10, 2024
2 parents d5a336f + 7aece6d commit 665954e
Showing 1 changed file with 100 additions and 8 deletions.
108 changes: 100 additions & 8 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@

//! Configs and associated enums.
use std::fmt::{Debug, Display, Formatter, Pointer};

use cxx::UniquePtr;

pub use ffi::{
get_device_count, get_log_level, get_random_seed, set_log_level, set_random_seed, BatchType,
ComputeType, Device, LogLevel,
BatchType, ComputeType, Device, get_device_count, get_log_level, get_random_seed,
LogLevel, set_log_level, set_random_seed,
};

#[cxx::bridge]
Expand All @@ -36,7 +38,7 @@ pub(crate) mod ffi {
/// # assert_eq!(device, Device::CPU);
/// ```
///
#[derive(Debug)]
#[derive(Copy, Clone, Debug)]
#[repr(i32)]
enum Device {
CPU,
Expand Down Expand Up @@ -81,7 +83,7 @@ pub(crate) mod ffi {
/// # assert_eq!(compute_type, ComputeType::DEFAULT);
/// ```
///
#[derive(Debug)]
#[derive(Copy, Clone, Debug)]
#[repr(i32)]
enum ComputeType {
DEFAULT,
Expand Down Expand Up @@ -116,7 +118,7 @@ pub(crate) mod ffi {
/// let batch_type = BatchType::default();
/// # assert_eq!(batch_type, BatchType::Examples);
/// ```
#[derive(Debug)]
#[derive(Copy, Clone, Debug)]
#[repr(i32)]
enum BatchType {
Examples,
Expand Down Expand Up @@ -146,7 +148,7 @@ pub(crate) mod ffi {
/// let log_level = LogLevel::default();
/// # assert_eq!(log_level, LogLevel::Warning);
/// ```
#[derive(Debug)]
#[derive(Copy, Clone, Debug)]
#[repr(i32)]
enum LogLevel {
Off = -3,
Expand Down Expand Up @@ -223,24 +225,77 @@ impl Default for Device {
}
}

impl Display for Device {
fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
match *self {
Device::CPU => write!(f, "CPU"),
Device::CUDA => write!(f, "CUDA"),
_ => write!(f, "Unknown"),
}
}
}

impl Default for ComputeType {
fn default() -> Self {
Self::DEFAULT
}
}

impl Display for ComputeType {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match *self {
ComputeType::DEFAULT => write!(f, "default"),
ComputeType::AUTO => write!(f, "auto"),
ComputeType::FLOAT32 => write!(f, "float32"),
ComputeType::INT8 => write!(f, "int8"),
ComputeType::INT8_FLOAT32 => write!(f, "int8_float32"),
ComputeType::INT8_FLOAT16 => write!(f, "int8_float16"),
ComputeType::INT8_BFLOAT16 => write!(f, "int8_bfloat16"),
ComputeType::INT16 => write!(f, "int16"),
ComputeType::FLOAT16 => write!(f, "float16"),
ComputeType::BFLOAT16 => write!(f, "bfloat16"),
_ => write!(f, "unknown"),
}
}
}

impl Default for BatchType {
fn default() -> Self {
Self::Examples
}
}

impl Display for BatchType {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match *self {
BatchType::Examples => write!(f, "examples"),
BatchType::Tokens => write!(f, "tokens"),
_ => write!(f, "unknown"),
}
}
}

impl Default for LogLevel {
fn default() -> Self {
Self::Warning
}
}

impl Display for LogLevel {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match *self {
LogLevel::Off => write!(f, "off"),
LogLevel::Critical => write!(f, "critical"),
LogLevel::Error => write!(f, "error"),
LogLevel::Warning => write!(f, "warning"),
LogLevel::Info => write!(f, "info"),
LogLevel::Debug => write!(f, "debug"),
LogLevel::Trace => write!(f, "trace"),
_ => write!(f, "unknown"),
}
}
}

/// The `Config` structure holds the configuration settings for CTranslator2.
///
/// # Examples
Expand Down Expand Up @@ -314,10 +369,47 @@ mod tests {
use rand::random;

use crate::config::{
get_device_count, get_log_level, get_random_seed, set_log_level, set_random_seed, Config,
Device, LogLevel,
BatchType, ComputeType, Config, Device, get_device_count,
get_log_level, get_random_seed, LogLevel, set_log_level, set_random_seed,
};

#[test]
fn test_device_display() {
assert_eq!(format!("{}", Device::CPU), "CPU");
assert_eq!(format!("{}", Device::CUDA), "CUDA");
}

#[test]
fn test_compute_type_display() {
assert_eq!(format!("{}", ComputeType::DEFAULT), "default");
assert_eq!(format!("{}", ComputeType::AUTO), "auto");
assert_eq!(format!("{}", ComputeType::FLOAT32), "float32");
assert_eq!(format!("{}", ComputeType::INT8), "int8");
assert_eq!(format!("{}", ComputeType::INT8_FLOAT32), "int8_float32");
assert_eq!(format!("{}", ComputeType::INT8_FLOAT16), "int8_float16");
assert_eq!(format!("{}", ComputeType::INT8_BFLOAT16), "int8_bfloat16");
assert_eq!(format!("{}", ComputeType::INT16), "int16");
assert_eq!(format!("{}", ComputeType::FLOAT16), "float16");
assert_eq!(format!("{}", ComputeType::BFLOAT16), "bfloat16");
}

#[test]
fn test_batch_type_display() {
assert_eq!(format!("{}", BatchType::Examples), "examples");
assert_eq!(format!("{}", BatchType::Tokens), "tokens");
}

#[test]
fn test_log_level_display() {
assert_eq!(format!("{}", LogLevel::Off), "off");
assert_eq!(format!("{}", LogLevel::Critical), "critical");
assert_eq!(format!("{}", LogLevel::Error), "error");
assert_eq!(format!("{}", LogLevel::Warning), "warning");
assert_eq!(format!("{}", LogLevel::Info), "info");
assert_eq!(format!("{}", LogLevel::Debug), "debug");
assert_eq!(format!("{}", LogLevel::Trace), "trace");
}

#[test]
fn test_config_to_ffi() {
let config = Config::default();
Expand Down

0 comments on commit 665954e

Please sign in to comment.