Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WASI-NN] Add support for a PyTorch backend for wasi-nn #9234

Merged
merged 8 commits into from
Oct 21, 2024

Conversation

rahulchaphalkar
Copy link
Contributor

This change adds a PyTorch backend for wasi-nn.
tch crate is used for Libtorch bindings. I have added an image classification example to demonstrate its usage, which uses a torchscript model.
This backend is currently gated behind a wasi-nn feature flag --features pytorch as due to dynamic linking, a Libtorch v2.4.0 installation on the system (specified by LIBTORCH=/path/to/libtorch) is needed for building.

@rahulchaphalkar rahulchaphalkar requested review from alexcrichton and removed request for a team September 12, 2024 18:18
@abrown abrown self-assigned this Sep 12, 2024
@alexcrichton alexcrichton requested review from abrown and removed request for a team and alexcrichton September 12, 2024 20:09
Copy link
Contributor

@abrown abrown left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good start. The main thing to fix is the handling of the input and output tensors.

crates/wasi-nn/src/backend/pytorch.rs Outdated Show resolved Hide resolved
crates/wasi-nn/src/backend/pytorch.rs Outdated Show resolved Hide resolved
crates/wasi-nn/src/backend/pytorch.rs Outdated Show resolved Hide resolved
crates/wasi-nn/src/backend/pytorch.rs Outdated Show resolved Hide resolved
crates/wasi-nn/src/backend/pytorch.rs Outdated Show resolved Hide resolved
crates/wasi-nn/src/backend/pytorch.rs Outdated Show resolved Hide resolved
crates/wasi-nn/src/backend/pytorch.rs Outdated Show resolved Hide resolved
@abrown
Copy link
Contributor

abrown commented Sep 20, 2024

The cargo vet situation is a bit much:

    cargo vet diff zstd-safe 5.0.1+zstd.1.5.2 5.0.2+zstd.1.5.2
                                                          gyscos         zstd                        2 files changed, 4 insertions(+), 4 deletions(-)
    cargo vet diff zstd 0.11.1+zstd.1.5.2 0.11.2+zstd.1.5.2
                                                          gyscos         zip                         3 files changed, 5 insertions(+), 5 deletions(-)
    cargo vet diff num-complex 0.4.2 0.4.6                cuviper        ndarray                     6 files changed, 188 insertions(+), 48 deletions(-)
      NOTE: this project trusts Josh Stone (cuviper) - consider cargo vet trust num-complex or cargo vet trust --all cuviper
    cargo vet inspect constant_time_eq 0.1.5              cesarb         zip                         311 lines
    cargo vet diff sha1 0.10.5 0.10.6                     newpavlov      zip                         7 files changed, 302 insertions(+), 20 deletions(-)
    cargo vet inspect rawpointer 0.2.1                    bluss          ndarray and matrixmultiply  559 lines
    cargo vet diff zip 0.6.4 0.6.6                        Plecra         tch and torch-sys           14 files changed, 604 insertions(+), 109 deletions(-)
    cargo vet inspect inout 0.1.3                         newpavlov      cipher                      1112 lines
      NOTE: cargo vet import zcash would eliminate this
    cargo vet inspect pbkdf2 0.9.0                        tarcieri       zip                         1120 lines
    cargo vet inspect bzip2 0.4.4                         alexcrichton   zip                         2094 lines
      NOTE: this project trusts Alex Crichton (alexcrichton) - consider cargo vet trust bzip2 or cargo vet trust --all alexcrichton
    cargo vet inspect safetensors 0.3.3                   Narsil         tch                         2200 lines
    cargo vet inspect cipher 0.4.4                        newpavlov      aes                         2635 lines
      NOTE: cargo vet import zcash would reduce this to a [130](https://github.com/bytecodealliance/wasmtime/actions/runs/10836457564/job/30070281197?pr=9234#step:6:131)0-line diff
    cargo vet inspect password-hash 0.3.2                 tarcieri       pbkdf2                      3139 lines
    cargo vet inspect base64ct 1.6.0                      tarcieri       password-hash               3381 lines
    cargo vet diff half 1.8.2 2.4.1                       starkat99      tch                         19 files changed, 2546 insertions(+), 958 deletions(-)
    cargo vet inspect time 0.1.44                         jhpratt        zip                         3915 lines
    cargo vet inspect aes 0.7.5                           tarcieri       zip                         6822 lines
    cargo vet inspect matrixmultiply 0.3.8                bluss          ndarray                     7934 lines
    cargo vet inspect ndarray 0.15.6                      jturner314     tch                         41996 lines
    cargo vet inspect torch-sys 0.17.0                    LaurentMazare  tch                         52119 lines
    cargo vet inspect bzip2-sys 0.1.11+1.0.8              alexcrichton   bzip2                       264[133](https://github.com/bytecodealliance/wasmtime/actions/runs/10836457564/job/30070281197?pr=9234#step:6:134) lines
      NOTE: this project trusts Alex Crichton (alexcrichton) - consider cargo vet trust bzip2-sys or cargo vet trust --all alexcrichton
    cargo vet inspect tch 0.17.0                          LaurentMazare  wasmtime-wasi-nn            2287297 lines

@rahulchaphalkar
Copy link
Contributor Author

This is a good start. The main thing to fix is the handling of the input and output tensors.

Thanks for the review, Andrew. I've marked smaller Nits as resolved, and I've addressed other comments as well, but kept them 'unresolved' as of now until you take a look.

&input_tensor.data,
&dimensions,
kind,
));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is almost there: we're still ignoring the index, though. What might need to happen is that we set up inputs as a Vec<Option<TchTensor>> filled with None and then set the right item based on the index. We are able to retrieve the number of inputs, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The index is being ignored because jit-compiled pytorch models expect tensors in order, similar to what is achieved by vector of tensors in set_input(). Do we need to use index here to keep things consistent with wasi-nn? I don't think I saw a direct way to retrieve the number of inputs to a model, I can look into that further. However, someone using pytorch would probably not intend on using index.
This was my previous comment on ignoring the index:

