diff --git a/src/target/source/codegen_webgpu.cc b/src/target/source/codegen_webgpu.cc index 83079a9f0756..1d1df91dc4a4 100644 --- a/src/target/source/codegen_webgpu.cc +++ b/src/target/source/codegen_webgpu.cc @@ -125,6 +125,7 @@ runtime::FunctionInfo CodeGenWebGPU::AddFunction(const PrimFunc& f, bool skip_re name_supply_->ReserveName("var"); name_supply_->ReserveName("let"); name_supply_->ReserveName("const"); + name_supply_->ReserveName("std"); // skip the first underscore, so SSA variable starts from name_supply_->FreshName("v_"); diff --git a/web/emcc/wasm_runtime.cc b/web/emcc/wasm_runtime.cc index 2f7135595843..9744750b80db 100644 --- a/web/emcc/wasm_runtime.cc +++ b/web/emcc/wasm_runtime.cc @@ -173,5 +173,51 @@ TVM_REGISTER_GLOBAL("tvmjs.runtime.ArrayConcat").set_body([](TVMArgs args, TVMRe } *ret = Array(data); }); + +NDArray ConcatEmbeddings(const std::vector& embeddings) { + // Get output shape + int64_t hidden_size = embeddings[0]->shape[1]; + DLDataType dtype = embeddings[0]->dtype; + DLDevice device = embeddings[0]->device; + int seqLen = 0; + for (int i = 0; i < embeddings.size(); ++i) { + ICHECK_EQ(embeddings[i]->ndim, 2); + ICHECK_EQ(embeddings[i]->shape[1], hidden_size); + seqLen += embeddings[i]->shape[0]; + } + + // Create output + std::vector shape; + shape.push_back(seqLen); + shape.push_back(hidden_size); + NDArray result = NDArray::Empty(shape, dtype, device); + + // Copy + int offset = 0; + for (int i = 0; i < embeddings.size(); i++) { + const DLTensor& copy_src = *(embeddings[i].operator->()); + const DLTensor* p_copy_dst = result.operator->(); + DLTensor copy_dst = *p_copy_dst; + copy_dst.shape = embeddings[i]->shape; + copy_dst.byte_offset = + offset * hidden_size * ((embeddings[i]->dtype.bits * embeddings[i]->dtype.lanes + 7) / 8); + NDArray::CopyFromTo(©_src, ©_dst); + offset += embeddings[i]->shape[0]; + } + + return result; +} + +// Concatenate n NDArrays +TVM_REGISTER_GLOBAL("tvmjs.runtime.ConcatEmbeddings").set_body([](TVMArgs args, TVMRetValue* ret) { + std::vector embeddings; + for (int i = 0; i < args.size(); ++i) { + ICHECK_EQ(args[i].type_code(), kTVMNDArrayHandle); + embeddings.push_back(args[i]); + } + NDArray result = ConcatEmbeddings(std::move(embeddings)); + *ret = result; +}); + } // namespace runtime } // namespace tvm diff --git a/web/src/runtime.ts b/web/src/runtime.ts index 600a9b857f03..8546cab773ff 100644 --- a/web/src/runtime.ts +++ b/web/src/runtime.ts @@ -174,6 +174,7 @@ class RuntimeContext implements Disposable { applyRepetitionPenalty: PackedFunc; applyPresenceAndFrequencyPenalty: PackedFunc; applySoftmaxWithTemperature: PackedFunc; + concatEmbeddings: PackedFunc | undefined; private autoDisposeScope: Array> = []; @@ -199,6 +200,11 @@ class RuntimeContext implements Disposable { this.applyRepetitionPenalty = getGlobalFunc("vm.builtin.apply_repetition_penalty"); this.applyPresenceAndFrequencyPenalty = getGlobalFunc("vm.builtin.apply_presence_and_frequency_penalty"); this.applySoftmaxWithTemperature = getGlobalFunc("vm.builtin.apply_softmax_with_temperature"); + try { + this.concatEmbeddings = getGlobalFunc("tvmjs.runtime.ConcatEmbeddings"); + } catch { + // TODO: remove soon. Older artifacts do not have this, try-catch for backward compatibility. + } } dispose(): void { @@ -223,6 +229,7 @@ class RuntimeContext implements Disposable { this.applyRepetitionPenalty.dispose(); this.applyPresenceAndFrequencyPenalty.dispose(); this.applySoftmaxWithTemperature.dispose(); + this.concatEmbeddings?.dispose(); } beginScope(): void { @@ -575,7 +582,10 @@ export class NDArray implements Disposable { * @param data The source data array. * @returns this */ - copyFrom(data: NDArray | Array | Float32Array): this { + copyFrom( + data: NDArray | Array | Float32Array | Float64Array | + Int32Array | Int8Array | Uint8Array | Uint8ClampedArray + ): this { if (data instanceof NDArray) { this.lib.checkCall( (this.lib.exports.TVMArrayCopyFromTo as ctypes.FTVMArrayCopyFromTo)( @@ -608,6 +618,8 @@ export class NDArray implements Disposable { buffer = Int8Array.from(data).buffer; } else if (this.dtype === "uint8") { buffer = Uint8Array.from(data).buffer; + } else if (this.dtype === "uint32") { + buffer = Uint32Array.from(data).buffer; } else { throw new Error("Unsupported data type " + this.dtype); } @@ -1906,6 +1918,30 @@ export class Instance implements Disposable { return this.ctx.arrayConcat(...listOfArrays) as TVMArray; } + /** + * Join a sequence of NDArrays that represent embeddings. + * @param inputs A list of embeddings in NDArrays, each array i has shape (m_i, hidden_size). + * @returns An NDArray of shape (\sum_{i} {m}, hidden_size) + */ + concatEmbeddings(embeddings: Array): NDArray { + // 1. Check shape validity + const hidden_size = embeddings[0].shape[1]; + embeddings.forEach((input) => { + if (input.shape.length !== 2 || input.shape[1] !== hidden_size) { + throw new Error("Expect embeddings to concatenate have shape (m_i, hidden_size)."); + } + }) + + // 2. Call global func + if (this.ctx.concatEmbeddings === undefined) { + throw new Error( + "Global function tvmjs.runtime.ConcatEmbeddings was " + + "not found, but called concatEmbeddings." + ); + } + return this.ctx.concatEmbeddings(...embeddings) as NDArray; + } + /** * Create a {@link TVMString} that can be consumed by runtime. *