Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WASM] Implement concat embeddings #17404

Merged
merged 2 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/target/source/codegen_webgpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ runtime::FunctionInfo CodeGenWebGPU::AddFunction(const PrimFunc& f, bool skip_re
name_supply_->ReserveName("var");
name_supply_->ReserveName("let");
name_supply_->ReserveName("const");
name_supply_->ReserveName("std");

// skip the first underscore, so SSA variable starts from
name_supply_->FreshName("v_");
Expand Down
46 changes: 46 additions & 0 deletions web/emcc/wasm_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,5 +173,51 @@ TVM_REGISTER_GLOBAL("tvmjs.runtime.ArrayConcat").set_body([](TVMArgs args, TVMRe
}
*ret = Array<ObjectRef>(data);
});

NDArray ConcatEmbeddings(const std::vector<NDArray>& embeddings) {
// Get output shape
int64_t hidden_size = embeddings[0]->shape[1];
DLDataType dtype = embeddings[0]->dtype;
DLDevice device = embeddings[0]->device;
int seqLen = 0;
for (int i = 0; i < embeddings.size(); ++i) {
ICHECK_EQ(embeddings[i]->ndim, 2);
ICHECK_EQ(embeddings[i]->shape[1], hidden_size);
seqLen += embeddings[i]->shape[0];
}

// Create output
std::vector<int64_t> shape;
shape.push_back(seqLen);
shape.push_back(hidden_size);
NDArray result = NDArray::Empty(shape, dtype, device);

// Copy
int offset = 0;
for (int i = 0; i < embeddings.size(); i++) {
const DLTensor& copy_src = *(embeddings[i].operator->());
const DLTensor* p_copy_dst = result.operator->();
DLTensor copy_dst = *p_copy_dst;
copy_dst.shape = embeddings[i]->shape;
copy_dst.byte_offset =
offset * hidden_size * ((embeddings[i]->dtype.bits * embeddings[i]->dtype.lanes + 7) / 8);
NDArray::CopyFromTo(&copy_src, &copy_dst);
offset += embeddings[i]->shape[0];
}

return result;
}

// Concatenate n NDArrays
TVM_REGISTER_GLOBAL("tvmjs.runtime.ConcatEmbeddings").set_body([](TVMArgs args, TVMRetValue* ret) {
std::vector<NDArray> embeddings;
for (int i = 0; i < args.size(); ++i) {
ICHECK_EQ(args[i].type_code(), kTVMNDArrayHandle);
embeddings.push_back(args[i]);
}
NDArray result = ConcatEmbeddings(std::move(embeddings));
*ret = result;
});

} // namespace runtime
} // namespace tvm
38 changes: 37 additions & 1 deletion web/src/runtime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ class RuntimeContext implements Disposable {
applyRepetitionPenalty: PackedFunc;
applyPresenceAndFrequencyPenalty: PackedFunc;
applySoftmaxWithTemperature: PackedFunc;
concatEmbeddings: PackedFunc | undefined;

private autoDisposeScope: Array<Array<Disposable | undefined>> = [];

Expand All @@ -199,6 +200,11 @@ class RuntimeContext implements Disposable {
this.applyRepetitionPenalty = getGlobalFunc("vm.builtin.apply_repetition_penalty");
this.applyPresenceAndFrequencyPenalty = getGlobalFunc("vm.builtin.apply_presence_and_frequency_penalty");
this.applySoftmaxWithTemperature = getGlobalFunc("vm.builtin.apply_softmax_with_temperature");
try {
this.concatEmbeddings = getGlobalFunc("tvmjs.runtime.ConcatEmbeddings");
} catch {
// TODO: remove soon. Older artifacts do not have this, try-catch for backward compatibility.
}
}

dispose(): void {
Expand All @@ -223,6 +229,7 @@ class RuntimeContext implements Disposable {
this.applyRepetitionPenalty.dispose();
this.applyPresenceAndFrequencyPenalty.dispose();
this.applySoftmaxWithTemperature.dispose();
this.concatEmbeddings?.dispose();
}

beginScope(): void {
Expand Down Expand Up @@ -575,7 +582,10 @@ export class NDArray implements Disposable {
* @param data The source data array.
* @returns this
*/
copyFrom(data: NDArray | Array<number> | Float32Array): this {
copyFrom(
data: NDArray | Array<number> | Float32Array | Float64Array |
Int32Array | Int8Array | Uint8Array | Uint8ClampedArray
): this {
if (data instanceof NDArray) {
this.lib.checkCall(
(this.lib.exports.TVMArrayCopyFromTo as ctypes.FTVMArrayCopyFromTo)(
Expand Down Expand Up @@ -608,6 +618,8 @@ export class NDArray implements Disposable {
buffer = Int8Array.from(data).buffer;
} else if (this.dtype === "uint8") {
buffer = Uint8Array.from(data).buffer;
} else if (this.dtype === "uint32") {
buffer = Uint32Array.from(data).buffer;
} else {
throw new Error("Unsupported data type " + this.dtype);
}
Expand Down Expand Up @@ -1906,6 +1918,30 @@ export class Instance implements Disposable {
return this.ctx.arrayConcat(...listOfArrays) as TVMArray;
}

/**
* Join a sequence of NDArrays that represent embeddings.
* @param inputs A list of embeddings in NDArrays, each array i has shape (m_i, hidden_size).
* @returns An NDArray of shape (\sum_{i} {m}, hidden_size)
*/
concatEmbeddings(embeddings: Array<NDArray>): NDArray {
// 1. Check shape validity
const hidden_size = embeddings[0].shape[1];
embeddings.forEach((input) => {
if (input.shape.length !== 2 || input.shape[1] !== hidden_size) {
throw new Error("Expect embeddings to concatenate have shape (m_i, hidden_size).");
}
})

// 2. Call global func
if (this.ctx.concatEmbeddings === undefined) {
throw new Error(
"Global function tvmjs.runtime.ConcatEmbeddings was " +
"not found, but called concatEmbeddings."
);
}
return this.ctx.concatEmbeddings(...embeddings) as NDArray;
}

/**
* Create a {@link TVMString} that can be consumed by runtime.
*
Expand Down
Loading