-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #29 from tracel-ai/resnet/fine-tune
[ResNet] Add fine-tuning example
- Loading branch information
Showing
22 changed files
with
679 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,25 +1,26 @@ | ||
[package] | ||
authors = ["guillaumelagrange <[email protected]>"] | ||
license = "MIT OR Apache-2.0" | ||
name = "resnet-burn" | ||
version = "0.1.0" | ||
edition = "2021" | ||
[workspace] | ||
# Try | ||
# require version 2 to avoid "feature" additiveness for dev-dependencies | ||
# https://doc.rust-lang.org/cargo/reference/resolver.html#feature-resolver-version-2 | ||
resolver = "2" | ||
|
||
members = [ | ||
"resnet", | ||
"examples/*", | ||
] | ||
|
||
[features] | ||
default = [] | ||
std = [] | ||
pretrained = ["burn/network", "std", "dep:dirs"] | ||
[workspace.package] | ||
edition = "2021" | ||
version = "0.2.0" | ||
readme = "README.md" | ||
license = "MIT OR Apache-2.0" | ||
|
||
[dependencies] | ||
[workspace.dependencies] | ||
# Note: default-features = false is needed to disable std | ||
burn = { version = "0.13.0", default-features = false } | ||
burn-import = "0.13.0" | ||
dirs = { version = "5.0.1", optional = true } | ||
dirs = "5.0.1" | ||
serde = { version = "1.0.192", default-features = false, features = [ | ||
"derive", | ||
"alloc", | ||
] } # alloc is for no_std, derive is needed | ||
|
||
[dev-dependencies] | ||
burn = { version = "0.13.0", features = ["ndarray"] } | ||
image = { version = "0.24.9", features = ["png", "jpeg"] } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# Downloaded files | ||
planet_sample/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
[package] | ||
authors = ["guillaumelagrange <[email protected]>"] | ||
name = "finetune" | ||
license.workspace = true | ||
version.workspace = true | ||
edition.workspace = true | ||
|
||
[features] | ||
default = ["burn/default"] | ||
tch-gpu = ["burn/tch"] | ||
wgpu = ["burn/wgpu"] | ||
|
||
[dependencies] | ||
resnet-burn = { path = "../../resnet", features = ["pretrained"] } | ||
burn = { workspace = true, features = ["train", "vision", "network"] } | ||
|
||
# Dataset files | ||
csv = "1.3.0" | ||
flate2 = "1.0.28" | ||
rand = { version = "0.8.5", default-features = false, features = [ | ||
"std_rng", | ||
] } | ||
serde = { version = "1.0.192", features = ["std", "derive"] } | ||
tar = "0.4.40" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
use burn::{backend::Autodiff, tensor::backend::Backend}; | ||
use finetune::{inference::infer, training::train}; | ||
|
||
#[allow(dead_code)] | ||
const ARTIFACT_DIR: &str = "/tmp/resnet-finetune"; | ||
|
||
#[allow(dead_code)] | ||
fn run<B: Backend>(device: B::Device) { | ||
train::<Autodiff<B>>(ARTIFACT_DIR, device.clone()); | ||
infer::<B>(ARTIFACT_DIR, device, 0.5); | ||
} | ||
|
||
#[cfg(feature = "tch-gpu")] | ||
mod tch_gpu { | ||
use burn::backend::libtorch::{LibTorch, LibTorchDevice}; | ||
|
||
pub fn run() { | ||
#[cfg(not(target_os = "macos"))] | ||
let device = LibTorchDevice::Cuda(0); | ||
#[cfg(target_os = "macos")] | ||
let device = LibTorchDevice::Mps; | ||
|
||
super::run::<LibTorch>(device); | ||
} | ||
} | ||
|
||
#[cfg(feature = "wgpu")] | ||
mod wgpu { | ||
use burn::{ | ||
backend::wgpu::{Wgpu, WgpuDevice}, | ||
Wgpu, | ||
}; | ||
|
||
pub fn run() { | ||
super::run::<Wgpu>(WgpuDevice::default()); | ||
} | ||
} | ||
|
||
fn main() { | ||
#[cfg(feature = "tch-gpu")] | ||
tch_gpu::run(); | ||
#[cfg(feature = "wgpu")] | ||
wgpu::run(); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
use burn::{ | ||
data::{ | ||
dataloader::batcher::Batcher, | ||
dataset::vision::{Annotation, ImageDatasetItem, PixelDepth}, | ||
}, | ||
prelude::*, | ||
}; | ||
|
||
use super::dataset::CLASSES; | ||
|
||
// ImageNet mean and std values | ||
const MEAN: [f32; 3] = [0.485, 0.456, 0.406]; | ||
const STD: [f32; 3] = [0.229, 0.224, 0.225]; | ||
|
||
// Planets patch size | ||
const WIDTH: usize = 256; | ||
const HEIGHT: usize = 256; | ||
|
||
/// Create a multi-hot encoded tensor. | ||
/// | ||
/// # Example | ||
/// | ||
/// ```rust, ignore | ||
/// let multi_hot = multi_hot::<B>(&[2, 5, 8], 10, &device); | ||
/// println!("{}", multi_hot.to_data()); | ||
/// // [0, 0, 1, 0, 0, 1, 0, 0, 1, 0] | ||
/// ``` | ||
pub fn multi_hot<B: Backend>( | ||
indices: &[usize], | ||
num_classes: usize, | ||
device: &B::Device, | ||
) -> Tensor<B, 1, Int> { | ||
Tensor::zeros(Shape::new([num_classes]), device).scatter( | ||
0, | ||
Tensor::from_ints( | ||
indices | ||
.iter() | ||
.map(|i| *i as i32) | ||
.collect::<Vec<_>>() | ||
.as_slice(), | ||
device, | ||
), | ||
Tensor::ones(Shape::new([indices.len()]), device), | ||
) | ||
} | ||
|
||
/// Normalizer with ImageNet values as it helps accelerate training since we are fine-tuning from | ||
/// ImageNet pre-trained weights and the model expects the data to be in this normalized range. | ||
#[derive(Clone)] | ||
pub struct Normalizer<B: Backend> { | ||
pub mean: Tensor<B, 4>, | ||
pub std: Tensor<B, 4>, | ||
} | ||
|
||
impl<B: Backend> Normalizer<B> { | ||
/// Creates a new normalizer. | ||
pub fn new(device: &Device<B>) -> Self { | ||
let mean = Tensor::from_floats(MEAN, device).reshape([1, 3, 1, 1]); | ||
let std = Tensor::from_floats(STD, device).reshape([1, 3, 1, 1]); | ||
Self { mean, std } | ||
} | ||
|
||
/// Normalizes the input image according to the ImageNet dataset. | ||
/// | ||
/// The input image should be in the range [0, 1]. | ||
/// The output image will be in the range [-1, 1]. | ||
/// | ||
/// The normalization is done according to the following formula: | ||
/// `input = (input - mean) / std` | ||
pub fn normalize(&self, input: Tensor<B, 4>) -> Tensor<B, 4> { | ||
(input - self.mean.clone()) / self.std.clone() | ||
} | ||
} | ||
|
||
#[derive(Clone)] | ||
pub struct ClassificationBatcher<B: Backend> { | ||
normalizer: Normalizer<B>, | ||
device: B::Device, | ||
} | ||
|
||
#[derive(Clone, Debug)] | ||
pub struct ClassificationBatch<B: Backend> { | ||
pub images: Tensor<B, 4>, | ||
pub targets: Tensor<B, 2, Int>, | ||
} | ||
|
||
impl<B: Backend> ClassificationBatcher<B> { | ||
pub fn new(device: B::Device) -> Self { | ||
Self { | ||
normalizer: Normalizer::<B>::new(&device), | ||
device, | ||
} | ||
} | ||
} | ||
|
||
impl<B: Backend> Batcher<ImageDatasetItem, ClassificationBatch<B>> for ClassificationBatcher<B> { | ||
fn batch(&self, items: Vec<ImageDatasetItem>) -> ClassificationBatch<B> { | ||
fn image_as_vec_u8(item: ImageDatasetItem) -> Vec<u8> { | ||
// Convert Vec<PixelDepth> to Vec<u8> (Planet images are u8) | ||
item.image | ||
.into_iter() | ||
.map(|p: PixelDepth| -> u8 { p.try_into().unwrap() }) | ||
.collect::<Vec<u8>>() | ||
} | ||
|
||
let targets = items | ||
.iter() | ||
.map(|item| { | ||
// Expect multi-hot encoded class labels as target (e.g., [0, 1, 0, 0, 1]) | ||
if let Annotation::MultiLabel(y) = &item.annotation { | ||
multi_hot(y, CLASSES.len(), &self.device) | ||
} else { | ||
panic!("Invalid target type") | ||
} | ||
}) | ||
.collect(); | ||
|
||
let images = items | ||
.into_iter() | ||
.map(|item| Data::new(image_as_vec_u8(item), Shape::new([HEIGHT, WIDTH, 3]))) | ||
.map(|data| Tensor::<B, 3>::from_data(data.convert(), &self.device).permute([2, 0, 1])) | ||
.map(|tensor| tensor / 255) // normalize between [0, 1] | ||
.collect(); | ||
|
||
let images = Tensor::stack(images, 0); | ||
let targets = Tensor::stack(targets, 0); | ||
|
||
let images = self.normalizer.normalize(images); | ||
|
||
ClassificationBatch { images, targets } | ||
} | ||
} |
Oops, something went wrong.