Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ResNet] Add fine-tuning example #29

Merged
merged 3 commits into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 17 additions & 16 deletions resnet-burn/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,25 +1,26 @@
[package]
authors = ["guillaumelagrange <[email protected]>"]
license = "MIT OR Apache-2.0"
name = "resnet-burn"
version = "0.1.0"
edition = "2021"
[workspace]
# Try
# require version 2 to avoid "feature" additiveness for dev-dependencies
# https://doc.rust-lang.org/cargo/reference/resolver.html#feature-resolver-version-2
resolver = "2"

members = [
"resnet",
"examples/*",
]

[features]
default = []
std = []
pretrained = ["burn/network", "std", "dep:dirs"]
[workspace.package]
edition = "2021"
version = "0.2.0"
readme = "README.md"
license = "MIT OR Apache-2.0"

[dependencies]
[workspace.dependencies]
# Note: default-features = false is needed to disable std
burn = { version = "0.13.0", default-features = false }
burn-import = "0.13.0"
dirs = { version = "5.0.1", optional = true }
dirs = "5.0.1"
serde = { version = "1.0.192", default-features = false, features = [
"derive",
"alloc",
] } # alloc is for no_std, derive is needed

[dev-dependencies]
burn = { version = "0.13.0", features = ["ndarray"] }
image = { version = "0.24.9", features = ["png", "jpeg"] }
41 changes: 38 additions & 3 deletions resnet-burn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