This is one of the differences for this backend. The module's forward method should handle multiple inputs appropriately if it does support multiple inputs. The vector of input tensors being passed to forward should be sufficient, no index or name is needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you mean use the index as the position/index for the vector of tensors? That sounds good. I'll see if there's a way to determine the max number of inputs to the given model.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, so here's what I've changed -

  • set_input() now uses the id passed to it. If it is u32, it assigns the tensor to the vector inputs[id], assigning None if there are any empty spots below the given id. This will give the user flexibility to set input tensors non sequentially.
  • compute() will check the vector for any None values, and give an error if present, before calling forward_ts().
  • There is no reliable way at this time to get max inputs for a model. However, assigning more inputs than available returns an error message detailing the max inputs the model expects at run time, so this is helpful for the user. The expectation, similar to other backends, is for the user to be aware of input size/shape etc.
  • If id is a string, I've currently permitted only single input to be set, effectively ignoring the index. If more inputs are assigned, or if u32 and String indexes are used together, we give out an error.

Let me know if this looks fine. Thanks.

Copy link
Contributor

@abrown abrown Oct 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, looking at 727a4aa, I think that makes more sense. I was kind of hoping that CModule::named_parameters would return the list of inputs by name, but, if that's not the case, then let's just return an error for the Id::Name side of things. In any case, we probably don't need id_type: if we get an Id::Index we know were to put it in the vector and if we get an Id::Name then we either (a) look up its index in named_parameters or (b) if that's not possible, return an error.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had added the Id:Name case because the wit test cases (nn_wit_image_classification_pytorch in this case) need a Name instead of Index to pass - nn.rs and wasi-nn.wit. Although it looks like set-input might be going away anyway.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we find a smaller model or download this instead? Not all Wasmtime users probably want to download this file... twice.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've made the following changes -

