Skip to content

Commit

Permalink
Merge pull request #69 from jkawamoto/lint
Browse files Browse the repository at this point in the history
Fix lint errors
  • Loading branch information
jkawamoto authored Jul 14, 2024
2 parents e2b74a3 + c14f04e commit e3587aa
Show file tree
Hide file tree
Showing 18 changed files with 121 additions and 132 deletions.
1 change: 0 additions & 1 deletion .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ jobs:
- uses: actions/checkout@v4
with:
submodules: recursive
- uses: dtolnay/rust-toolchain@stable
- uses: Swatinem/rust-cache@v2
- name: Build
run: cargo build -vv --no-default-features -F "${{ matrix.feature }}"
Expand Down
10 changes: 8 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ repos:
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v18.1.6
rev: v18.1.8
hooks:
- id: clang-format
types_or: [c++, c, cuda]
- repo: https://github.com/google/yamlfmt
rev: v0.12.1
rev: v0.13.0
hooks:
- id: yamlfmt
- repo: local
Expand All @@ -25,3 +25,9 @@ repos:
language: system
files: \.(toml|md)$
types: [text]
- id: rustfmt
name: run rustfmt
entry: cargo fmt --
language: system
files: \.rs$
types: [text]
2 changes: 1 addition & 1 deletion build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ fn link_libraries<T: AsRef<Path>>(root: T) {
.iter()
.for_each(|name| {
let parent = path.parent();
if parent != current_dir.as_ref().map(|p: &PathBuf| p.as_path()) {
if parent != current_dir.as_deref() {
let dir = parent.unwrap();
println!("cargo:rustc-link-search={}", dir.display());
current_dir = Some(dir.to_path_buf())
Expand Down
22 changes: 10 additions & 12 deletions examples/bart.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,20 +75,18 @@ fn main() -> Result<()> {

let t = Translator::new(&args.path, &cfg)?;

let source =
BufReader::new(File::open(args.prompt)?)
.lines()
.fold(Ok(String::new()), |acc, line| {
acc.and_then(|mut acc| {
line.map(|l| {
acc.push_str(&l);
acc
})
})
})?;
let source = BufReader::new(File::open(args.prompt)?).lines().try_fold(
String::new(),
|mut acc, line| {
line.map(|l| {
acc.push_str(&l);
acc
})
},
)?;

let now = time::Instant::now();
let res = t.translate_batch(&vec![source], &Default::default(), None)?;
let res = t.translate_batch(&[source], &Default::default(), None)?;
let elapsed = now.elapsed();

for (r, _) in res {
Expand Down
22 changes: 10 additions & 12 deletions examples/falcon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,21 +74,19 @@ fn main() -> Result<()> {
};

let g = Generator::new(&args.path, &cfg)?;
let prompts =
BufReader::new(File::open(args.prompt)?)
.lines()
.fold(Ok(String::new()), |acc, line| {
acc.and_then(|mut acc| {
line.map(|l| {
acc.push_str(&l);
acc
})
})
})?;
let prompts = BufReader::new(File::open(args.prompt)?).lines().try_fold(
String::new(),
|mut acc, line| {
line.map(|l| {
acc.push_str(&l);
acc
})
},
)?;

let now = time::Instant::now();
let res = g.generate_batch(
&vec![prompts],
&[prompts],
&GenerationOptions {
max_length: 200,
sampling_topk: 10,
Expand Down
22 changes: 10 additions & 12 deletions examples/gpt-2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,21 +72,19 @@ fn main() -> Result<()> {
};

let g = Generator::new(&args.path, &cfg)?;
let prompts =
BufReader::new(File::open(args.prompt)?)
.lines()
.fold(Ok(String::new()), |acc, line| {
acc.and_then(|mut acc| {
line.map(|l| {
acc.push_str(&l);
acc
})
})
})?;
let prompts = BufReader::new(File::open(args.prompt)?).lines().try_fold(
String::new(),
|mut acc, line| {
line.map(|l| {
acc.push_str(&l);
acc
})
},
)?;

let now = time::Instant::now();
let res = g.generate_batch(
&vec![prompts],
&[prompts],
&GenerationOptions {
max_length: 30,
sampling_topk: 10,
Expand Down
22 changes: 10 additions & 12 deletions examples/gpt-j.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,21 +74,19 @@ fn main() -> Result<()> {
};

let g = Generator::new(&args.path, &cfg)?;
let prompts =
BufReader::new(File::open(args.prompt)?)
.lines()
.fold(Ok(String::new()), |acc, line| {
acc.and_then(|mut acc| {
line.map(|l| {
acc.push_str(&l);
acc
})
})
})?;
let prompts = BufReader::new(File::open(args.prompt)?).lines().try_fold(
String::new(),
|mut acc, line| {
line.map(|l| {
acc.push_str(&l);
acc
})
},
)?;

let now = time::Instant::now();
let res = g.generate_batch(
&vec![prompts],
&[prompts],
&GenerationOptions {
max_length: 30,
sampling_topk: 10,
Expand Down
22 changes: 10 additions & 12 deletions examples/gpt-neox.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,21 +74,19 @@ fn main() -> Result<()> {
};

let g = Generator::new(&args.path, &cfg)?;
let prompts =
BufReader::new(File::open(args.prompt)?)
.lines()
.fold(Ok(String::new()), |acc, line| {
acc.and_then(|mut acc| {
line.map(|l| {
acc.push_str(&l);
acc
})
})
})?;
let prompts = BufReader::new(File::open(args.prompt)?).lines().try_fold(
String::new(),
|mut acc, line| {
line.map(|l| {
acc.push_str(&l);
acc
})
},
)?;

let now = time::Instant::now();
let res = g.generate_batch(
&vec![prompts],
&[prompts],
&GenerationOptions {
max_length: 64,
sampling_topk: 20,
Expand Down
22 changes: 10 additions & 12 deletions examples/mpt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,21 +74,19 @@ fn main() -> Result<()> {
};

let g = Generator::new(&args.path, &cfg)?;
let prompts =
BufReader::new(File::open(args.prompt)?)
.lines()
.fold(Ok(String::new()), |acc, line| {
acc.and_then(|mut acc| {
line.map(|l| {
acc.push_str(&l);
acc
})
})
})?;
let prompts = BufReader::new(File::open(args.prompt)?).lines().try_fold(
String::new(),
|mut acc, line| {
line.map(|l| {
acc.push_str(&l);
acc
})
},
)?;

let now = time::Instant::now();
let res = g.generate_batch(
&vec![prompts],
&[prompts],
&GenerationOptions {
max_length: 30,
sampling_topk: 10,
Expand Down
24 changes: 11 additions & 13 deletions examples/opt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ use std::time;
use anyhow::Result;
use clap::Parser;

use ct2rs::{GenerationOptions, Generator};
use ct2rs::bpe;
use ct2rs::config::{Config, Device};
use ct2rs::{GenerationOptions, Generator};

/// Generate text using OPT models.
#[derive(Parser, Debug)]
Expand Down Expand Up @@ -75,21 +75,19 @@ fn main() -> Result<()> {
bpe::new(&args.path, Some("Ġ".to_string()))?,
&cfg,
)?;
let prompts =
BufReader::new(File::open(args.prompt)?)
.lines()
.fold(Ok(String::new()), |acc, line| {
acc.and_then(|mut acc| {
line.map(|l| {
acc.push_str(&l);
acc
})
})
})?;
let prompts = BufReader::new(File::open(args.prompt)?).lines().try_fold(
String::new(),
|mut acc, line| {
line.map(|l| {
acc.push_str(&l);
acc
})
},
)?;

let now = time::Instant::now();
let res = g.generate_batch(
&vec![prompts],
&[prompts],
&GenerationOptions {
beam_size: 15,
max_length: 50,
Expand Down
22 changes: 10 additions & 12 deletions examples/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,21 +66,19 @@ fn main() -> Result<()> {
};

let t = Translator::new(&args.path, &cfg)?;
let source =
BufReader::new(File::open(args.prompt)?)
.lines()
.fold(Ok(String::new()), |acc, line| {
acc.and_then(|mut acc| {
line.map(|l| {
acc.push_str(&l);
acc
})
})
})?;
let source = BufReader::new(File::open(args.prompt)?).lines().try_fold(
String::new(),
|mut acc, line| {
line.map(|l| {
acc.push_str(&l);
acc
})
},
)?;

let mut out = stdout();
let _ = t.translate_batch(
&vec![source],
&[source],
&TranslationOptions {
// beam_size must be 1 to use the stream API.
beam_size: 1,
Expand Down
8 changes: 4 additions & 4 deletions examples/whisper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ use anyhow::Result;
use clap::Parser;
use hound::WavReader;
use ndarray::{Array2, Ix3};
use rustfft::FftPlanner;
use rustfft::num_complex::Complex;
use rustfft::FftPlanner;
use serde::Deserialize;

use ct2rs::{auto, Tokenizer};
use ct2rs::storage_view::StorageView;
use ct2rs::whisper::Whisper;
use ct2rs::{auto, Tokenizer};

const PREPROCESSOR_CONFIG_FILE: &str = "preprocessor_config.json";

Expand Down Expand Up @@ -104,7 +104,7 @@ fn main() -> Result<()> {
// Transcribe.
let res = model.generate(
&storage_view,
&vec![vec![
&[vec![
"<|startoftranscript|>",
&lang[0][0].language,
"<|transcribe|>",
Expand Down Expand Up @@ -242,7 +242,7 @@ impl PreprocessorConfig {
let rows = aux.mel_filters.len();
let cols = aux
.mel_filters
.get(0)
.first()
.map(|row| row.len())
.unwrap_or_default();

Expand Down
10 changes: 5 additions & 5 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ use std::fmt::{Debug, Display, Formatter};
use cxx::UniquePtr;

pub use ffi::{
BatchType, ComputeType, Device, get_device_count, get_log_level, get_random_seed,
LogLevel, set_log_level, set_random_seed,
get_device_count, get_log_level, get_random_seed, set_log_level, set_random_seed, BatchType,
ComputeType, Device, LogLevel,
};

#[cxx::bridge]
Expand Down Expand Up @@ -369,8 +369,8 @@ mod tests {
use rand::random;

use crate::config::{
BatchType, ComputeType, Config, Device, get_device_count,
get_log_level, get_random_seed, LogLevel, set_log_level, set_random_seed,
get_device_count, get_log_level, get_random_seed, set_log_level, set_random_seed,
BatchType, ComputeType, Config, Device, LogLevel,
};

#[test]
Expand Down Expand Up @@ -431,7 +431,7 @@ mod tests {

#[test]
fn test_log_level() {
for l in vec![
for l in [
LogLevel::Off,
LogLevel::Critical,
LogLevel::Error,
Expand Down
4 changes: 2 additions & 2 deletions src/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,11 +247,11 @@ impl Generator {
/// # Ok(())
/// # }
/// ```
pub fn generate_batch<'a, T: AsRef<str>, U: AsRef<str>, V: AsRef<str>, W: AsRef<str>>(
pub fn generate_batch<T: AsRef<str>, U: AsRef<str>, V: AsRef<str>, W: AsRef<str>>(
&self,
start_tokens: &[Vec<T>],
options: &GenerationOptions<U, V, W>,
callback: Option<&'a mut dyn FnMut(GenerationStepResult) -> bool>,
callback: Option<&mut dyn FnMut(GenerationStepResult) -> bool>,
) -> Result<Vec<GenerationResult>> {
Ok(self
.ptr
Expand Down
Loading

0 comments on commit e3587aa

Please sign in to comment.