Skip to content

Commit

Permalink
ダウンローダーに--only <TARGET>...--exclude <TARGET>...を追加
Browse files Browse the repository at this point in the history
  • Loading branch information
qryxip committed Oct 15, 2023
1 parent 77bf10f commit a3705e7
Showing 1 changed file with 92 additions and 14 deletions.
106 changes: 92 additions & 14 deletions crates/download/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{
borrow::Cow,
collections::HashSet,
env,
future::Future,
io::{self, Cursor, Read},
Expand All @@ -26,7 +27,7 @@ use once_cell::sync::Lazy;
use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _};
use strum::{Display, IntoStaticStr};
use tokio::task::{JoinError, JoinSet};
use tracing::info;
use tracing::{info, warn};
use url::Url;
use zip::ZipArchive;

Expand All @@ -48,7 +49,20 @@ static OPEN_JTALK_DIC_URL: Lazy<Url> = Lazy::new(|| {

#[derive(clap::Parser)]
struct Args {
/// ダウンロードするライブラリを最小限にするように指定
/// ダウンロード対象を限定する
#[arg(
long,
num_args(1..),
value_name("TARGET"),
conflicts_with_all(["exclude", "min"]))
]
only: Vec<DownloadTarget>,

/// ダウンロード対象を除外する
#[arg(long, num_args(1..), value_name("TARGET"), conflicts_with("min"))]
exclude: Vec<DownloadTarget>,

/// `--only core`のエイリアス
#[arg(long, conflicts_with("additional_libraries_version"))]
min: bool,

Expand All @@ -65,7 +79,12 @@ struct Args {
additional_libraries_version: String,

/// ダウンロードするデバイスを指定する(cudaはlinuxのみ)
#[arg(value_enum, long, default_value(<&str>::from(Device::default())))]
#[arg(
value_enum,
long,
default_value(<&str>::from(Device::default())),
required_if_eq("only", "additional-libraries")
)]
device: Device,

/// ダウンロードするcpuのアーキテクチャを指定する
Expand All @@ -87,6 +106,14 @@ struct Args {
additional_libraries_repo: RepoName,
}

#[derive(ValueEnum, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
enum DownloadTarget {
Core,
Models,
AdditionalLibraries,
Dict,
}

#[derive(Default, ValueEnum, Display, IntoStaticStr, Clone, Copy, PartialEq)]
#[strum(serialize_all = "kebab-case")]
enum Device {
Expand Down Expand Up @@ -133,8 +160,9 @@ impl Os {
}
}

#[derive(parse_display::FromStr, Clone)]
#[derive(parse_display::FromStr, parse_display::Display, Clone)]
#[from_str(regex = "(?<owner>[a-zA-Z0-9_]+)/(?<repo>[a-zA-Z0-9_]+)")]
#[display("{owner}/{repo}")]
struct RepoName {
owner: String,
repo: String,
Expand All @@ -145,6 +173,8 @@ async fn main() -> anyhow::Result<()> {
setup_logger();

let Args {
only,
exclude,
min,
output,
version,
Expand All @@ -156,6 +186,51 @@ async fn main() -> anyhow::Result<()> {
additional_libraries_repo,
} = Args::parse();

let targets: HashSet<_> = if !only.is_empty() {
assert!(exclude.is_empty() && !min);
only.into_iter().collect()
} else if !exclude.is_empty() {
assert!(!min);
DownloadTarget::value_variants()
.iter()
.copied()
.filter(|t| !exclude.contains(t))
.collect()
} else if min {
[DownloadTarget::Core].into()
} else {
DownloadTarget::value_variants().iter().copied().collect()
};

if !(targets.contains(&DownloadTarget::Core) || targets.contains(&DownloadTarget::Models)) {
if version != "latest" {
warn!(
"`--version={version}`が指定されていますが、`core`も`models`もダウンロード対象から\
除外されています",
);
}
if core_repo.to_string() != DEFAULT_CORE_REPO {
warn!(
"`--core-repo={core_repo}`が指定されていますが、`core`も`models`もダウンロード対象\
から除外されています",
);
}
}
if !targets.contains(&DownloadTarget::AdditionalLibraries) {
if additional_libraries_version != "latest" {
warn!(
"`--additional-libraries-version={additional_libraries_version}`が指定されています\
が、`additional-libraries-version`はダウンロード対象から除外されています",
);
}
if additional_libraries_repo.to_string() != DEFAULT_ADDITIONAL_LIBRARIES_REPO {
warn!(
"`--additional-libraries-repo={additional_libraries_repo}`が指定されていますが、\
`additional-libraries-version`はダウンロード対象から除外されています",
);
}
}

let octocrab = &octocrab()?;

let core = find_gh_asset(octocrab, &core_repo, &version, |tag| {
Expand Down Expand Up @@ -202,21 +277,23 @@ async fn main() -> anyhow::Result<()> {

let mut tasks = JoinSet::new();

tasks.spawn(download_and_extract_from_gh(
core,
Stripping::FirstDir,
&output,
&progresses,
)?);

if !min {
if targets.contains(&DownloadTarget::Core) {
tasks.spawn(download_and_extract_from_gh(
core,
Stripping::FirstDir,
&output,
&progresses,
)?);
}
if targets.contains(&DownloadTarget::Models) {
tasks.spawn(download_and_extract_from_gh(
model,
Stripping::FirstDir,
&output.join("model"),
&progresses,
)?);

}
if targets.contains(&DownloadTarget::AdditionalLibraries) {
if let Some(additional_libraries) = additional_libraries {
tasks.spawn(download_and_extract_from_gh(
additional_libraries,
Expand All @@ -225,7 +302,8 @@ async fn main() -> anyhow::Result<()> {
&progresses,
)?);
}

}
if targets.contains(&DownloadTarget::Dict) {
tasks.spawn(download_and_extract_from_url(
&OPEN_JTALK_DIC_URL,
Stripping::None,
Expand Down

0 comments on commit a3705e7

Please sign in to comment.