crates/wasi-nn/Cargo.toml Outdated Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's the identical 44.7MB file checked in again.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Downloading from external repo, same as tests

rahulchaphalkar and others added 7 commits October 14, 2024 08:18
As described in the [contribution guidelines], Wasmtime will exempt
dependencies from vetting that receive at least 10,000 downloads a day.
This substantially reduces the burden for vetting this PR, so I've
tallied up daily downloads (across all versions) for the crates in this
PR, listed below. This change then exempts the new dependencies that
meet the 10K+ criteria.

[contribution guidelines]: https://docs.wasmtime.dev/contributing-coding-guidelines.html#policy-for-adding-cargo-vet-entries

```
> aes
2024-10-02      111734
2024-10-03      107324
2024-10-04      104299
2024-10-05      32397
2024-10-06      29507
2024-10-07      123368
2024-10-08      125732

> base64ct
2024-10-02      179848
2024-10-03      157938
2024-10-04      149495
2024-10-05      48118
2024-10-06      43389
2024-10-07      183254
2024-10-08      175378

> bzip2
2024-10-02      89309
2024-10-03      85112
2024-10-04      76573
2024-10-05      27152
2024-10-06      24124
2024-10-07      90228
2024-10-08      93314

> bzip2-sys
2024-10-02      109664
2024-10-03      102677
2024-10-04      94485
2024-10-05      33196
2024-10-06      28417
2024-10-07      110195
2024-10-08      110951

> cipher
2024-10-02      1119
2024-10-03      377
2024-10-04      270
2024-10-05      178
2024-10-06      271
2024-10-07      2105
2024-10-08      1777

> constant_time_eq
2024-10-02      137462
2024-10-03      126300
2024-10-04      121927
2024-10-05      169156
2024-10-06      139559
2024-10-07      304529
2024-10-08      246533

> crunchy
2024-10-02      197832
2024-10-03      176586
2024-10-04      172053
2024-10-05      187875
2024-10-06      153647
2024-10-07      359240
2024-10-08      304777

> deranged
2024-10-02      319691
2024-10-03      285298
2024-10-04      267760
2024-10-05      104537
2024-10-06      92306
2024-10-07      309831
2024-10-08      308869

> digest
2024-10-02      2128
2024-10-03      1335
2024-10-04      1474
2024-10-05      594
2024-10-06      726
2024-10-07      3079
2024-10-08      2855

> half
2024-10-02      161525
2024-10-03      144013
2024-10-04      137296
2024-10-05      49246
2024-10-06      42437
2024-10-07      157366
2024-10-08      165013

> hmac
2024-10-02      1254
2024-10-03      394
2024-10-04      322
2024-10-05      230
2024-10-06      424
2024-10-07      2068
2024-10-08      1907

> inout
2024-10-02      1114
2024-10-03      366
2024-10-04      281
2024-10-05      184
2024-10-06      285
2024-10-07      2000
2024-10-08      1782

> matrixmultiply
2024-10-02      52273
2024-10-03      49931
2024-10-04      48408
2024-10-05      17219
2024-10-06      13950
2024-10-07      53916
2024-10-08      52644

> ndarray
2024-10-02      28922
2024-10-03      29354
2024-10-04      27397
2024-10-05      10480
2024-10-06      9074
2024-10-07      30988
2024-10-08      32344

> num-complex
2024-10-02      178444
2024-10-03      159144
2024-10-04      146722
2024-10-05      48522
2024-10-06      39138
2024-10-07      171363
2024-10-08      172915

> num-conv
2024-10-02      298495
2024-10-03      267134
2024-10-04      250350
2024-10-05      97809
2024-10-06      87399
2024-10-07      293150
2024-10-08      290661

> num-integer
2024-10-02      333731
2024-10-03      300418
2024-10-04      287516
2024-10-05      227416
2024-10-06      190413
2024-10-07      487348
2024-10-08      433744

> password-hash
2024-10-02      22429
2024-10-03      20702
2024-10-04      21550
2024-10-05      9061
2024-10-06      8660
2024-10-07      25743
2024-10-08      22404

> pbkdf2
2024-10-02      77885
2024-10-03      76192
2024-10-04      72278
2024-10-05      148944
2024-10-06      119322
2024-10-07      248354
2024-10-08      190649

> powerfmt
2024-10-02      310293
2024-10-03      277178
2024-10-04      259885
2024-10-05      101195
2024-10-06      89789
2024-10-07      302058
2024-10-08      300192

> rawpointer
2024-10-02      53917
2024-10-03      50649
2024-10-04      48439
2024-10-05      17375
2024-10-06      14761
2024-10-07      56228
2024-10-08      55013

> safetensors
2024-10-02      2253
2024-10-03      1737
2024-10-04      1798
2024-10-05      1085
2024-10-06      1544
2024-10-07      1742
2024-10-08      2024

> sha1
2024-10-02      1410
2024-10-03      673
2024-10-04      772
2024-10-05      230
2024-10-06      416
2024-10-07      2125
2024-10-08      2204

> tch
2024-10-02      1930
2024-10-03      2295
2024-10-04      2834
2024-10-05      1274
2024-10-06      455
2024-10-07      2290
2024-10-08      2181

> time
2024-10-02      303042
2024-10-03      271434
2024-10-04      255795
2024-10-05      100194
2024-10-06      88810
2024-10-07      297807
2024-10-08      295315

> time-core
2024-10-02      334979
2024-10-03      302165
2024-10-04      282918
2024-10-05      109319
2024-10-06      96522
2024-10-07      324779
2024-10-08      322102

> torch-sys
2024-10-02      1911
2024-10-03      2300
2024-10-04      2843
2024-10-05      1271
2024-10-06      452
2024-10-07      2292
2024-10-08      2177

> zip
2024-10-02      22520
2024-10-03      23201
2024-10-04      20946
2024-10-05      9067
2024-10-06      8470
2024-10-07      24674
2024-10-08      24870

> zstd
2024-10-02      175155
2024-10-03      167766
2024-10-04      157489
2024-10-05      52753
2024-10-06      44844
2024-10-07      177411
2024-10-08      173785

> zstd-safe
2024-10-02      179288
2024-10-03      170379
2024-10-04      159352
2024-10-05      52820
2024-10-06      45835
2024-10-07      180535
2024-10-08      177703
```
For dependencies that did not have clear 10k+ daily downloads, this
change audits them for `safe-to-deploy`.
This adds external audits pulled in automatically by `cargo vet` for the
remainder of the dependencies not covered by previous commits.
@abrown
Copy link
Contributor

