Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problem with ort and WASM #260

Closed
all-c-a-p-s opened this issue Aug 13, 2024 · 1 comment
Closed

Problem with ort and WASM #260

all-c-a-p-s opened this issue Aug 13, 2024 · 1 comment

Comments

@all-c-a-p-s
Copy link

I switched from the tract-onnx crate to ort to load a .onnx model in Rust. I followed instructions found at the article https://ort.pyke.io/setup/webassembly to try to integrate ort with wasm in my project. However, I still get an error in the console of the browser when I try to use trunk serve to build the project.

Here is the wasm-related error:

TypeError: Failed to execute 'decode' on 'TextDecoder': The encoded data was not valid.
    at getStringFromWasm0 (graffiti-12c00882611d1559.js:18:30)
    at imports.wbg.__wbg_error_d2d279fddc1936c2 (graffiti-12c00882611d1559.js:301:27)
    at graffiti-95270a5d58d49849.wasm.eframe::web::panic_handler::error::he4aea513b62c29d8 (graffiti-12c00882611d1559_bg.wasm:0xb9a778)
    at graffiti-95270a5d58d49849.wasm.eframe::web::panic_handler::PanicHandler::install::{{closure}}::hd32e90139075ff91 (graffiti-12c00882611d1559_bg.wasm:0x676f90)
    at graffiti-95270a5d58d49849.wasm.std::panicking::rust_panic_with_hook::h6731baa78621a747 (graffiti-12c00882611d1559_bg.wasm:0xac87d0)
    at graffiti-95270a5d58d49849.wasm.std::panicking::begin_panic_handler::{{closure}}::hb6cd8464ed39ae71 (graffiti-12c00882611d1559_bg.wasm:0xb5d383)
    at graffiti-95270a5d58d49849.wasm.std::sys_common::backtrace::__rust_end_short_backtrace::hbdf3ddeb21a1e747 (graffiti-12c00882611d1559_bg.wasm:0xcb007f)
    at graffiti-95270a5d58d49849.wasm.rust_begin_unwind (graffiti-12c00882611d1559_bg.wasm:0xc5d91f)
    at graffiti-95270a5d58d49849.wasm.core::panicking::panic_fmt::h5c7ce52813e94bcd (graffiti-12c00882611d1559_bg.wasm:0xc7026e)
    at graffiti-95270a5d58d49849.wasm.core::cell::panic_already_borrowed::h18b8189a0fdd8b58 (graffiti-12c00882611d1559_bg.wasm:0xc33747)

Here is all of my ort-related code:

use std::collections::HashMap;

use ort::{Session, Tensor};

...

//include NN data at compile time
static GRADE_MODEL_BYTES: &[u8] = include_bytes!("/Users/seba/rs/graffiti/models/custom_model.ort");
static ROUTESET_MODEL_BYTES: &[u8] = include_bytes!("/Users/seba/rs/graffiti/models/routeset/routeset.ort");

pub fn run_model(
    start_holds: Vec<String>,
    finish_holds: Vec<String>,
    intermediate_holds: Vec<String>,
) -> ort::Result<String> {
    let mut holds_data: Vec<f32> = vec![0.0; 198];
    
    ...

    #[cfg(target_arch = "wasm32")]
    ort::wasm::initialize();

    let session = Session::builder()?.commit_from_memory(GRADE_MODEL_BYTES)?;

    let input_holds = Tensor::from_array(([1, 198], holds_data.clone().into_boxed_slice()))?;
    let mut inputs = HashMap::new();
    inputs.insert("input_layer", input_holds);

    let outputs = session.run(inputs)?;

    let mut probabilities: Vec<f32> = Vec::new();

    for (_, output_value) in outputs.iter() {
        probabilities = output_value
            .to_owned()
            .try_extract_tensor::<f32>()?
            .iter()
            .cloned()
            .collect::<Vec<f32>>();
    }

    let mut max: f32 = 0.0;
    let mut most_likely_grade = 4;

    ...
}

pub fn run_routeset_model(
    start_holds: &Vec<String>,
    finish_holds: &Vec<String>,
    intermediate_holds: &Vec<String>,
    grade: f32,
) -> ort::Result<Option<String>> {
    let mut holds_data: Vec<Vec<f32>> = vec![vec![0.0f32; 11]; 18];
    ...

    #[cfg(target_arch = "wasm32")]
    ort::wasm::initialize();

    let session = Session::builder()?.commit_from_memory(ROUTESET_MODEL_BYTES)?;

    let input_vector = holds_data.iter().flatten().cloned().collect::<Vec<f32>>();
    let input_holds = Tensor::from_array(([1, 18, 11], input_vector.clone().into_boxed_slice()))?;
    let input_grade = Tensor::from_array(([1, 1], vec![grade].into_boxed_slice()))?;

    let mut inputs = HashMap::new();
    inputs.insert("input_holds", input_holds);
    inputs.insert("input_grades", input_grade);

    let outputs = session.run(inputs)?;

    let mut probabilities: Vec<f32> = Vec::new();
    for (_, output_value) in outputs.iter() {
        probabilities = output_value
            .to_owned()
            .try_extract_tensor::<f32>()?
            .iter()
            .cloned()
            .collect::<Vec<f32>>();
    }

    ...
}

pub fn generate_route(
    start_holds: Vec<String>,
    finish_holds: Vec<String>,
    intermediate_holds: Vec<String>,
    grade: usize,
) -> (Vec<String>, Vec<String>, Vec<String>) {
    #[cfg(target_arch = "wasm32")]
    ort::wasm::initialize();
    ...
    let mut next_hold = run_routeset_model(
        &start_holds,
        &finish_holds,
        &intermediate_holds,
        grade as f32,
    )
    .expect("failed to run model");
    while next_hold.is_some() {
        ...
        next_hold = run_routeset_model(&s, &f, &i, grade as f32).expect("failed to run model");
    }
    ...
}

Here is my Cargo.toml: (I don't think any of the other dependencies are incompatible with WASM):

[package]
name = "graffiti"
version = "0.1.0"
edition = "2021"
description = "AI for climbing routesetting"
authors = ["Sebastiano Rebonato-Scott"]
license = "MIT"
repository = "https://github.com/all-c-a-p-s/Graffiti"

[dependencies]
console_error_panic_hook = "0.1.7"
eframe = "0.28.1"
egui = "0.28.1"
egui_extras = { version = "0.28.1", features = ["all_loaders"] }
getrandom = { version = "0.2", features = ["js"] }
image = { version = "0.25.2", features = ["jpeg", "png"] }
log = "0.4.22"
ndarray = "0.16.0"
ort = "2.0.0-rc.4"
wasm-bindgen = "0.2.92"
web-sys = "0.3.69"

[package.metadata.bundle]
name = "graffiti"
identifier = "io.github.all-c-a-p-s.graffiti"

[target.'cfg(target_arch = "wasm32")'.dependencies]
wasm-bindgen-futures = "0.4"

[lints.clippy]
all = "warn"

I have also tried compiling to WASM using wasm-pack instead of trunk, but I got the same error. Before I switched to ort from tract-onnx, the WASM was working fine, which is my reason for thinking that ort is the cause of the problem.
Sorry for the very long post and thanks for your time.

@decahedron1
Copy link
Member

I can no longer support WASM; it'll be removed soon. Sorry. I suggest you switch back to tract.

@decahedron1 decahedron1 closed this as not planned Won't fix, can't repro, duplicate, stale Aug 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants