Skip to content

Commit

Permalink
replace sequential impl
Browse files Browse the repository at this point in the history
  • Loading branch information
Arjun31415 committed Apr 18, 2024
1 parent 848c172 commit 2a8d223
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 94 deletions.
26 changes: 7 additions & 19 deletions mobilenet-burn/src/model/conv_norm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,22 @@ use burn::{
},
tensor::{self, backend::Backend, Tensor},
};
use serde::{Deserialize, Serialize};

#[derive(Module, Debug, Clone, Default)]
pub struct ReLU6 {}
impl ReLU6 {
pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
tensor::activation::relu(input).clamp_max(6)
}
}

#[derive(Module, Debug)]
pub struct Conv2dNormActivation<B: Backend> {
conv: Conv2d<B>,
norm_layer: NormalizationLayer<B, 4>,
norm_layer: BatchNorm<B, 4>,
activation: ReLU6,
}
#[derive(Module, Debug)]
pub enum NormalizationLayer<B: Backend, const D: usize> {
BatchNorm(BatchNorm<B, D>),
}

#[derive(Config, Debug)]
pub struct Conv2dNormActivationConfig {
pub in_channels: usize,
Expand All @@ -48,20 +46,12 @@ pub struct Conv2dNormActivationConfig {
#[config(default = false)]
pub bias: bool,

pub norm_type: NormalizationType,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum NormalizationType {
BatchNorm(BatchNormConfig),
pub batch_norm_config: BatchNormConfig,
}

impl Conv2dNormActivationConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> Conv2dNormActivation<B> {
let norm_layer = match &self.norm_type {
NormalizationType::BatchNorm(config) => {
NormalizationLayer::BatchNorm(config.init(device))
}
};
let norm_layer = self.batch_norm_config.init(device);
Conv2dNormActivation {
conv: Conv2dConfig::new(
[self.in_channels, self.out_channels],
Expand All @@ -81,9 +71,7 @@ impl Conv2dNormActivationConfig {
impl<B: Backend> Conv2dNormActivation<B> {
pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
let x = self.conv.forward(input);
let x = match &self.norm_layer {
NormalizationLayer::BatchNorm(norm) => norm.forward(x),
};
let x = self.norm_layer.forward(x);
self.activation.forward(x)
}
}
90 changes: 42 additions & 48 deletions mobilenet-burn/src/model/inverted_residual.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,24 @@
use super::conv_norm::NormalizationLayer;
use super::conv_norm::NormalizationType;
use super::conv_norm::{Conv2dNormActivation, Conv2dNormActivationConfig};
use alloc::vec;
use alloc::vec::Vec;
use burn::config::Config;
use burn::nn::conv::Conv2dConfig;
use burn::nn::BatchNormConfig;
use burn::nn::{BatchNorm, BatchNormConfig};
use burn::tensor::Tensor;
use burn::{module::Module, nn::conv::Conv2d, tensor::backend::Backend};

#[derive(Module, Debug)]
enum InvertedResidualSequentialType<B: Backend> {
Conv(Conv2d<B>),
ConvNormActivation(Conv2dNormActivation<B>),
NormLayer(NormalizationLayer<B, 4>),
pub struct PointWiseLinear<B: Backend> {
conv: Conv2d<B>,
norm: BatchNorm<B, 4>,
}

/// Inverted Residual Block
/// Ref: https://paperswithcode.com/method/inverted-residual-block
#[derive(Module, Debug)]
pub struct InvertedResidual<B: Backend> {
use_res_connect: bool,
layers: Vec<InvertedResidualSequentialType<B>>,
pw: Option<Conv2dNormActivation<B>>, // pointwise, only when expand ratio != 1
dw: Conv2dNormActivation<B>,
pw_linear: PointWiseLinear<B>,
}

#[derive(Config, Debug)]
Expand All @@ -28,64 +27,59 @@ pub struct InvertedResidualConfig {
pub oup: usize,
pub stride: usize,
pub expand_ratio: usize,
pub norm_type: NormalizationType,
pub norm_type:BatchNormConfig,
}

impl InvertedResidualConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> InvertedResidual<B> {
let mut layers = Vec::new();
let hidden_dim = self.inp * self.expand_ratio;
if self.expand_ratio != 1 {
layers.push(InvertedResidualSequentialType::ConvNormActivation(
let pw = if self.expand_ratio != 1 {
Some(
Conv2dNormActivationConfig::new(self.inp, hidden_dim, self.norm_type.clone())
.with_kernel_size(1)
.init(device),
));
}
let mut temp_layer: Vec<InvertedResidualSequentialType<B>> = vec![
InvertedResidualSequentialType::ConvNormActivation(
Conv2dNormActivationConfig::new(hidden_dim, hidden_dim, self.norm_type.clone())
.with_stride(self.stride)
.with_groups(hidden_dim)
.init(device),
),
InvertedResidualSequentialType::Conv(
Conv2dConfig::new([hidden_dim, self.oup], [1, 1])
.with_stride([1, 1])
.with_padding(burn::nn::PaddingConfig2d::Explicit(0, 0))
.with_bias(false)
.init(device),
),
match self.norm_type {
NormalizationType::BatchNorm(_) => InvertedResidualSequentialType::NormLayer(
NormalizationLayer::BatchNorm(BatchNormConfig::new(self.oup).init(device)),
),
},
];
layers.append(&mut temp_layer);
)
} else {
None
};
let dw = Conv2dNormActivationConfig::new(hidden_dim, hidden_dim, self.norm_type.clone())
.with_stride(self.stride)
.with_groups(hidden_dim)
.init(device);
let pw_linear = PointWiseLinear {
conv: Conv2dConfig::new([hidden_dim, self.oup], [1, 1])
.with_stride([1, 1])
.with_padding(burn::nn::PaddingConfig2d::Explicit(0, 0))
.with_bias(false)
.init(device),
norm: BatchNormConfig::new(self.oup).init(device),
};
InvertedResidual {
use_res_connect: self.stride == 1 && self.inp == self.oup,
layers,
pw_linear,
dw,
pw,
}
}
}
impl<B: Backend> InvertedResidual<B> {
pub fn forward(&self, x: &Tensor<B, 4>) -> Tensor<B, 4> {
let mut out = x.clone();
for layer in &self.layers {
match layer {
InvertedResidualSequentialType::Conv(conv) => out = conv.forward(out),
InvertedResidualSequentialType::ConvNormActivation(conv_norm) => {
out = conv_norm.forward(out)
}
InvertedResidualSequentialType::NormLayer(NormalizationLayer::BatchNorm(x)) => {
out = x.forward(out)
}
}
if let Some(pw) = &self.pw {
out = pw.forward(out);
}
out = self.dw.forward(out);
out = self.pw_linear.forward(out);

if self.use_res_connect {
out = out + x.clone();
}
out
}
}

impl<B: Backend> PointWiseLinear<B> {
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
self.norm.forward(self.conv.forward(x))
}
}
47 changes: 20 additions & 27 deletions mobilenet-burn/src/model/mobilenet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use burn::{
module::Module,
nn::{
pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig},
Dropout, DropoutConfig, Linear, LinearConfig,
BatchNormConfig, Dropout, DropoutConfig, Linear, LinearConfig,
},
tensor::{backend::Backend, Tensor},
};
Expand All @@ -24,15 +24,13 @@ use {
burn_import::pytorch::{LoadArgs, PyTorchFileRecorder},
};

use super::{
conv_norm::{Conv2dNormActivation, NormalizationType},
inverted_residual::InvertedResidual,
};
use super::{conv_norm::Conv2dNormActivation, inverted_residual::InvertedResidual};

#[derive(Debug, Module)]
pub struct MobileNetV2<B: Backend> {
features: Vec<ConvBlock<B>>,
classifier: Vec<ClassifierLayersType<B>>,
// classifier: Vec<ClassifierLayersType<B>>,
classifier: Classifier<B>,
avg_pool: AdaptiveAvgPool2d,
}

Expand All @@ -50,7 +48,7 @@ pub struct MobileNetV2Config {
#[config(default = "8")]
round_nearest: usize,

norm_layer: NormalizationType,
norm_layer: BatchNormConfig,

#[config(default = "0.2")]
dropout: f64,
Expand All @@ -62,10 +60,17 @@ enum ConvBlock<B: Backend> {
Conv(Conv2dNormActivation<B>),
}


#[derive(Module, Debug)]
enum ClassifierLayersType<B: Backend> {
Dropout(Dropout),
Linear(Linear<B>),
struct Classifier<B: Backend> {
dropout: Dropout,
linear: Linear<B>,
}
impl<B: Backend> Classifier<B> {
fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
let x = self.dropout.forward(input);
self.linear.forward(x)
}
}

impl MobileNetV2Config {
Expand Down Expand Up @@ -128,12 +133,10 @@ impl MobileNetV2Config {
.init(device),
));

let classifier = vec![
ClassifierLayersType::Dropout(DropoutConfig::new(self.dropout).init()),
ClassifierLayersType::Linear(
LinearConfig::new(last_channel, self.num_classes).init(device),
),
];
let classifier = Classifier {
dropout: DropoutConfig::new(self.dropout).init(),
linear: LinearConfig::new(last_channel, self.num_classes).init(device),
};

MobileNetV2 {
features,
Expand All @@ -155,19 +158,9 @@ impl<B: Backend> MobileNetV2<B> {
}
}
}

x = self.avg_pool.forward(x);
x = x.flatten(1, 1);
for layer in &self.classifier {
match layer {
ClassifierLayersType::Dropout(dropout) => {
x = dropout.forward(x);
}
ClassifierLayersType::Linear(linear) => {
x = linear.forward(x);
}
}
}
x = self.classifier.forward(x);
x
}
}
Expand Down

0 comments on commit 2a8d223

Please sign in to comment.