-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
20 changed files
with
1,853 additions
and
304 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
[package] | ||
authors = ["Arjun31415", "guillaumelagrange <[email protected]>"] | ||
license = "MIT OR Apache-2.0" | ||
name = "mobilenetv2-burn" | ||
version = "0.1.0" | ||
edition = "2021" | ||
|
||
[features] | ||
default = [] | ||
std = [] | ||
pretrained = ["burn/network", "std", "dep:dirs"] | ||
|
||
[dependencies] | ||
# Note: default-features = false is needed to disable std | ||
burn = { version = "0.13.0" } | ||
burn-import = { version = "0.13.0" } | ||
dirs = { version = "5.0.1", optional = true } | ||
serde = { version = "1.0.192", default-features = false, features = [ | ||
"derive", | ||
"alloc", | ||
] } # alloc is for no_std, derive is needed | ||
|
||
[dev-dependencies] | ||
burn = { version = "0.13.0", features = ["ndarray"] } | ||
image = { version = "0.24.9", features = ["png", "jpeg"] } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# NOTICES AND INFORMATION | ||
|
||
This file contains notices and information required by libraries that this repository copied or derived from. The use of the following resources complies with the licenses provided. | ||
|
||
## Sample Image | ||
|
||
Image Title: Standing yellow Labrador Retriever dog. | ||
Author: Djmirko | ||
Source: https://commons.wikimedia.org/wiki/File:YellowLabradorLooking_new.jpg | ||
License: https://creativecommons.org/licenses/by-sa/3.0/ | ||
|
||
## Pre-trained Model | ||
|
||
The ImageNet pre-trained model was ported from [`torchvision.models.MobileNet_V2_Weights.IMAGENET1K_V2`](https://pytorch.org/vision/main/models/generated/torchvision.models.mobilenet_v2.html#torchvision.models.MobileNet_V2_Weights). | ||
|
||
As opposed to [other pre-trained models](https://pytorch.org/vision/stable/models/generated/torchvision.models.regnet_y_128gf.html#torchvision.models.RegNet_Y_128GF_Weights) in `torchvision`, no specific license was linked to the weights, which are assumed to be under the library's [BSD-3-Clause license](https://github.com/pytorch/vision/blob/main/LICENSE) ([ref](https://github.com/pytorch/vision/issues/160)). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# MobileNetV2 Burn | ||
|
||
[MobileNetV2](https://arxiv.org/abs/1801.04381) is a convolutional neural network architecture for | ||
classification tasks which seeks to perform well on mobile devices. You can find the | ||
[Burn](https://github.com/tracel-ai/burn) implementation for the MobileNetV2 in | ||
[src/model/mobilenetv2.rs](src/model/mobilenetv2.rs). | ||
|
||
The model is [no_std compatible](https://docs.rust-embedded.org/book/intro/no-std.html). | ||
|
||
## Usage | ||
|
||
### `Cargo.toml` | ||
|
||
Add this to your `Cargo.toml`: | ||
|
||
```toml | ||
[dependencies] | ||
mobilenetv2-burn = { git = "https://github.com/tracel-ai/models", package = "mobilenetv2-burn", default-features = false } | ||
``` | ||
|
||
If you want to get the pre-trained ImageNet weights, enable the `pretrained` feature flag. | ||
|
||
```toml | ||
[dependencies] | ||
mobilenetv2-burn = { git = "https://github.com/tracel-ai/models", package = "mobilenetv2-burn", features = ["pretrained"] } | ||
``` | ||
|
||
**Important:** this feature requires `std`. | ||
|
||
### Example Usage | ||
|
||
The [inference example](examples/inference.rs) initializes a MobileNetV2 from the ImageNet | ||
[pre-trained weights](https://pytorch.org/vision/main/models/generated/torchvision.models.mobilenet_v2.html#torchvision.models.MobileNet_V2_Weights) | ||
with the `NdArray` backend and performs inference on the provided input image. | ||
|
||
You can run the example with the following command: | ||
|
||
```sh | ||
cargo run --release --features pretrained --example inference samples/dog.jpg | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
use mobilenetv2_burn::model::{imagenet, mobilenetv2::MobileNetV2, weights}; | ||
|
||
use burn::{ | ||
backend::NdArray, | ||
tensor::{backend::Backend, Data, Device, Element, Shape, Tensor}, | ||
}; | ||
|
||
const HEIGHT: usize = 224; | ||
const WIDTH: usize = 224; | ||
|
||
fn to_tensor<B: Backend, T: Element>( | ||
data: Vec<T>, | ||
shape: [usize; 3], | ||
device: &Device<B>, | ||
) -> Tensor<B, 3> { | ||
Tensor::<B, 3>::from_data(Data::new(data, Shape::new(shape)).convert(), device) | ||
// [H, W, C] -> [C, H, W] | ||
.permute([2, 0, 1]) | ||
/ 255 // normalize between [0, 1] | ||
} | ||
|
||
pub fn main() { | ||
// Parse arguments | ||
let img_path = std::env::args().nth(1).expect("No image path provided"); | ||
|
||
// Create MobileNetV2 | ||
let device = Default::default(); | ||
let model: MobileNetV2<NdArray> = | ||
MobileNetV2::pretrained(weights::MobileNetV2::ImageNet1kV2, &device) | ||
.map_err(|err| format!("Failed to load pre-trained weights.\nError: {err}")) | ||
.unwrap(); | ||
|
||
// Load image | ||
let img = image::open(&img_path) | ||
.map_err(|err| format!("Failed to load image {img_path}.\nError: {err}")) | ||
.unwrap(); | ||
|
||
// Resize to 224x224 | ||
let resized_img = img.resize_exact( | ||
WIDTH as u32, | ||
HEIGHT as u32, | ||
image::imageops::FilterType::Triangle, // also known as bilinear in 2D | ||
); | ||
|
||
// Create tensor from image data | ||
let img_tensor = to_tensor( | ||
resized_img.into_rgb8().into_raw(), | ||
[HEIGHT, WIDTH, 3], | ||
&device, | ||
) | ||
.unsqueeze::<4>(); // [B, C, H, W] | ||
|
||
// Normalize the image | ||
let x = imagenet::Normalizer::new(&device).normalize(img_tensor); | ||
|
||
// Forward pass | ||
let out = model.forward(x); | ||
|
||
// Output class index w/ score (raw) | ||
let (score, idx) = out.max_dim_with_indices(1); | ||
let idx = idx.into_scalar() as usize; | ||
|
||
println!( | ||
"Predicted: {}\nCategory Id: {}\nScore: {:.4}", | ||
imagenet::CLASSES[idx], | ||
idx, | ||
score.into_scalar() | ||
); | ||
} |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
#![cfg_attr(not(feature = "std"), no_std)] | ||
pub mod model; | ||
extern crate alloc; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
use burn::{ | ||
config::Config, | ||
module::Module, | ||
nn::{ | ||
conv::{Conv2d, Conv2dConfig}, | ||
BatchNorm, BatchNormConfig, PaddingConfig2d, | ||
}, | ||
tensor::{self, backend::Backend, Tensor}, | ||
}; | ||
|
||
/// A rectified linear unit where the activation is limited to a maximum of 6. | ||
#[derive(Module, Debug, Clone, Default)] | ||
pub struct ReLU6 {} | ||
impl ReLU6 { | ||
pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> { | ||
tensor::activation::relu(input).clamp_max(6) | ||
} | ||
} | ||
|
||
/// A Conv2d -> BatchNorm -> activation block. | ||
#[derive(Module, Debug)] | ||
pub struct Conv2dNormActivation<B: Backend> { | ||
conv: Conv2d<B>, | ||
norm: BatchNorm<B, 2>, | ||
activation: ReLU6, | ||
} | ||
|
||
/// [Conv2dNormActivation] configuration. | ||
#[derive(Config, Debug)] | ||
pub struct Conv2dNormActivationConfig { | ||
pub in_channels: usize, | ||
pub out_channels: usize, | ||
|
||
#[config(default = "3")] | ||
pub kernel_size: usize, | ||
|
||
#[config(default = "1")] | ||
pub stride: usize, | ||
|
||
#[config(default = "None")] | ||
pub padding: Option<usize>, | ||
|
||
#[config(default = "1")] | ||
pub groups: usize, | ||
|
||
#[config(default = "1")] | ||
pub dilation: usize, | ||
|
||
#[config(default = false)] | ||
pub bias: bool, | ||
} | ||
|
||
impl Conv2dNormActivationConfig { | ||
pub fn init<B: Backend>(&self, device: &B::Device) -> Conv2dNormActivation<B> { | ||
let padding = if let Some(padding) = self.padding { | ||
padding | ||
} else { | ||
(self.kernel_size - 1) / 2 * self.dilation | ||
}; | ||
|
||
Conv2dNormActivation { | ||
conv: Conv2dConfig::new( | ||
[self.in_channels, self.out_channels], | ||
[self.kernel_size, self.kernel_size], | ||
) | ||
.with_padding(PaddingConfig2d::Explicit(padding, padding)) | ||
.with_stride([self.stride, self.stride]) | ||
.with_bias(self.bias) | ||
.with_dilation([self.dilation, self.dilation]) | ||
.with_groups(self.groups) | ||
.init(device), | ||
norm: BatchNormConfig::new(self.out_channels).init(device), | ||
activation: ReLU6 {}, | ||
} | ||
} | ||
} | ||
impl<B: Backend> Conv2dNormActivation<B> { | ||
pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> { | ||
let x = self.conv.forward(input); | ||
let x = self.norm.forward(x); | ||
self.activation.forward(x) | ||
} | ||
} |
Oops, something went wrong.