From 354190688f3eae64864fb8e351232255ae59ed9d Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Tue, 2 Apr 2024 14:28:57 -0400 Subject: [PATCH 01/11] Add multilabel classification dataset - Add MultiLabel annotation support - Refactor de/serialize annotation with AnnotationRaw - Add ImageFolderDataset::with_items methods --- crates/burn-dataset/Cargo.toml | 3 +- .../burn-dataset/src/vision/image_folder.rs | 191 +++++++++++++++--- 2 files changed, 166 insertions(+), 28 deletions(-) diff --git a/crates/burn-dataset/Cargo.toml b/crates/burn-dataset/Cargo.toml index 45b46f302d..a3b5e8a9f8 100644 --- a/crates/burn-dataset/Cargo.toml +++ b/crates/burn-dataset/Cargo.toml @@ -21,7 +21,7 @@ fake = ["dep:fake"] sqlite = ["__sqlite-shared", "dep:rusqlite"] sqlite-bundled = ["__sqlite-shared", "rusqlite/bundled"] -vision = ["dep:flate2", "dep:globwalk", "dep:burn-common"] +vision = ["dep:bincode", "dep:flate2", "dep:globwalk", "dep:burn-common"] # internal __sqlite-shared = [ @@ -33,6 +33,7 @@ __sqlite-shared = [ ] [dependencies] +bincode = { workspace = true, optional = true } burn-common = { path = "../burn-common", version = "0.13.0", optional = true, features = [ "network", ] } diff --git a/crates/burn-dataset/src/vision/image_folder.rs b/crates/burn-dataset/src/vision/image_folder.rs index e98f81b955..84bebcf2fa 100644 --- a/crates/burn-dataset/src/vision/image_folder.rs +++ b/crates/burn-dataset/src/vision/image_folder.rs @@ -57,11 +57,13 @@ impl TryFrom for f32 { } } -/// Image target for different tasks. +/// Annotation type for different tasks. #[derive(Debug, Clone, PartialEq)] pub enum Annotation { /// Image-level label. Label(usize), + /// Multiple image-level labels. + MultiLabel(Vec), /// Object bounding boxes. BoundingBoxes(Vec), /// Segmentation mask. @@ -97,14 +99,47 @@ pub struct ImageDatasetItem { pub annotation: Annotation, } +/// Raw annotation types. +#[derive(Deserialize, Serialize, Debug, Clone)] +enum AnnotationRaw { + Label(String), + MultiLabel(Vec), + // TODO: bounding boxes and segmentation mask +} + +impl AnnotationRaw { + fn bin_config() -> bincode::config::Configuration { + bincode::config::standard() + } + + fn encode(&self) -> Vec { + bincode::serde::encode_to_vec(self, Self::bin_config()).unwrap() + } + + fn decode(annotation: &[u8]) -> Self { + let (annotation, _): (AnnotationRaw, usize) = + bincode::serde::decode_from_slice(&annotation, Self::bin_config()).unwrap(); + annotation + } +} + #[derive(Deserialize, Serialize, Debug, Clone)] struct ImageDatasetItemRaw { /// Image path. - pub image_path: PathBuf, + image_path: PathBuf, /// Image annotation. /// The annotation bytes can represent a string (category name) or path to annotation file. - pub annotation: Vec, + annotation: Vec, +} + +impl ImageDatasetItemRaw { + fn new>(image_path: P, annotation: AnnotationRaw) -> ImageDatasetItemRaw { + ImageDatasetItemRaw { + image_path: image_path.as_ref().to_path_buf(), + annotation: annotation.encode(), + } + } } struct PathToImageDatasetItem { @@ -118,9 +153,18 @@ fn parse_image_annotation(annotation: &[u8], classes: &HashMap) - // - [ ] Segmentation mask // For now, only image classification labels are supported. + let annotation = AnnotationRaw::decode(annotation); + // Map class string to label id - let name = std::str::from_utf8(annotation).unwrap(); - Annotation::Label(*classes.get(name).unwrap()) + match annotation { + AnnotationRaw::Label(name) => Annotation::Label(*classes.get(&name).unwrap()), + AnnotationRaw::MultiLabel(names) => Annotation::MultiLabel( + names + .iter() + .map(|name| *classes.get(name).unwrap()) + .collect(), + ), + } } impl Mapper for PathToImageDatasetItem { @@ -212,7 +256,7 @@ pub enum ImageLoaderError { type ImageDatasetMapper = MapperDataset, PathToImageDatasetItem, ImageDatasetItemRaw>; -/// A generic dataset to load classification images from disk. +/// A generic dataset to load images from disk. pub struct ImageFolderDataset { dataset: ImageDatasetMapper, } @@ -259,18 +303,6 @@ impl ImageFolderDataset { P: AsRef, S: AsRef, { - /// Check if extension is supported. - fn check_extension>(extension: &S) -> Result { - let extension = extension.as_ref(); - if !SUPPORTED_FILES.contains(&extension) { - Err(ImageLoaderError::InvalidFileExtensionError( - extension.to_string(), - )) - } else { - Ok(extension.to_string()) - } - } - // Glob all images with extensions let walker = globwalk::GlobWalkerBuilder::from_patterns( root.as_ref(), @@ -278,7 +310,7 @@ impl ImageFolderDataset { "*.{{{}}}", // "*.{ext1,ext2,ext3} extensions .iter() - .map(check_extension) + .map(Self::check_extension) .collect::, _>>()? .join(",") )], @@ -312,21 +344,99 @@ impl ImageFolderDataset { classes.insert(label.clone()); - items.push(ImageDatasetItemRaw { - image_path: image_path.to_path_buf(), - annotation: label.into_bytes(), - }) + items.push(ImageDatasetItemRaw::new( + image_path, + AnnotationRaw::Label(label), + )) } + Self::with_items(items, &classes.iter().collect::>()) + } + + /// Create an image classification dataset with the specified items. + /// + /// # Arguments + /// + /// * `items` - List of dataset items, each item represented by a tuple `(image path, label)`. + /// * `classes` - Dataset class names. + /// + /// # Returns + /// A new dataset instance. + pub fn new_classification_with_items, S: AsRef>( + items: Vec<(P, String)>, + classes: &[S], + ) -> Result { + // Parse items and check valid image extension types + let items = items + .into_iter() + .map(|(path, label)| { + // Map image path and label + let path = path.as_ref(); + let label = AnnotationRaw::Label(label); + + Self::check_extension(&path.extension().unwrap().to_str().unwrap())?; + + Ok(ImageDatasetItemRaw::new(path, label)) + }) + .collect::, _>>()?; + + Self::with_items(items, classes) + } + + /// Create a multi-label image classification dataset with the specified items. + /// + /// # Arguments + /// + /// * `items` - List of dataset items, each item represented by a tuple `(image path, labels)`. + /// * `classes` - Dataset class names. + /// + /// # Returns + /// A new dataset instance. + pub fn new_multilabel_classification_with_items, S: AsRef>( + items: Vec<(P, Vec)>, + classes: &[S], + ) -> Result { + // Parse items and check valid image extension types + let items = items + .into_iter() + .map(|(path, labels)| { + // Map image path and multi-label + let path = path.as_ref(); + let labels = AnnotationRaw::MultiLabel(labels); + + Self::check_extension(&path.extension().unwrap().to_str().unwrap())?; + + Ok(ImageDatasetItemRaw::new(path, labels)) + }) + .collect::, _>>()?; + + Self::with_items(items, classes) + } + + /// Create an image dataset with the specified items. + /// + /// # Arguments + /// + /// * `items` - Raw dataset items. + /// * `classes` - Dataset class names. + /// + /// # Returns + /// A new dataset instance. + fn with_items>( + items: Vec, + classes: &[S], + ) -> Result { + // NOTE: right now we don't need to validate the supported image files since + // the method is private. We assume it's already validated. let dataset = InMemDataset::new(items); // Class names to index map - let mut classes = classes.into_iter().collect::>(); + let mut classes = classes.iter().map(|c| c.as_ref()).collect::>(); classes.sort(); let classes_map: HashMap<_, _> = classes .into_iter() .enumerate() - .map(|(idx, cls)| (cls, idx)) + .map(|(idx, cls)| (cls.to_string(), idx)) .collect(); let mapper = PathToImageDatasetItem { @@ -336,6 +446,18 @@ impl ImageFolderDataset { Ok(Self { dataset }) } + + /// Check if extension is supported. + fn check_extension>(extension: &S) -> Result { + let extension = extension.as_ref(); + if !SUPPORTED_FILES.contains(&extension) { + Err(ImageLoaderError::InvalidFileExtensionError( + extension.to_string(), + )) + } else { + Ok(extension.to_string()) + } + } } #[cfg(test)] @@ -417,11 +539,26 @@ mod tests { } #[test] - pub fn parse_image_annotation_string() { + pub fn parse_image_annotation_label_string() { let classes = HashMap::from([("0".to_string(), 0_usize), ("1".to_string(), 1_usize)]); + let anno = AnnotationRaw::Label("0".to_string()).encode(); assert_eq!( - parse_image_annotation(&"0".to_string().into_bytes(), &classes), + parse_image_annotation(&anno, &classes), Annotation::Label(0) ); } + + #[test] + pub fn parse_image_annotation_multilabel_string() { + let classes = HashMap::from([ + ("0".to_string(), 0_usize), + ("1".to_string(), 1_usize), + ("2".to_string(), 2_usize), + ]); + let anno = AnnotationRaw::MultiLabel(vec!["0".to_string(), "2".to_string()]).encode(); + assert_eq!( + parse_image_annotation(&anno, &classes), + Annotation::MultiLabel(vec![0, 2]) + ); + } } From 968299de798c4d9d15620636c71b8e3cd08aa70b Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Tue, 2 Apr 2024 14:32:28 -0400 Subject: [PATCH 02/11] Fix custom-image-classification example deps --- examples/custom-image-dataset/Cargo.toml | 3 --- examples/custom-image-dataset/README.md | 11 ++++++++++- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/examples/custom-image-dataset/Cargo.toml b/examples/custom-image-dataset/Cargo.toml index 13e27c1652..eabd5e43de 100644 --- a/examples/custom-image-dataset/Cargo.toml +++ b/examples/custom-image-dataset/Cargo.toml @@ -17,7 +17,4 @@ burn = { path = "../../crates/burn", features = ["train", "vision", "network"] } # File download flate2 = { workspace = true } -indicatif = { workspace = true } -reqwest = { workspace = true } tar = "0.4.40" -tokio = { workspace = true } diff --git a/examples/custom-image-dataset/README.md b/examples/custom-image-dataset/README.md index d9eee42b27..ac49957e47 100644 --- a/examples/custom-image-dataset/README.md +++ b/examples/custom-image-dataset/README.md @@ -51,6 +51,15 @@ The CNN model and training recipe used in this example are fairly simple since t demonstrate how to load a custom image classification dataset from disk. Nonetheless, it still achieves 70-80% accuracy on the test set after just 30 epochs. +Run it with the Torch GPU backend: + +```sh +export TORCH_CUDA_VERSION=cu121 +cargo run --example custom-image-dataset --release --features tch-gpu +``` + +Run it with our WGPU backend: + ```sh -cargo run --example custom-image-dataset +cargo run --example custom-image-dataset --release --features wgpu ``` From 813ae504f0615d4430ca7a5b50103fe5597a74fb Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Tue, 2 Apr 2024 15:05:11 -0400 Subject: [PATCH 03/11] Add image_folder_dataset_multilabel test --- .../burn-dataset/src/vision/image_folder.rs | 42 +++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/crates/burn-dataset/src/vision/image_folder.rs b/crates/burn-dataset/src/vision/image_folder.rs index 84bebcf2fa..2cc099edaf 100644 --- a/crates/burn-dataset/src/vision/image_folder.rs +++ b/crates/burn-dataset/src/vision/image_folder.rs @@ -492,6 +492,48 @@ mod tests { assert_eq!(dataset.get(1).unwrap().annotation, Annotation::Label(1)); } + #[test] + pub fn image_folder_dataset_multilabel() { + let root = Path::new(DATASET_ROOT); + let items = vec![ + ( + root.join("orange").join("dot.jpg"), + vec!["dot".to_string(), "orange".to_string()], + ), + ( + root.join("red").join("dot.jpg"), + vec!["dot".to_string(), "red".to_string()], + ), + ( + root.join("red").join("dot.png"), + vec!["dot".to_string(), "red".to_string()], + ), + ]; + let dataset = ImageFolderDataset::new_multilabel_classification_with_items( + items, + &["dot", "orange", "red"], + ) + .unwrap(); + + // Dataset has 3 elements + assert_eq!(dataset.len(), 3); + assert_eq!(dataset.get(3), None); + + // Dataset elements should be: [dot, orange] (0, 1), [dot, red] (0, 2), [dot, red] (0, 2) + assert_eq!( + dataset.get(0).unwrap().annotation, + Annotation::MultiLabel(vec![0, 1]) + ); + assert_eq!( + dataset.get(1).unwrap().annotation, + Annotation::MultiLabel(vec![0, 2]) + ); + assert_eq!( + dataset.get(2).unwrap().annotation, + Annotation::MultiLabel(vec![0, 2]) + ); + } + #[test] #[should_panic] pub fn image_folder_dataset_invalid_extension() { From bd118e75449a067ea9f8f4f901dd814d8d428f66 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Tue, 2 Apr 2024 15:12:27 -0400 Subject: [PATCH 04/11] Do not change class names order when provided --- crates/burn-dataset/src/vision/image_folder.rs | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/crates/burn-dataset/src/vision/image_folder.rs b/crates/burn-dataset/src/vision/image_folder.rs index 2cc099edaf..702abd852e 100644 --- a/crates/burn-dataset/src/vision/image_folder.rs +++ b/crates/burn-dataset/src/vision/image_folder.rs @@ -350,7 +350,11 @@ impl ImageFolderDataset { )) } - Self::with_items(items, &classes.iter().collect::>()) + // Sort class names + let mut classes = classes.into_iter().collect::>(); + classes.sort(); + + Self::with_items(items, &classes) } /// Create an image classification dataset with the specified items. @@ -431,8 +435,7 @@ impl ImageFolderDataset { let dataset = InMemDataset::new(items); // Class names to index map - let mut classes = classes.iter().map(|c| c.as_ref()).collect::>(); - classes.sort(); + let classes = classes.iter().map(|c| c.as_ref()).collect::>(); let classes_map: HashMap<_, _> = classes .into_iter() .enumerate() From 8ef0115fb233e8fc60e892d12b72d842dc5a19c4 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Wed, 3 Apr 2024 12:45:50 -0400 Subject: [PATCH 05/11] Add hamming score and multi-label classification output --- .../burn-train/src/learner/classification.rs | 27 +++- crates/burn-train/src/metric/hamming.rs | 149 ++++++++++++++++++ crates/burn-train/src/metric/mod.rs | 2 + 3 files changed, 177 insertions(+), 1 deletion(-) create mode 100644 crates/burn-train/src/metric/hamming.rs diff --git a/crates/burn-train/src/learner/classification.rs b/crates/burn-train/src/learner/classification.rs index f6b415fa29..ee86a05754 100644 --- a/crates/burn-train/src/learner/classification.rs +++ b/crates/burn-train/src/learner/classification.rs @@ -1,4 +1,4 @@ -use crate::metric::{AccuracyInput, Adaptor, LossInput}; +use crate::metric::{AccuracyInput, Adaptor, HammingScoreInput, LossInput}; use burn_core::tensor::backend::Backend; use burn_core::tensor::{Int, Tensor}; @@ -26,3 +26,28 @@ impl Adaptor> for ClassificationOutput { LossInput::new(self.loss.clone()) } } + +/// Multi-label classification output adapted for multiple metrics. +#[derive(new)] +pub struct MultiLabelClassificationOutput { + /// The loss. + pub loss: Tensor, + + /// The output. + pub output: Tensor, + + /// The targets. + pub targets: Tensor, +} + +impl Adaptor> for MultiLabelClassificationOutput { + fn adapt(&self) -> HammingScoreInput { + HammingScoreInput::new(self.output.clone(), self.targets.clone()) + } +} + +impl Adaptor> for MultiLabelClassificationOutput { + fn adapt(&self) -> LossInput { + LossInput::new(self.loss.clone()) + } +} diff --git a/crates/burn-train/src/metric/hamming.rs b/crates/burn-train/src/metric/hamming.rs new file mode 100644 index 0000000000..79f0e807a8 --- /dev/null +++ b/crates/burn-train/src/metric/hamming.rs @@ -0,0 +1,149 @@ +use super::state::{FormatOptions, NumericMetricState}; +use super::{MetricEntry, MetricMetadata}; +use crate::metric::{Metric, Numeric}; +use burn_core::tensor::{activation::sigmoid, backend::Backend, ElementConversion, Int, Tensor}; + +/// The hamming score, sometimes referred to as multi-label or label-based accuracy. +#[derive(Default)] +pub struct HammingScore { + state: NumericMetricState, + threshold: f32, + sigmoid: bool, + _b: B, +} + +/// The [hamming score](HammingScore) input type. +#[derive(new)] +pub struct HammingScoreInput { + outputs: Tensor, + targets: Tensor, +} + +impl HammingScore { + /// Creates the metric. + pub fn new() -> Self { + Self::default() + } + + /// Sets the threshold. + pub fn with_threshold(mut self, threshold: f32) -> Self { + self.threshold = threshold; + self + } + + /// Sets the sigmoid activation function usage. + pub fn with_sigmoid(mut self, sigmoid: bool) -> Self { + self.sigmoid = sigmoid; + self + } + + /// Creates a new metric instance with default values. + fn default() -> Self { + let mut instance: Self = Default::default(); + instance.threshold = 0.5; + instance + } +} + +impl Metric for HammingScore { + const NAME: &'static str = "Hamming Score"; + + type Input = HammingScoreInput; + + fn update(&mut self, input: &HammingScoreInput, _metadata: &MetricMetadata) -> MetricEntry { + let [batch_size, _n_classes] = input.outputs.dims(); + + let targets = input.targets.clone(); + + let mut outputs = input.outputs.clone(); + + if self.sigmoid { + outputs = sigmoid(outputs); + } + + let score = outputs + .greater_elem(self.threshold) + .equal(targets.bool()) + .float() + .mean() + .into_scalar() + .elem::(); + + self.state.update( + 100.0 * score, + batch_size, + FormatOptions::new(Self::NAME).unit("%").precision(2), + ) + } + + fn clear(&mut self) { + self.state.reset() + } +} + +impl Numeric for HammingScore { + fn value(&self) -> f64 { + self.state.value() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::TestBackend; + + #[test] + fn test_hamming_score() { + let device = Default::default(); + let mut metric = HammingScore::::new(); + + let x = Tensor::from_data( + [ + [0.32, 0.52, 0.38, 0.68, 0.61], // with x > 0.5: [0, 1, 0, 1, 1] + [0.43, 0.31, 0.21, 0.63, 0.53], // [0, 0, 0, 1, 1] + [0.44, 0.25, 0.71, 0.39, 0.73], // [0, 0, 1, 0, 1] + [0.49, 0.37, 0.68, 0.39, 0.31], // [0, 0, 1, 0, 0] + ], + &device, + ); + let y = Tensor::from_data( + [ + [0, 1, 0, 1, 1], + [0, 0, 0, 1, 1], + [0, 0, 1, 0, 1], + [0, 0, 1, 0, 0], + ], + &device, + ); + + let _entry = metric.update( + &HammingScoreInput::new(x.clone(), y.clone()), + &MetricMetadata::fake(), + ); + assert_eq!(100.0, metric.value()); + + // Invert all targets: y = (1 - y) + let y = y.neg().add_scalar(1); + let _entry = metric.update( + &HammingScoreInput::new(x.clone(), y), // invert targets (1 - y) + &MetricMetadata::fake(), + ); + assert_eq!(0.0, metric.value()); + + // Invert 5 target values -> 1 - (5/20) = 0.75 + let y = Tensor::from_data( + [ + [0, 1, 1, 0, 1], + [0, 0, 0, 0, 1], + [0, 0, 0, 0, 1], + [0, 1, 1, 0, 0], + ], + &device, + ); + let _entry = metric.update( + &HammingScoreInput::new(x, y), // invert targets (1 - y) + &MetricMetadata::fake(), + ); + assert_eq!(75.0, metric.value()); + } +} diff --git a/crates/burn-train/src/metric/mod.rs b/crates/burn-train/src/metric/mod.rs index 37ad5af73b..7d443da067 100644 --- a/crates/burn-train/src/metric/mod.rs +++ b/crates/burn-train/src/metric/mod.rs @@ -9,6 +9,7 @@ mod cpu_temp; mod cpu_use; #[cfg(feature = "metrics")] mod cuda; +mod hamming; mod learning_rate; mod loss; #[cfg(feature = "metrics")] @@ -22,6 +23,7 @@ pub use cpu_temp::*; pub use cpu_use::*; #[cfg(feature = "metrics")] pub use cuda::*; +pub use hamming::*; pub use learning_rate::*; pub use loss::*; #[cfg(feature = "metrics")] From 858599b5b3e4cdaea998c9457d347c55d41f07af Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Wed, 3 Apr 2024 14:43:50 -0400 Subject: [PATCH 06/11] Add new_classification_with_items test --- .../burn-dataset/src/vision/image_folder.rs | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/crates/burn-dataset/src/vision/image_folder.rs b/crates/burn-dataset/src/vision/image_folder.rs index 702abd852e..aa5d2fb55b 100644 --- a/crates/burn-dataset/src/vision/image_folder.rs +++ b/crates/burn-dataset/src/vision/image_folder.rs @@ -495,6 +495,27 @@ mod tests { assert_eq!(dataset.get(1).unwrap().annotation, Annotation::Label(1)); } + #[test] + pub fn image_folder_dataset_with_items() { + let root = Path::new(DATASET_ROOT); + let items = vec![ + (root.join("orange").join("dot.jpg"), "orange".to_string()), + (root.join("red").join("dot.jpg"), "red".to_string()), + (root.join("red").join("dot.png"), "red".to_string()), + ]; + let dataset = + ImageFolderDataset::new_classification_with_items(items, &["orange", "red"]).unwrap(); + + // Dataset has 3 elements + assert_eq!(dataset.len(), 3); + assert_eq!(dataset.get(3), None); + + // Dataset elements should be: orange (0), red (1), red (1) + assert_eq!(dataset.get(0).unwrap().annotation, Annotation::Label(0)); + assert_eq!(dataset.get(1).unwrap().annotation, Annotation::Label(1)); + assert_eq!(dataset.get(2).unwrap().annotation, Annotation::Label(1)); + } + #[test] pub fn image_folder_dataset_multilabel() { let root = Path::new(DATASET_ROOT); From 32c8ebf3627da00ed0ac1a856f5314abc831b69c Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Wed, 3 Apr 2024 15:10:06 -0400 Subject: [PATCH 07/11] Fix clippy suggestions --- crates/burn-dataset/src/vision/image_folder.rs | 2 +- crates/burn-train/src/metric/hamming.rs | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/crates/burn-dataset/src/vision/image_folder.rs b/crates/burn-dataset/src/vision/image_folder.rs index aa5d2fb55b..75909bb90f 100644 --- a/crates/burn-dataset/src/vision/image_folder.rs +++ b/crates/burn-dataset/src/vision/image_folder.rs @@ -118,7 +118,7 @@ impl AnnotationRaw { fn decode(annotation: &[u8]) -> Self { let (annotation, _): (AnnotationRaw, usize) = - bincode::serde::decode_from_slice(&annotation, Self::bin_config()).unwrap(); + bincode::serde::decode_from_slice(annotation, Self::bin_config()).unwrap(); annotation } } diff --git a/crates/burn-train/src/metric/hamming.rs b/crates/burn-train/src/metric/hamming.rs index 79f0e807a8..2f0d72df43 100644 --- a/crates/burn-train/src/metric/hamming.rs +++ b/crates/burn-train/src/metric/hamming.rs @@ -39,9 +39,10 @@ impl HammingScore { /// Creates a new metric instance with default values. fn default() -> Self { - let mut instance: Self = Default::default(); - instance.threshold = 0.5; - instance + Self { + threshold: 0.5, + ..Default::default() + } } } From 87c9c6249e5efc3e881c9a55c127632990e7db3f Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Thu, 4 Apr 2024 09:41:43 -0400 Subject: [PATCH 08/11] Implement default trait for hamming score --- crates/burn-train/src/metric/hamming.rs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/crates/burn-train/src/metric/hamming.rs b/crates/burn-train/src/metric/hamming.rs index 2f0d72df43..6ef2b9ff7b 100644 --- a/crates/burn-train/src/metric/hamming.rs +++ b/crates/burn-train/src/metric/hamming.rs @@ -4,7 +4,6 @@ use crate::metric::{Metric, Numeric}; use burn_core::tensor::{activation::sigmoid, backend::Backend, ElementConversion, Int, Tensor}; /// The hamming score, sometimes referred to as multi-label or label-based accuracy. -#[derive(Default)] pub struct HammingScore { state: NumericMetricState, threshold: f32, @@ -36,12 +35,16 @@ impl HammingScore { self.sigmoid = sigmoid; self } +} +impl Default for HammingScore { /// Creates a new metric instance with default values. fn default() -> Self { Self { + state: NumericMetricState::default(), threshold: 0.5, - ..Default::default() + sigmoid: false, + _b: B::default(), } } } From 311ad4597e273f296a7a777bbb9e7548c7f043c9 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Thu, 4 Apr 2024 09:52:36 -0400 Subject: [PATCH 09/11] Remove de/serialization and use AnnotationRaw as type --- crates/burn-dataset/Cargo.toml | 3 +- .../burn-dataset/src/vision/image_folder.rs | 34 +++++-------------- 2 files changed, 10 insertions(+), 27 deletions(-) diff --git a/crates/burn-dataset/Cargo.toml b/crates/burn-dataset/Cargo.toml index a3b5e8a9f8..45b46f302d 100644 --- a/crates/burn-dataset/Cargo.toml +++ b/crates/burn-dataset/Cargo.toml @@ -21,7 +21,7 @@ fake = ["dep:fake"] sqlite = ["__sqlite-shared", "dep:rusqlite"] sqlite-bundled = ["__sqlite-shared", "rusqlite/bundled"] -vision = ["dep:bincode", "dep:flate2", "dep:globwalk", "dep:burn-common"] +vision = ["dep:flate2", "dep:globwalk", "dep:burn-common"] # internal __sqlite-shared = [ @@ -33,7 +33,6 @@ __sqlite-shared = [ ] [dependencies] -bincode = { workspace = true, optional = true } burn-common = { path = "../burn-common", version = "0.13.0", optional = true, features = [ "network", ] } diff --git a/crates/burn-dataset/src/vision/image_folder.rs b/crates/burn-dataset/src/vision/image_folder.rs index 75909bb90f..14f3e3d4cd 100644 --- a/crates/burn-dataset/src/vision/image_folder.rs +++ b/crates/burn-dataset/src/vision/image_folder.rs @@ -107,37 +107,20 @@ enum AnnotationRaw { // TODO: bounding boxes and segmentation mask } -impl AnnotationRaw { - fn bin_config() -> bincode::config::Configuration { - bincode::config::standard() - } - - fn encode(&self) -> Vec { - bincode::serde::encode_to_vec(self, Self::bin_config()).unwrap() - } - - fn decode(annotation: &[u8]) -> Self { - let (annotation, _): (AnnotationRaw, usize) = - bincode::serde::decode_from_slice(annotation, Self::bin_config()).unwrap(); - annotation - } -} - #[derive(Deserialize, Serialize, Debug, Clone)] struct ImageDatasetItemRaw { /// Image path. image_path: PathBuf, /// Image annotation. - /// The annotation bytes can represent a string (category name) or path to annotation file. - annotation: Vec, + annotation: AnnotationRaw, } impl ImageDatasetItemRaw { fn new>(image_path: P, annotation: AnnotationRaw) -> ImageDatasetItemRaw { ImageDatasetItemRaw { image_path: image_path.as_ref().to_path_buf(), - annotation: annotation.encode(), + annotation: annotation, } } } @@ -147,17 +130,18 @@ struct PathToImageDatasetItem { } /// Parse the image annotation to the corresponding type. -fn parse_image_annotation(annotation: &[u8], classes: &HashMap) -> Annotation { +fn parse_image_annotation( + annotation: &AnnotationRaw, + classes: &HashMap, +) -> Annotation { // TODO: add support for other annotations // - [ ] Object bounding boxes // - [ ] Segmentation mask // For now, only image classification labels are supported. - let annotation = AnnotationRaw::decode(annotation); - // Map class string to label id match annotation { - AnnotationRaw::Label(name) => Annotation::Label(*classes.get(&name).unwrap()), + AnnotationRaw::Label(name) => Annotation::Label(*classes.get(name).unwrap()), AnnotationRaw::MultiLabel(names) => Annotation::MultiLabel( names .iter() @@ -607,7 +591,7 @@ mod tests { #[test] pub fn parse_image_annotation_label_string() { let classes = HashMap::from([("0".to_string(), 0_usize), ("1".to_string(), 1_usize)]); - let anno = AnnotationRaw::Label("0".to_string()).encode(); + let anno = AnnotationRaw::Label("0".to_string()); assert_eq!( parse_image_annotation(&anno, &classes), Annotation::Label(0) @@ -621,7 +605,7 @@ mod tests { ("1".to_string(), 1_usize), ("2".to_string(), 2_usize), ]); - let anno = AnnotationRaw::MultiLabel(vec!["0".to_string(), "2".to_string()]).encode(); + let anno = AnnotationRaw::MultiLabel(vec!["0".to_string(), "2".to_string()]); assert_eq!( parse_image_annotation(&anno, &classes), Annotation::MultiLabel(vec![0, 2]) From b165044a7b55872d659f6df5757c86991401f7f1 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Thu, 4 Apr 2024 10:10:34 -0400 Subject: [PATCH 10/11] Fix clippy --- crates/burn-dataset/src/vision/image_folder.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/burn-dataset/src/vision/image_folder.rs b/crates/burn-dataset/src/vision/image_folder.rs index 14f3e3d4cd..f850714f35 100644 --- a/crates/burn-dataset/src/vision/image_folder.rs +++ b/crates/burn-dataset/src/vision/image_folder.rs @@ -120,7 +120,7 @@ impl ImageDatasetItemRaw { fn new>(image_path: P, annotation: AnnotationRaw) -> ImageDatasetItemRaw { ImageDatasetItemRaw { image_path: image_path.as_ref().to_path_buf(), - annotation: annotation, + annotation, } } } From a63fa6e20255e5d60b51e4ee5836ff28d09cda98 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Thu, 4 Apr 2024 12:47:49 -0400 Subject: [PATCH 11/11] Fix metric backend phantom data --- crates/burn-train/src/metric/acc.rs | 4 +++- crates/burn-train/src/metric/hamming.rs | 6 ++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/crates/burn-train/src/metric/acc.rs b/crates/burn-train/src/metric/acc.rs index efdbf88cff..ad24fe077f 100644 --- a/crates/burn-train/src/metric/acc.rs +++ b/crates/burn-train/src/metric/acc.rs @@ -1,3 +1,5 @@ +use core::marker::PhantomData; + use super::state::{FormatOptions, NumericMetricState}; use super::{MetricEntry, MetricMetadata}; use crate::metric::{Metric, Numeric}; @@ -9,7 +11,7 @@ use burn_core::tensor::{ElementConversion, Int, Tensor}; pub struct AccuracyMetric { state: NumericMetricState, pad_token: Option, - _b: B, + _b: PhantomData, } /// The [accuracy metric](AccuracyMetric) input type. diff --git a/crates/burn-train/src/metric/hamming.rs b/crates/burn-train/src/metric/hamming.rs index 6ef2b9ff7b..6291833de7 100644 --- a/crates/burn-train/src/metric/hamming.rs +++ b/crates/burn-train/src/metric/hamming.rs @@ -1,3 +1,5 @@ +use core::marker::PhantomData; + use super::state::{FormatOptions, NumericMetricState}; use super::{MetricEntry, MetricMetadata}; use crate::metric::{Metric, Numeric}; @@ -8,7 +10,7 @@ pub struct HammingScore { state: NumericMetricState, threshold: f32, sigmoid: bool, - _b: B, + _b: PhantomData, } /// The [hamming score](HammingScore) input type. @@ -44,7 +46,7 @@ impl Default for HammingScore { state: NumericMetricState::default(), threshold: 0.5, sigmoid: false, - _b: B::default(), + _b: PhantomData, } } }