diff --git a/src/config.rs b/src/config.rs index 4046a27..37b41da 100644 --- a/src/config.rs +++ b/src/config.rs @@ -8,11 +8,13 @@ //! Configs and associated enums. +use std::fmt::{Debug, Display, Formatter, Pointer}; + use cxx::UniquePtr; pub use ffi::{ - get_device_count, get_log_level, get_random_seed, set_log_level, set_random_seed, BatchType, - ComputeType, Device, LogLevel, + BatchType, ComputeType, Device, get_device_count, get_log_level, get_random_seed, + LogLevel, set_log_level, set_random_seed, }; #[cxx::bridge] @@ -36,7 +38,7 @@ pub(crate) mod ffi { /// # assert_eq!(device, Device::CPU); /// ``` /// - #[derive(Debug)] + #[derive(Copy, Clone, Debug)] #[repr(i32)] enum Device { CPU, @@ -81,7 +83,7 @@ pub(crate) mod ffi { /// # assert_eq!(compute_type, ComputeType::DEFAULT); /// ``` /// - #[derive(Debug)] + #[derive(Copy, Clone, Debug)] #[repr(i32)] enum ComputeType { DEFAULT, @@ -116,7 +118,7 @@ pub(crate) mod ffi { /// let batch_type = BatchType::default(); /// # assert_eq!(batch_type, BatchType::Examples); /// ``` - #[derive(Debug)] + #[derive(Copy, Clone, Debug)] #[repr(i32)] enum BatchType { Examples, @@ -146,7 +148,7 @@ pub(crate) mod ffi { /// let log_level = LogLevel::default(); /// # assert_eq!(log_level, LogLevel::Warning); /// ``` - #[derive(Debug)] + #[derive(Copy, Clone, Debug)] #[repr(i32)] enum LogLevel { Off = -3, @@ -223,24 +225,77 @@ impl Default for Device { } } +impl Display for Device { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + match *self { + Device::CPU => write!(f, "CPU"), + Device::CUDA => write!(f, "CUDA"), + _ => write!(f, "Unknown"), + } + } +} + impl Default for ComputeType { fn default() -> Self { Self::DEFAULT } } +impl Display for ComputeType { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match *self { + ComputeType::DEFAULT => write!(f, "default"), + ComputeType::AUTO => write!(f, "auto"), + ComputeType::FLOAT32 => write!(f, "float32"), + ComputeType::INT8 => write!(f, "int8"), + ComputeType::INT8_FLOAT32 => write!(f, "int8_float32"), + ComputeType::INT8_FLOAT16 => write!(f, "int8_float16"), + ComputeType::INT8_BFLOAT16 => write!(f, "int8_bfloat16"), + ComputeType::INT16 => write!(f, "int16"), + ComputeType::FLOAT16 => write!(f, "float16"), + ComputeType::BFLOAT16 => write!(f, "bfloat16"), + _ => write!(f, "unknown"), + } + } +} + impl Default for BatchType { fn default() -> Self { Self::Examples } } +impl Display for BatchType { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match *self { + BatchType::Examples => write!(f, "examples"), + BatchType::Tokens => write!(f, "tokens"), + _ => write!(f, "unknown"), + } + } +} + impl Default for LogLevel { fn default() -> Self { Self::Warning } } +impl Display for LogLevel { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match *self { + LogLevel::Off => write!(f, "off"), + LogLevel::Critical => write!(f, "critical"), + LogLevel::Error => write!(f, "error"), + LogLevel::Warning => write!(f, "warning"), + LogLevel::Info => write!(f, "info"), + LogLevel::Debug => write!(f, "debug"), + LogLevel::Trace => write!(f, "trace"), + _ => write!(f, "unknown"), + } + } +} + /// The `Config` structure holds the configuration settings for CTranslator2. /// /// # Examples @@ -314,10 +369,47 @@ mod tests { use rand::random; use crate::config::{ - get_device_count, get_log_level, get_random_seed, set_log_level, set_random_seed, Config, - Device, LogLevel, + BatchType, ComputeType, Config, Device, get_device_count, + get_log_level, get_random_seed, LogLevel, set_log_level, set_random_seed, }; + #[test] + fn test_device_display() { + assert_eq!(format!("{}", Device::CPU), "CPU"); + assert_eq!(format!("{}", Device::CUDA), "CUDA"); + } + + #[test] + fn test_compute_type_display() { + assert_eq!(format!("{}", ComputeType::DEFAULT), "default"); + assert_eq!(format!("{}", ComputeType::AUTO), "auto"); + assert_eq!(format!("{}", ComputeType::FLOAT32), "float32"); + assert_eq!(format!("{}", ComputeType::INT8), "int8"); + assert_eq!(format!("{}", ComputeType::INT8_FLOAT32), "int8_float32"); + assert_eq!(format!("{}", ComputeType::INT8_FLOAT16), "int8_float16"); + assert_eq!(format!("{}", ComputeType::INT8_BFLOAT16), "int8_bfloat16"); + assert_eq!(format!("{}", ComputeType::INT16), "int16"); + assert_eq!(format!("{}", ComputeType::FLOAT16), "float16"); + assert_eq!(format!("{}", ComputeType::BFLOAT16), "bfloat16"); + } + + #[test] + fn test_batch_type_display() { + assert_eq!(format!("{}", BatchType::Examples), "examples"); + assert_eq!(format!("{}", BatchType::Tokens), "tokens"); + } + + #[test] + fn test_log_level_display() { + assert_eq!(format!("{}", LogLevel::Off), "off"); + assert_eq!(format!("{}", LogLevel::Critical), "critical"); + assert_eq!(format!("{}", LogLevel::Error), "error"); + assert_eq!(format!("{}", LogLevel::Warning), "warning"); + assert_eq!(format!("{}", LogLevel::Info), "info"); + assert_eq!(format!("{}", LogLevel::Debug), "debug"); + assert_eq!(format!("{}", LogLevel::Trace), "trace"); + } + #[test] fn test_config_to_ffi() { let config = Config::default();