Skip to content

Commit

Permalink
implement module download
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Feb 10, 2024
1 parent c8f214c commit 4326b45
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 21 deletions.
9 changes: 9 additions & 0 deletions aicirt/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,15 @@ macro_rules! bail_user {
};
}

#[macro_export]
macro_rules! ensure_user {
($cond:expr, $($tt:tt)*) => {
if !$cond {
return Err($crate::UserError::anyhow(format!($($tt)*)))
}
};
}

pub fn is_hex_string(s: &str) -> bool {
s.chars().all(|c| c.is_digit(16))
}
Expand Down
78 changes: 58 additions & 20 deletions aicirt/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use clap::Parser;
use hex;
use hostimpl::GlobalInfo;
use regex::Regex;
use serde::Serialize;
use serde_json::{json, Value};
use sha2::{Digest, Sha256};
use std::{
Expand Down Expand Up @@ -162,6 +163,16 @@ fn hex_hash_string(s: &str) -> String {
hex::encode(hasher.finalize())
}

fn read_json(filename: &PathBuf) -> Result<Value> {
let bytes = fs::read(filename)?;
Ok(serde_json::from_slice(&bytes)?)
}

fn write_json<T: Serialize>(filename: &PathBuf, json: &T) -> Result<()> {
fs::write(filename, serde_json::to_vec(json)?)?;
Ok(())
}

impl ModuleRegistry {
pub fn new(wasm_ctx: WasmContext, shm: Shm) -> Result<Self> {
let forker = WorkerForker::new(wasm_ctx.clone(), shm);
Expand Down Expand Up @@ -254,7 +265,7 @@ impl ModuleRegistry {
Ok(self.elf_path(module_id))
}

fn create_module(&self, wasm_bytes: Vec<u8>, auth: AuthInfo) -> Result<Value> {
fn create_module(&self, wasm_bytes: Vec<u8>, auth: AuthInfo) -> Result<MkModuleResp> {
let timer = Instant::now();

let mut hasher = <Sha256 as Digest>::new();
Expand Down Expand Up @@ -285,12 +296,12 @@ impl ModuleRegistry {
time
);

Ok(serde_json::to_value(MkModuleResp {
Ok(MkModuleResp {
module_id: module_id.to_string(),
wasm_size: wasm_bytes.len(),
compiled_size,
time,
})?)
})
}

fn write_and_compile(
Expand All @@ -304,12 +315,12 @@ impl ModuleRegistry {
Ok(
if meta.is_err() || meta.unwrap().len() != wasm_bytes.len() as u64 {
fs::write(self.wasm_path(module_id), wasm_bytes)?;
fs::write(
self.sys_meta_path(module_id),
serde_json::to_vec(&json!({
write_json(
&self.sys_meta_path(module_id),
&json!({
"created": get_unix_time(),
"auth": auth,
}))?,
}),
)?;
self.compile_module(module_id, true)?
} else {
Expand All @@ -320,7 +331,9 @@ impl ModuleRegistry {

fn mk_module(&self, req: MkModuleReq, auth: AuthInfo) -> Result<Value> {
let wasm_bytes = base64::engine::general_purpose::STANDARD.decode(req.binary)?;
self.create_module(wasm_bytes, auth)
Ok(serde_json::to_value(
&self.create_module(wasm_bytes, auth)?,
)?)
}

fn set_tags(&self, req: SetTagsReq, auth: AuthInfo) -> Result<Value> {
Expand Down Expand Up @@ -366,7 +379,7 @@ impl ModuleRegistry {
log::info!("tag {} -> {} by {}", tagname, req.module_id, auth.user);
let mut info = info.clone();
info.tag = tagname.clone();
fs::write(self.tag_path(tagname), &serde_json::to_vec(&info)?)?;
write_json(&self.tag_path(tagname), &info)?;
resp.tags.push(info)
}

Expand Down Expand Up @@ -440,18 +453,18 @@ impl ModuleRegistry {
if !(meta.is_ok()
&& meta.unwrap().modified()? > SystemTime::now().sub(Duration::from_secs(120)))
{
log::info!("fetching {}", url);
log::info!("fetching {} to {:?}", url, cache_path);
let resp = ureq::get(&url)
.set("User-Agent", "AICI")
.set("Accept", "application/vnd.github+json")
.set("X-GitHub-Api-Version", "2022-11-28")
.call()
.map_err(|e| anyhow!("gh: fetch failed: {}", e))?;
fs::create_dir_all(&self.cache_path)?;
std::fs::write(cache_path.clone(), resp.into_string()?)?;
}
let json: serde_json::Value = serde_json::from_slice(&std::fs::read(cache_path)?)?;

let wasm_files = json["assets"]
let release = read_json(&cache_path)?;
let wasm_files = release["assets"]
.as_array()
.ok_or_else(|| anyhow!("no assets"))?
.iter()
Expand All @@ -463,21 +476,46 @@ impl ModuleRegistry {
})
.collect::<Vec<_>>();

ensure!(wasm_files.len() > 0, "no wasm files found");
ensure!(wasm_files.len() == 1, "too many wasm files found");
ensure_user!(wasm_files.len() > 0, "no wasm files found");
ensure_user!(wasm_files.len() == 1, "too many wasm files found");

let wasm_file = wasm_files[0];
let _upd = wasm_file["updated_at"]
let upd = wasm_file["updated_at"]
.as_str()
.ok_or_else(|| anyhow!("no updated_at"))?;
let _wasm_url = wasm_file["browser_download_url"]
let wasm_url = wasm_file["browser_download_url"]
.as_str()
.ok_or_else(|| anyhow!("no browser_download_url"))?;

todo!()
let link_path = self.url_path(&format!("{}---{}", upd, wasm_url));
if link_path.exists() {
let link = read_json(&link_path)?;
return Ok(link["module_id"]
.as_str()
.ok_or_else(|| anyhow!("invalid json"))?
.to_string());
}
log::info!("downloading {}", wasm_url);
let mut wasm_bytes = vec![];
ureq::get(wasm_url)
.set("User-Agent", "AICI")
.call()
.map_err(|e| anyhow!("gh: download failed: {}", e))?
.into_reader()
.read_to_end(&mut wasm_bytes)?;
log::info!("downloaded {} bytes", wasm_bytes.len());
let resp = self.create_module(
wasm_bytes,
AuthInfo {
user: wasm_url.to_string(),
is_admin: true,
},
)?;
write_json(&link_path, &resp)?;
Ok(resp.module_id)
}

fn instantiate(&mut self, mut req: InstantiateReq) -> Result<Value> {
req.module_id = self.resolve_gh_module(&req.module_id)?;
if valid_tagname(&req.module_id) {
let taginfo = self.read_tag(&req.module_id)?;
req.module_id = taginfo.module_id;
Expand Down Expand Up @@ -1078,7 +1116,7 @@ fn install_from_cmdline(cli: &Cli, wasm_ctx: WasmContext, shm: Shm) {
let json = reg
.create_module(wasm_bytes, AuthInfo::local_user())
.unwrap();
json["module_id"].as_str().unwrap().to_string()
json.module_id
};

println!("{}", module_id);
Expand Down
2 changes: 1 addition & 1 deletion rllm/rllm-base/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ impl SamplingParams {

fn _verify_args(&self) -> Result<()> {
if let Some(mod_id) = self.controller.as_ref() {
if !valid_module_or_tag(mod_id) {
if !valid_module_or_tag(mod_id) && !mod_id.starts_with("gh:") {
bail_user!(
"'controller' must be a 64-char hex string or tag name, got {}.",
mod_id
Expand Down

0 comments on commit 4326b45

Please sign in to comment.