To this day, [ResNet](https://arxiv.org/abs/1512.03385)s are still a strong baseline for your image
classification tasks. You can find the [Burn](https://github.com/tracel-ai/burn) implementation for
the ResNet variants in [src/model/resnet.rs](src/model/resnet.rs).
the ResNet variants in [resnet.rs](resnet/src/resnet.rs).

The model is [no_std compatible](https://docs.rust-embedded.org/book/intro/no-std.html).

Expand All @@ -28,12 +28,47 @@ resnet-burn = { git = "https://github.com/tracel-ai/models", package = "resnet-b

### Example Usage

The [inference example](examples/inference.rs) initializes a ResNet-18 from the ImageNet
#### Inference

The [inference example](examples/inference/examples/inference.rs) initializes a ResNet-18 from the
ImageNet
[pre-trained weights](https://pytorch.org/vision/stable/models/generated/torchvision.models.resnet18.html#torchvision.models.ResNet18_Weights)
with the `NdArray` backend and performs inference on the provided input image.

You can run the example with the following command:

```sh
cargo run --release --features pretrained --example inference samples/dog.jpg
cargo run --release --example inference samples/dog.jpg --release
```

#### Fine-tuning

For this [multi-label image classification fine-tuning example](examples/finetune), a sample of the
planets dataset from the Kaggle competition
[Planet: Understanding the Amazon from Space](https://www.kaggle.com/c/planet-understanding-the-amazon-from-space)
is downloaded from a
[fastai mirror](https://github.com/fastai/fastai/blob/master/fastai/data/external.py#L55). The
sample dataset is a collection of satellite images with multiple labels describing the scene, as
illustrated below.

<img src="./samples/dataset.jpg" alt="Planet dataset sample" width="1000"/>

To achieve this task, a ResNet-18 pre-trained on the ImageNet dataset is fine-tuned on the target
planets dataset. The training recipe used is fairly simple. The main objective is to demonstrate how to re-use a
pre-trained model for a different downstream task.

Without any bells and whistle, our model achieves over 90% multi-label accuracy (i.e., hamming
score) on the validation set after just 5 epochs.

Run the example with the Torch GPU backend:

```sh
export TORCH_CUDA_VERSION=cu121
cargo run --release --example finetune --features tch-gpu
```

Run it with our WGPU backend:

```sh
cargo run --release --example finetune --features wgpu
```
2 changes: 2 additions & 0 deletions resnet-burn/examples/finetune/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Downloaded files
planet_sample/
24 changes: 24 additions & 0 deletions resnet-burn/examples/finetune/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
[package]
authors = ["guillaumelagrange <[email protected]>"]
name = "finetune"
license.workspace = true
version.workspace = true
edition.workspace = true

[features]
default = ["burn/default"]
tch-gpu = ["burn/tch"]
wgpu = ["burn/wgpu"]

[dependencies]
resnet-burn = { path = "../../resnet", features = ["pretrained"] }
burn = { workspace = true, features = ["train", "vision", "network"] }

# Dataset files
csv = "1.3.0"
flate2 = "1.0.28"
rand = { version = "0.8.5", default-features = false, features = [
"std_rng",
] }
serde = { version = "1.0.192", features = ["std", "derive"] }
tar = "0.4.40"
44 changes: 44 additions & 0 deletions resnet-burn/examples/finetune/examples/finetune.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
use burn::{backend::Autodiff, tensor::backend::Backend};
use finetune::{inference::infer, training::train};

#[allow(dead_code)]
const ARTIFACT_DIR: &str = "/tmp/resnet-finetune";

#[allow(dead_code)]
fn run<B: Backend>(device: B::Device) {
train::<Autodiff<B>>(ARTIFACT_DIR, device.clone());
infer::<B>(ARTIFACT_DIR, device, 0.5);
}

#[cfg(feature = "tch-gpu")]
mod tch_gpu {
use burn::backend::libtorch::{LibTorch, LibTorchDevice};

pub fn run() {
#[cfg(not(target_os = "macos"))]
let device = LibTorchDevice::Cuda(0);
#[cfg(target_os = "macos")]
let device = LibTorchDevice::Mps;

super::run::<LibTorch>(device);
}
}

#[cfg(feature = "wgpu")]
mod wgpu {
use burn::{
backend::wgpu::{Wgpu, WgpuDevice},
Wgpu,
};

pub fn run() {
super::run::<Wgpu>(WgpuDevice::default());
}
}

fn main() {
#[cfg(feature = "tch-gpu")]
tch_gpu::run();
#[cfg(feature = "wgpu")]
wgpu::run();
}
132 changes: 132 additions & 0 deletions resnet-burn/examples/finetune/src/data.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
use burn::{
data::{
dataloader::batcher::Batcher,
dataset::vision::{Annotation, ImageDatasetItem, PixelDepth},
},
prelude::*,
};

use super::dataset::CLASSES;

// ImageNet mean and std values
const MEAN: [f32; 3] = [0.485, 0.456, 0.406];
const STD: [f32; 3] = [0.229, 0.224, 0.225];

// Planets patch size
const WIDTH: usize = 256;
const HEIGHT: usize = 256;

/// Create a multi-hot encoded tensor.
///
/// # Example
///
/// ```rust, ignore
/// let multi_hot = multi_hot::<B>(&[2, 5, 8], 10, &device);
/// println!("{}", multi_hot.to_data());
/// // [0, 0, 1, 0, 0, 1, 0, 0, 1, 0]
/// ```
pub fn multi_hot<B: Backend>(
indices: &[usize],
num_classes: usize,
device: &B::Device,
) -> Tensor<B, 1, Int> {
Tensor::zeros(Shape::new([num_classes]), device).scatter(
0,
Tensor::from_ints(
indices
.iter()
.map(|i| *i as i32)
.collect::<Vec<_>>()
.as_slice(),
device,
),
Tensor::ones(Shape::new([indices.len()]), device),
)
}

/// Normalizer with ImageNet values as it helps accelerate training since we are fine-tuning from
/// ImageNet pre-trained weights and the model expects the data to be in this normalized range.
#[derive(Clone)]
pub struct Normalizer<B: Backend> {
pub mean: Tensor<B, 4>,
pub std: Tensor<B, 4>,
}

impl<B: Backend> Normalizer<B> {
/// Creates a new normalizer.
pub fn new(device: &Device<B>) -> Self {
let mean = Tensor::from_floats(MEAN, device).reshape([1, 3, 1, 1]);
let std = Tensor::from_floats(STD, device).reshape([1, 3, 1, 1]);
Self { mean, std }
}

/// Normalizes the input image according to the ImageNet dataset.
///
/// The input image should be in the range [0, 1].
/// The output image will be in the range [-1, 1].
///
/// The normalization is done according to the following formula:
/// `input = (input - mean) / std`
pub fn normalize(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
(input - self.mean.clone()) / self.std.clone()
}
}

#[derive(Clone)]
pub struct ClassificationBatcher<B: Backend> {
normalizer: Normalizer<B>,
device: B::Device,
}

#[derive(Clone, Debug)]
pub struct ClassificationBatch<B: Backend> {
pub images: Tensor<B, 4>,
pub targets: Tensor<B, 2, Int>,
}

impl<B: Backend> ClassificationBatcher<B> {
pub fn new(device: B::Device) -> Self {
Self {
normalizer: Normalizer::<B>::new(&device),
device,
}
}
}

impl<B: Backend> Batcher<ImageDatasetItem, ClassificationBatch<B>> for ClassificationBatcher<B> {
fn batch(&self, items: Vec<ImageDatasetItem>) -> ClassificationBatch<B> {
fn image_as_vec_u8(item: ImageDatasetItem) -> Vec<u8> {
// Convert Vec<PixelDepth> to Vec<u8> (Planet images are u8)
item.image
.into_iter()
.map(|p: PixelDepth| -> u8 { p.try_into().unwrap() })
.collect::<Vec<u8>>()
}

let targets = items
.iter()
.map(|item| {
// Expect multi-hot encoded class labels as target (e.g., [0, 1, 0, 0, 1])
if let Annotation::MultiLabel(y) = &item.annotation {
multi_hot(y, CLASSES.len(), &self.device)
} else {
panic!("Invalid target type")
}
})
.collect();

let images = items
.into_iter()
.map(|item| Data::new(image_as_vec_u8(item), Shape::new([HEIGHT, WIDTH, 3])))
.map(|data| Tensor::<B, 3>::from_data(data.convert(), &self.device).permute([2, 0, 1]))
.map(|tensor| tensor / 255) // normalize between [0, 1]
.collect();

let images = Tensor::stack(images, 0);
let targets = Tensor::stack(targets, 0);

let images = self.normalizer.normalize(images);

ClassificationBatch { images, targets }
}
}
Loading