Skip to content

Commit

Permalink
Format.
Browse files Browse the repository at this point in the history
  • Loading branch information
twitchax committed Dec 21, 2023
1 parent e6c1777 commit 8121669
Show file tree
Hide file tree
Showing 10 changed files with 46 additions and 55 deletions.
4 changes: 3 additions & 1 deletion src/bin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,9 @@ fn start(args: Args) -> Void {
klib::ml::train::run_training::<Autodiff<NdArray<f32>>>(device, &config, true, true)?;
}
_ => {
return Err(anyhow::Error::msg("Invalid device (must choose either `gpu` [requires `ml_gpu` feature], `wgpu` [requires `ml_gpu` feature] or `cpu`)."));
return Err(anyhow::Error::msg(
"Invalid device (must choose either `gpu` [requires `ml_gpu` feature], `wgpu` [requires `ml_gpu` feature] or `cpu`).",
));
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/core/pitch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@ pub trait HasFrequency {
/// Essentially, mid way between the frequency and the next frequency on either side.
fn frequency_range(&self) -> (f32, f32) {
let frequency = self.frequency();

(frequency * (1.0 - 1.0 / 17.462 / 2.0), frequency * (1.0 + 1.0 / 16.8196 / 2.0))
}

/// Returns the tight frequency range of the type (usually a [`Note`]).
/// Essentially, 1/8 the way between the frequency and the next frequency on either side.
fn tight_frequency_range(&self) -> (f32, f32) {
let frequency = self.frequency();

(frequency * (1.0 - 1.0 / 17.462 / 8.0), frequency * (1.0 + 1.0 / 16.8196 / 8.0))
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/ml/base/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
use burn::tensor::{backend::Backend, Data, Tensor};

use super::{
helpers::{get_deterministic_guess, mel_filter_banks_from, u128_to_binary, note_binned_convolution},
helpers::{get_deterministic_guess, mel_filter_banks_from, note_binned_convolution, u128_to_binary},
KordItem, INPUT_SPACE_SIZE, NUM_CLASSES,
};

Expand Down
8 changes: 4 additions & 4 deletions src/ml/base/gather.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

use std::path::Path;

use crate::core::{
base::{Parsable, Void},
note::{HasNoteId, Note},
};
use crate::{
analyze::base::{get_frequency_space, get_smoothed_frequency_space},
ml::base::{KordItem, FREQUENCY_SPACE_SIZE},
};
use crate::core::{
base::{Parsable, Void},
note::{HasNoteId, Note},
};

use crate::analyze::mic::get_audio_data_from_microphone;

Expand Down
20 changes: 9 additions & 11 deletions src/ml/base/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,19 @@ use std::{
path::{Path, PathBuf},
};

use burn::{tensor::{backend::Backend, Tensor}, module::Module};
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};

use crate::{
analyze::base::get_notes_from_smoothed_frequency_space,
core::{
base::Res,
helpers::{inv_mel, mel},
note::{HasNoteId, Note, ALL_PITCH_NOTES_WITH_FREQUENCY}, pitch::HasFrequency,
note::{HasNoteId, Note, ALL_PITCH_NOTES_WITH_FREQUENCY},
pitch::HasFrequency,
},
};

Expand Down Expand Up @@ -113,7 +117,7 @@ pub fn note_binned_convolution(spectrum: &[f32]) -> [f32; NUM_CLASSES] {

for (note, _) in ALL_PITCH_NOTES_WITH_FREQUENCY.iter().skip(7).take(90) {
let id_index = note.id_index();

let (low, high) = note.tight_frequency_range();
let low = low.round() as usize;
let high = high.round() as usize;
Expand All @@ -137,13 +141,7 @@ pub fn note_binned_convolution(spectrum: &[f32]) -> [f32; NUM_CLASSES] {
pub fn harmonic_convolution(spectrum: &[f32]) -> [f32; FREQUENCY_SPACE_SIZE] {
let mut harmonic_convolution = [0f32; FREQUENCY_SPACE_SIZE];

let (peak, _) = spectrum.iter().enumerate().fold((0usize, 0f32), |(k, max), (j, x)| {
if *x > max {
(j, *x)
} else {
(k, max)
}
});
let (peak, _) = spectrum.iter().enumerate().fold((0usize, 0f32), |(k, max), (j, x)| if *x > max { (j, *x) } else { (k, max) });

for center in (peak / 2)..4000 {
let mut sum = spectrum[center];
Expand Down Expand Up @@ -239,4 +237,4 @@ impl<B: Backend> Sigmoid<B> {
//let scaled = input;
scaled.clone().exp().div(scaled.exp().add_scalar(1.0))
}
}
}
12 changes: 4 additions & 8 deletions src/ml/base/mlp.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
//! Multilayer Perceptron module.

use burn::{
nn::{self, LayerNormConfig, LayerNorm},
tensor::{backend::Backend, Tensor}, module::Module,
module::Module,
nn::{self, LayerNorm, LayerNormConfig},
tensor::{backend::Backend, Tensor},
};

/// Multilayer Perceptron module.
Expand All @@ -28,12 +29,7 @@ impl<B: Backend> Mlp<B> {
let dropout = nn::DropoutConfig::new(mlp_dropout).init();
let activation = nn::ReLU::new();

Self {
linears,
norm,
dropout,
activation
}
Self { linears, norm, dropout, activation }
}

/// Applies the forward pass on the input tensor.
Expand Down
27 changes: 12 additions & 15 deletions src/ml/base/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,21 @@
use core::f32;

use burn::{
nn::{self, attention::{MultiHeadAttentionConfig, MultiHeadAttention, MhaInput}},
tensor::{backend::Backend, Tensor}, module::Module,
module::Module,
nn::{
self,
attention::{MhaInput, MultiHeadAttention, MultiHeadAttentionConfig},
},
tensor::{backend::Backend, Tensor},
};

use super::{helpers::Sigmoid, INPUT_SPACE_SIZE, NUM_CLASSES};

#[cfg(feature = "ml_train")]
use crate::ml::train::{
data::KordBatch,
helpers::KordClassificationOutput,
};
use crate::ml::train::{data::KordBatch, helpers::KordClassificationOutput};

/// The Kord model.
///
///
/// This model is a transformer model that uses multi-head attention to classify notes from a frequency space.
#[derive(Module, Debug)]
pub struct KordModel<B: Backend> {
Expand All @@ -32,11 +33,7 @@ impl<B: Backend> KordModel<B> {
let output = nn::LinearConfig::new(INPUT_SPACE_SIZE, NUM_CLASSES).init::<B>();
let sigmoid = Sigmoid::new(sigmoid_strength);

Self {
mha,
output,
sigmoid,
}
Self { mha, output, sigmoid }
}

/// Applies the forward pass on the input tensor.
Expand All @@ -50,14 +47,14 @@ impl<B: Backend> KordModel<B> {

// Reshape the output to remove the sequence dimension.
let mut x = attn.context.reshape([batch_size, input_size]);

// Perform the final linear layer to map to the output dimensions.
x = self.output.forward(x);

// Apply the sigmoid function to the output to achieve multi-classification.
x = self.sigmoid.forward(x);

x
x
}

/// Applies the forward classification pass on the input tensor.
Expand Down Expand Up @@ -90,4 +87,4 @@ impl<B: Backend> KordModel<B> {

KordClassificationOutput { loss, output, targets }
}
}
}
3 changes: 2 additions & 1 deletion src/ml/infer/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
use burn::{
config::Config,
module::Module,
tensor::backend::Backend, record::{BinBytesRecorder, FullPrecisionSettings, Recorder},
record::{BinBytesRecorder, FullPrecisionSettings, Recorder},
tensor::backend::Backend,
};
use burn_ndarray::{NdArray, NdArrayDevice};
use serde::{de::DeserializeOwned, Serialize};
Expand Down
12 changes: 7 additions & 5 deletions src/ml/train/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
use std::sync::Arc;

use burn::{
backend::Autodiff,
config::Config,
data::dataloader::DataLoaderBuilder,
lr_scheduler::constant::ConstantLr,
module::Module,
optim::{decay::WeightDecayConfig, AdamConfig},
backend::Autodiff,
record::{BinFileRecorder, FullPrecisionSettings, Recorder},
tensor::backend::{AutodiffBackend, Backend},
train::{metric::LossMetric, LearnerBuilder}, lr_scheduler::constant::ConstantLr, record::{BinFileRecorder, FullPrecisionSettings, Recorder},
train::{metric::LossMetric, LearnerBuilder},
};
use serde::{de::DeserializeOwned, Serialize};

Expand All @@ -31,8 +33,8 @@ use super::{
use crate::ml::base::TrainConfig;

/// Run the training.
///
/// Given the [`TrainConfig`], this function will run the training and return the overall accuracy on
///
/// Given the [`TrainConfig`], this function will run the training and return the overall accuracy on
/// the validation / test set.
pub fn run_training<B: AutodiffBackend>(device: B::Device, config: &TrainConfig, print_accuracy_report: bool, save_model: bool) -> Res<f32>
where
Expand Down Expand Up @@ -157,7 +159,7 @@ pub fn compute_overall_accuracy<B: Backend>(model_trained: &KordModel<B>, device
}

/// Run hyper parameter tuning.
///
///
///This method sweeps through the hyper parameters and runs training for each combination. The best
/// hyper parameters are then printed at the end.
#[coverage(off)]
Expand Down
9 changes: 2 additions & 7 deletions src/ml/train/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@ use burn::{
},
train::{
metric::{
MetricMetadata,
state::{FormatOptions, NumericMetricState},
Adaptor, LossInput, Metric, MetricEntry, Numeric,
Adaptor, LossInput, Metric, MetricEntry, MetricMetadata, Numeric,
},
TrainOutput, TrainStep, ValidStep,
},
Expand All @@ -24,11 +23,7 @@ use crate::{
note::{HasNoteId, Note, ALL_PITCH_NOTES},
pitch::HasFrequency,
},
ml::base::{
helpers::load_kord_item,
model::KordModel,
KordItem, FREQUENCY_SPACE_SIZE, NUM_CLASSES,
},
ml::base::{helpers::load_kord_item, model::KordModel, KordItem, FREQUENCY_SPACE_SIZE, NUM_CLASSES},
};

use super::data::KordBatch;
Expand Down

0 comments on commit 8121669

Please sign in to comment.