Skip to content

Commit

Permalink
Add mini web demo
Browse files Browse the repository at this point in the history
  • Loading branch information
nikhilaravi committed Apr 11, 2023
1 parent 7fa17d7 commit 426edb2
Show file tree
Hide file tree
Showing 23 changed files with 766 additions and 0 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,8 @@ model_ts*.txt
.idea
.vscode
_darcs

# demo
**/node_modules
yarn.lock
package-lock.json
Binary file added assets/minidemo.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
95 changes: 95 additions & 0 deletions demo/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
## Segment Anything Simple Web demo

This **front-end only** demo shows how to load a fixed image and `.npy` file of the SAM image embedding, and run the SAM ONNX model in the browser using Web Assembly with mulithreading enabled by `SharedArrayBuffer`, Web Worker, and SIMD128.

<img src="https://github.com/facebookresearch/segment-anything/raw/main/assets/minidemo.gif" width="500"/>

## Run the app

```
yarn && yarn start
```

Navigate to [`http://localhost:8081/`](http://localhost:8081/)

Move your cursor around to see the mask prediction update in real time.

## Change the image, embedding and ONNX model

In the [ONNX Model Example notebook](https://github.com/facebookresearch/segment-anything/blob/main/notebooks/onnx_model_example.ipynb) upload the image of your choice and generate and save corresponding embedding.

Initialize the predictor

```python
checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
sam = sam_model_registry[model_type](checkpoint=checkpoint)
sam.to(device='cuda')
predictor = SamPredictor(sam)
```

Set the new image and export the embedding

```
image = cv2.imread('src/assets/dogs.jpg')
predictor.set_image(image)
image_embedding = predictor.get_image_embedding().cpu().numpy()
np.save("dogs_embedding.npy", image_embedding)
```

Save the new image and embedding in `/assets/data`and update the following paths to the files at the top of`App.tsx`:

```py
const IMAGE_PATH = "/assets/data/dogs.jpg";
const IMAGE_EMBEDDING = "/assets/data/dogs_embedding.npy";
const MODEL_DIR = "/model/sam_onnx_quantized_example.onnx";
```

Optionally you can also export the ONNX model. Currently the example ONNX model from the notebook is saved at `/model/sam_onnx_quantized_example.onnx`.

**NOTE: if you change the ONNX model by using a new checkpoint you need to also re-export the embedding.**

## ONNX multithreading with SharedArrayBuffer

To use multithreading, the appropriate headers need to be set to create a cross origin isolation state which will enable use of `SharedArrayBuffer` (see this [blog post](https://cloudblogs.microsoft.com/opensource/2021/09/02/onnx-runtime-web-running-your-machine-learning-model-in-browser/) for more details)

The headers below are set in `configs/webpack/dev.js`:

```js
headers: {
"Cross-Origin-Opener-Policy": "same-origin",
"Cross-Origin-Embedder-Policy": "credentialless",
}
```

## Structure of the app

**`App.tsx`**

- Initializes ONNX model
- Loads image embedding and image
- Runs the ONNX model based on input prompts

**`Stage.tsx`**

- Handles mouse move interaction to update the ONNX model prompt

**`Tool.tsx`**

- Renders the image and the mask prediction

**`helpers/maskUtils.tsx`**

- Conversion of ONNX model output from array to an HTMLImageElement

**`helpers/onnxModelAPI.tsx`**

- Formats the inputs for the ONNX model

**`helpers/scaleHelper.tsx`**

- Handles image scaling logic for SAM (longest size 1024)

**`hooks/`**

- Handle shared state for the app
78 changes: 78 additions & 0 deletions demo/configs/webpack/common.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
const { resolve } = require("path");
const HtmlWebpackPlugin = require("html-webpack-plugin");
const FriendlyErrorsWebpackPlugin = require("friendly-errors-webpack-plugin");
const CopyPlugin = require("copy-webpack-plugin");
const webpack = require("webpack");

module.exports = {
entry: "./src/index.tsx",
resolve: {
extensions: [".js", ".jsx", ".ts", ".tsx"],
},
output: {
path: resolve(__dirname, "dist"),
},
module: {
rules: [
{
test: /\.mjs$/,
include: /node_modules/,
type: "javascript/auto",
resolve: {
fullySpecified: false,
},
},
{
test: [/\.jsx?$/, /\.tsx?$/],
use: ["ts-loader"],
exclude: /node_modules/,
},
{
test: /\.css$/,
use: ["style-loader", "css-loader"],
},
{
test: /\.(scss|sass)$/,
use: ["style-loader", "css-loader", "postcss-loader"],
},
{
test: /\.(jpe?g|png|gif|svg)$/i,
use: [
"file-loader?hash=sha512&digest=hex&name=img/[contenthash].[ext]",
"image-webpack-loader?bypassOnDebug&optipng.optimizationLevel=7&gifsicle.interlaced=false",
],
},
{
test: /\.(woff|woff2|ttf)$/,
use: {
loader: "url-loader",
},
},
],
},
plugins: [
new CopyPlugin({
patterns: [
{
from: "node_modules/onnxruntime-web/dist/*.wasm",
to: "[name][ext]",
},
{
from: "model",
to: "model",
},
{
from: "src/assets",
to: "assets",
},
],
}),
new HtmlWebpackPlugin({
template: "./src/assets/index.html",
}),
new FriendlyErrorsWebpackPlugin(),
new webpack.ProvidePlugin({
process: "process/browser",
}),
],
};
19 changes: 19 additions & 0 deletions demo/configs/webpack/dev.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// development config
const { merge } = require("webpack-merge");
const commonConfig = require("./common");

module.exports = merge(commonConfig, {
mode: "development",
devServer: {
hot: true, // enable HMR on the server
open: true,
// These headers enable the cross origin isolation state
// needed to enable use of SharedArrayBuffer for ONNX
// multithreading.
headers: {
"Cross-Origin-Opener-Policy": "same-origin",
"Cross-Origin-Embedder-Policy": "credentialless",
},
},
devtool: "cheap-module-source-map",
});
16 changes: 16 additions & 0 deletions demo/configs/webpack/prod.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// production config
const { merge } = require("webpack-merge");
const { resolve } = require("path");
const Dotenv = require("dotenv-webpack");
const commonConfig = require("./common");

module.exports = merge(commonConfig, {
mode: "production",
output: {
filename: "js/bundle.[contenthash].min.js",
path: resolve(__dirname, "../../dist"),
publicPath: "/",
},
devtool: "source-map",
plugins: [new Dotenv()],
});
64 changes: 64 additions & 0 deletions demo/package.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
{
"name": "se-demo",
"version": "0.1.0",
"license": "MIT",
"scripts": {
"build": "yarn run clean-dist && webpack --config=configs/webpack/prod.js && mv dist/*.wasm dist/js && cp -R dataset dist",
"clean-dist": "rimraf dist/*",
"lint": "eslint './src/**/*.{js,ts,tsx}' --quiet",
"start": "yarn run start-dev",
"test": "yarn run start-model-test",
"start-dev": "webpack serve --config=configs/webpack/dev.js"
},
"devDependencies": {
"@babel/core": "^7.18.13",
"@babel/preset-env": "^7.18.10",
"@babel/preset-react": "^7.18.6",
"@babel/preset-typescript": "^7.18.6",
"@pmmmwh/react-refresh-webpack-plugin": "^0.5.7",
"@testing-library/react": "^13.3.0",
"@types/node": "^18.7.13",
"@types/react": "^18.0.17",
"@types/react-dom": "^18.0.6",
"@types/underscore": "^1.11.4",
"@typescript-eslint/eslint-plugin": "^5.35.1",
"@typescript-eslint/parser": "^5.35.1",
"babel-loader": "^8.2.5",
"copy-webpack-plugin": "^11.0.0",
"css-loader": "^6.7.1",
"dotenv": "^16.0.2",
"dotenv-webpack": "^8.0.1",
"eslint": "^8.22.0",
"eslint-plugin-react": "^7.31.0",
"file-loader": "^6.2.0",
"fork-ts-checker-webpack-plugin": "^7.2.13",
"friendly-errors-webpack-plugin": "^1.7.0",
"html-webpack-plugin": "^5.5.0",
"image-webpack-loader": "^8.1.0",
"postcss-loader": "^7.0.1",
"postcss-preset-env": "^7.8.0",
"process": "^0.11.10",
"rimraf": "^3.0.2",
"sass": "^1.54.5",
"sass-loader": "^13.0.2",
"style-loader": "^3.3.1",
"tailwindcss": "^3.1.8",
"ts-loader": "^9.3.1",
"typescript": "^4.8.2",
"webpack": "^5.74.0",
"webpack-cli": "^4.10.0",
"webpack-dev-server": "^4.10.0",
"webpack-dotenv-plugin": "^2.1.0",
"webpack-merge": "^5.8.0"
},
"dependencies": {
"konva": "^8.3.12",
"npyjs": "^0.4.0",
"onnxruntime-web": "^1.14.0",
"react": "^18.2.0",
"react-dom": "^18.2.0",
"react-konva": "^18.2.1",
"underscore": "^1.13.6",
"react-refresh": "^0.14.0"
}
}
4 changes: 4 additions & 0 deletions demo/postcss.config.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
const tailwindcss = require("tailwindcss");
module.exports = {
plugins: ["postcss-preset-env", 'tailwindcss/nesting', tailwindcss],
};
Loading

0 comments on commit 426edb2

Please sign in to comment.