abrown commented Oct 14, 2024

I think 51930fa and da1f467 have unintentionally updated some crates. Can you remove those commits?

@rahulchaphalkar
Copy link
Contributor Author

I rolled back 2 commits, but there's an issue with the lock file. I had previously attempted to fix this by deleting my lockfile, rebasing off of latest main, and then doing a cargo build to generate any of my lock file changes on top of that.

@rahulchaphalkar
Copy link
Contributor Author

The failing tests fail due to

error: failed retrieving file 'mingw-w64-x86_64-headers-git-12.0.0.r329.g8f7b5ce36-1-any.pkg.tar.zst.sig' from mirror.clarkson.edu : Operation too slow. Less than 1 bytes/sec transferred the last 10 seconds
 mingw-w64-x86_64-gcc-14.2.0-1-any downloading...
error: failed retrieving file 'mingw-w64-x86_64-mpfr-4.2.1-2-any.pkg.tar.zst' from mirror.msys2.org : Operation too slow. Less than 1 bytes/sec transferred the last 10 seconds

I'm hoping a rerun of CI would help.

@rahulchaphalkar
Copy link
Contributor Author

@abrown can you take a look

@abrown abrown added this pull request to the merge queue Oct 21, 2024
Merged via the queue into bytecodealliance:main with commit a5412ca Oct 21, 2024
127 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants