diff --git a/aicirt/src/lib.rs b/aicirt/src/lib.rs index aa9c137f..b1d65b40 100644 --- a/aicirt/src/lib.rs +++ b/aicirt/src/lib.rs @@ -118,6 +118,15 @@ macro_rules! bail_user { }; } +#[macro_export] +macro_rules! ensure_user { + ($cond:expr, $($tt:tt)*) => { + if !$cond { + return Err($crate::UserError::anyhow(format!($($tt)*))) + } + }; +} + pub fn is_hex_string(s: &str) -> bool { s.chars().all(|c| c.is_digit(16)) } diff --git a/aicirt/src/main.rs b/aicirt/src/main.rs index 157ec22b..ab4a817f 100644 --- a/aicirt/src/main.rs +++ b/aicirt/src/main.rs @@ -21,6 +21,7 @@ use clap::Parser; use hex; use hostimpl::GlobalInfo; use regex::Regex; +use serde::Serialize; use serde_json::{json, Value}; use sha2::{Digest, Sha256}; use std::{ @@ -162,6 +163,16 @@ fn hex_hash_string(s: &str) -> String { hex::encode(hasher.finalize()) } +fn read_json(filename: &PathBuf) -> Result { + let bytes = fs::read(filename)?; + Ok(serde_json::from_slice(&bytes)?) +} + +fn write_json(filename: &PathBuf, json: &T) -> Result<()> { + fs::write(filename, serde_json::to_vec(json)?)?; + Ok(()) +} + impl ModuleRegistry { pub fn new(wasm_ctx: WasmContext, shm: Shm) -> Result { let forker = WorkerForker::new(wasm_ctx.clone(), shm); @@ -254,7 +265,7 @@ impl ModuleRegistry { Ok(self.elf_path(module_id)) } - fn create_module(&self, wasm_bytes: Vec, auth: AuthInfo) -> Result { + fn create_module(&self, wasm_bytes: Vec, auth: AuthInfo) -> Result { let timer = Instant::now(); let mut hasher = ::new(); @@ -285,12 +296,12 @@ impl ModuleRegistry { time ); - Ok(serde_json::to_value(MkModuleResp { + Ok(MkModuleResp { module_id: module_id.to_string(), wasm_size: wasm_bytes.len(), compiled_size, time, - })?) + }) } fn write_and_compile( @@ -304,12 +315,12 @@ impl ModuleRegistry { Ok( if meta.is_err() || meta.unwrap().len() != wasm_bytes.len() as u64 { fs::write(self.wasm_path(module_id), wasm_bytes)?; - fs::write( - self.sys_meta_path(module_id), - serde_json::to_vec(&json!({ + write_json( + &self.sys_meta_path(module_id), + &json!({ "created": get_unix_time(), "auth": auth, - }))?, + }), )?; self.compile_module(module_id, true)? } else { @@ -320,7 +331,9 @@ impl ModuleRegistry { fn mk_module(&self, req: MkModuleReq, auth: AuthInfo) -> Result { let wasm_bytes = base64::engine::general_purpose::STANDARD.decode(req.binary)?; - self.create_module(wasm_bytes, auth) + Ok(serde_json::to_value( + &self.create_module(wasm_bytes, auth)?, + )?) } fn set_tags(&self, req: SetTagsReq, auth: AuthInfo) -> Result { @@ -366,7 +379,7 @@ impl ModuleRegistry { log::info!("tag {} -> {} by {}", tagname, req.module_id, auth.user); let mut info = info.clone(); info.tag = tagname.clone(); - fs::write(self.tag_path(tagname), &serde_json::to_vec(&info)?)?; + write_json(&self.tag_path(tagname), &info)?; resp.tags.push(info) } @@ -440,18 +453,18 @@ impl ModuleRegistry { if !(meta.is_ok() && meta.unwrap().modified()? > SystemTime::now().sub(Duration::from_secs(120))) { - log::info!("fetching {}", url); + log::info!("fetching {} to {:?}", url, cache_path); let resp = ureq::get(&url) .set("User-Agent", "AICI") .set("Accept", "application/vnd.github+json") .set("X-GitHub-Api-Version", "2022-11-28") .call() .map_err(|e| anyhow!("gh: fetch failed: {}", e))?; + fs::create_dir_all(&self.cache_path)?; std::fs::write(cache_path.clone(), resp.into_string()?)?; } - let json: serde_json::Value = serde_json::from_slice(&std::fs::read(cache_path)?)?; - - let wasm_files = json["assets"] + let release = read_json(&cache_path)?; + let wasm_files = release["assets"] .as_array() .ok_or_else(|| anyhow!("no assets"))? .iter() @@ -463,21 +476,46 @@ impl ModuleRegistry { }) .collect::>(); - ensure!(wasm_files.len() > 0, "no wasm files found"); - ensure!(wasm_files.len() == 1, "too many wasm files found"); + ensure_user!(wasm_files.len() > 0, "no wasm files found"); + ensure_user!(wasm_files.len() == 1, "too many wasm files found"); let wasm_file = wasm_files[0]; - let _upd = wasm_file["updated_at"] + let upd = wasm_file["updated_at"] .as_str() .ok_or_else(|| anyhow!("no updated_at"))?; - let _wasm_url = wasm_file["browser_download_url"] + let wasm_url = wasm_file["browser_download_url"] .as_str() .ok_or_else(|| anyhow!("no browser_download_url"))?; - - todo!() + let link_path = self.url_path(&format!("{}---{}", upd, wasm_url)); + if link_path.exists() { + let link = read_json(&link_path)?; + return Ok(link["module_id"] + .as_str() + .ok_or_else(|| anyhow!("invalid json"))? + .to_string()); + } + log::info!("downloading {}", wasm_url); + let mut wasm_bytes = vec![]; + ureq::get(wasm_url) + .set("User-Agent", "AICI") + .call() + .map_err(|e| anyhow!("gh: download failed: {}", e))? + .into_reader() + .read_to_end(&mut wasm_bytes)?; + log::info!("downloaded {} bytes", wasm_bytes.len()); + let resp = self.create_module( + wasm_bytes, + AuthInfo { + user: wasm_url.to_string(), + is_admin: true, + }, + )?; + write_json(&link_path, &resp)?; + Ok(resp.module_id) } fn instantiate(&mut self, mut req: InstantiateReq) -> Result { + req.module_id = self.resolve_gh_module(&req.module_id)?; if valid_tagname(&req.module_id) { let taginfo = self.read_tag(&req.module_id)?; req.module_id = taginfo.module_id; @@ -1078,7 +1116,7 @@ fn install_from_cmdline(cli: &Cli, wasm_ctx: WasmContext, shm: Shm) { let json = reg .create_module(wasm_bytes, AuthInfo::local_user()) .unwrap(); - json["module_id"].as_str().unwrap().to_string() + json.module_id }; println!("{}", module_id); diff --git a/rllm/rllm-base/src/config.rs b/rllm/rllm-base/src/config.rs index 966976bf..6cf5a436 100644 --- a/rllm/rllm-base/src/config.rs +++ b/rllm/rllm-base/src/config.rs @@ -158,7 +158,7 @@ impl SamplingParams { fn _verify_args(&self) -> Result<()> { if let Some(mod_id) = self.controller.as_ref() { - if !valid_module_or_tag(mod_id) { + if !valid_module_or_tag(mod_id) && !mod_id.starts_with("gh:") { bail_user!( "'controller' must be a 64-char hex string or tag name, got {}.", mod_id