Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sync internal repo to external Dec 12 2024 #103

Open
wants to merge 139 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
139 commits
Select commit Hold shift + click to select a range
2e807a6
[generation] Fix on-device generation without filtering
aws-patlange Apr 24, 2024
3d965f6
[sos] mulit-bucket and bsh attn layout support
May 10, 2024
1e0bbe1
[generation_demo] fixes, add flags
May 15, 2024
d7725a7
[sos] seq mismatch fix for testing
May 16, 2024
18f5eab
sequence parallel with fully unrolled and on_device embedding
aws-amerrez May 16, 2024
f080045
[compiler] Enabled writing HLO file metadata when dumping is configured
jluntamazon May 15, 2024
bda234f
[hlo] switch to the fast past cumsum NKI kernel for floats
hannanjgaws May 17, 2024
4365742
[speculation] Fix 2D cache_id pad for speculative_forward
May 17, 2024
d578c19
Enabling output bucketing for continuous batching with SBH cache layout.
May 19, 2024
6b7b980
[config/attention/llama] Transposed attention output projection weight
May 22, 2024
0fa7550
make weightiling permute order configurable
aws-bowencc May 23, 2024
f9bc7c7
make weightiling permute order configurable
aws-bowencc May 23, 2024
6e2b374
[attention] Add support for masking exp values in context
hannanjgaws May 24, 2024
c038af9
[generation_demo] Add naive continuous batching
aws-patlange May 17, 2024
008ed9a
switch back to gelu_new_legacy
gevadyla May 29, 2024
f7316ff
Enable continuous batching for multi-layer models
aws-amulyaab May 16, 2024
78e39e4
Fix on-device generation during continuous batching
aws-patlange Apr 21, 2024
695b5d6
Fix on-device embedding and continuous batching
aws-patlange Apr 21, 2024
3913740
Use collectives_layout=BSH as default and add deprecation warning
aws-patlange May 23, 2024
478714e
Remove internal code name
akhil-aws May 29, 2024
4a0f913
stop criteria fix to cover generated tokens
aws-amerrez May 30, 2024
cca53a1
[attn] Fix NKI kernel import in attention utils.
May 30, 2024
8ad9cc4
Revert "stop criteria fix to cover generated tokens"
aws-patlange May 30, 2024
c0c54cc
Revert "[generation_demo] Add naive continuous batching"
aws-patlange May 30, 2024
1ba1ba7
[sos] node interleave qkv support
May 30, 2024
86a6dd9
Revert "Use collectives_layout=BSH as default and add deprecation war…
aws-patlange May 31, 2024
e9a0159
Allow saving and loading of preshareded weights
aws-yishanm May 3, 2024
7fb318c
Extend sequence parallel normalization support in the MLP and LM Head…
hannanjgaws Jun 4, 2024
ef6b434
[sos] unit_2 bug fix
Jun 4, 2024
4f829a0
[decoder] Fix string comparison for 'None' params
hannanjgaws Jun 4, 2024
7b71068
Do not use flash attention for batch > 1 left padded
May 30, 2024
df848fd
[generation_demo] fix generation_demo merge errors
May 24, 2024
9005805
Add sliding window masks
Jun 7, 2024
ad4cfb5
update nki kernel import path
aws-jiyoua May 29, 2024
a53eb59
Add support for per batch-line dynamic sampling
aws-aymahg Jun 11, 2024
b3d29e0
[transformer] Do not AllGather lm head output when on-device generati…
hannanjgaws Jun 12, 2024
7c828a5
initial support for concatenated prompt encoding
liangfu Jun 14, 2024
c403ae6
add block-wise attention mask for paged attention
liangfu Jun 17, 2024
ae3a4ff
Add token tree utils for speculation
sssrijan-amazon Apr 26, 2024
ab77fea
Loading token tree attention as part of model weights
sssrijan-amazon May 22, 2024
e14802f
Support on device sampling with token tree attention
sssrijan-amazon May 24, 2024
bead732
Cache reordering with Tree speculative generator
sssrijan-amazon Jun 13, 2024
99ba588
Wrap logits decoding in a function for timing
aws-aymahg Jun 17, 2024
2533b34
fix on-device embedding via cast after gather
liangfu Jun 21, 2024
da61bd7
Enabling custom RMSNorm across the board, and support sequence parall…
CptTZ Jun 17, 2024
f0fada0
add CB unit test and assertion for concat prompt encoding
liangfu Jun 18, 2024
d22daad
fix mixtral model CB support
liangfu Jun 19, 2024
49d2fef
Vectorize input padding for token gen with continuous batching
aws-patlange Jun 7, 2024
7d38741
Fix logits shape for CB with VLLM
aws-aymahg Jun 24, 2024
b17d742
[llama] Generalize forward to allow embeddings as input
Jun 19, 2024
90e51ea
initial paged attention support via gather all blocks
liangfu Jun 27, 2024
5161bbe
Initial commit for fused speculation
Jun 11, 2024
9bdbc8b
fixed missing logics for updating draft cache
Jun 20, 2024
284b0a5
Add assertions for input k
Jun 28, 2024
a399b0f
[llama3][optimization] fused_rmsnorm_qkv flag to enable fused kernel …
aws-shchung Jun 26, 2024
faaddef
Add fp8 kv cache quant support
aws-amulyaab Jun 25, 2024
192c101
[Speculation] Adding batch_size>1 support
aws-amulyaab Mar 15, 2024
e44938d
make mlp down proj configurable
aws-bowencc Jun 20, 2024
fda0c6c
[generation_demo] Add shard_over_sequence config
Jul 8, 2024
f89cbf9
Enable right padding for fused speculation
Jul 15, 2024
f36f504
[llama3][optimization] fused_rmsnorm_qkv flag, fixing merge error wit…
aws-shchung Jul 3, 2024
60eb797
Add continuous batching to generation demo
aws-amulyaab Jul 16, 2024
b18c025
fix tree spec test failure in pipe
aws-yishanm Jul 16, 2024
3bf4137
Use smaller all-gather in LMHead when using sequence parallel Norm
Jul 16, 2024
b57b51d
[sos] 2D cache_id masking support
Jul 17, 2024
db313c6
commit: be57ec6d2956c8d14dcbe515306afea9ff3099b9
gevadyla Jul 18, 2024
1c17a21
llama fuse mlp support
Jul 18, 2024
93d642c
Add scaled rope support
sssrijan-amazon Jul 15, 2024
8db726f
Revert "commit: be57ec6d2956c8d14dcbe515306afea9ff3099b9"
gevadyla Jul 19, 2024
c946284
ROpe scaling fix
sssrijan-amazon Jul 19, 2024
467af9b
Fix sample_loop_llama
aws-amulyaab Jul 18, 2024
3565e88
initial support for optimized paged attention (without output bucketing)
liangfu Jul 19, 2024
fb12e3e
Adding output bucketing for paged attention blocks based on n_positions.
Jul 8, 2024
15a36c8
Rope scaling inline with transformers impl
sssrijan-amazon Jul 23, 2024
b75a4de
Fix on_device_sampling to work with Continuous Batching
aws-amulyaab Jul 23, 2024
ec1a5d1
Add args to generation demo
aws-amulyaab Jul 24, 2024
5eccbbc
Fix KV head full-replication when tp is not a multiple of num_kv_heads
desudit Jul 23, 2024
6dc4305
add block-wise from bottom-right mask
liangfu Jul 30, 2024
67fa88c
Fix KV head full replication logic
desudit Jul 30, 2024
c8cc7c3
enable continous batching iteration with fused speculative decoding
Jul 31, 2024
64df0b3
Fix sequence parallel bug
Aug 2, 2024
125ddfe
fix cache selection logic for multi prompt with speculative decoding
Aug 3, 2024
39fd37f
make deterministic_threshold configurable for fused speculation
Aug 5, 2024
a131c7c
fix tree speculation failure
aws-yishanm Aug 6, 2024
5ef6014
setting execute_repetition to default 1 or num layers when not fully …
aws-amerrez Aug 7, 2024
5583909
Fix the handling for finished seqs as part of static batching in fuse…
sssrijan-amazon Aug 5, 2024
0603c61
commit: be57ec6d2956c8d14dcbe515306afea9ff3099b9
gevadyla Jul 18, 2024
43907f5
config dump to subdirs and perform dump at beginning of compilation
gevadyla Jul 31, 2024
09c155f
Add invalid token_id fix to hlo.multinomial
aws-aymahg Aug 8, 2024
2613961
Handle seq_ids consistently across vLLM versions
aws-patlange Aug 5, 2024
7d601ed
Add support for multiple eos_token_ids
Aug 12, 2024
af3aca7
update mixtral/model.py, retrieve sliding window start location from …
mike-zhang-aws Jun 12, 2024
5b6187a
output logits from speculative sample
gevadyla Aug 8, 2024
7b5514a
update speculative sample docs
gevadyla Aug 12, 2024
61ded08
Adding FP8 weight quantization support
Aug 13, 2024
e6ca307
add padding when request batch size is small than neff size
Aug 13, 2024
63e4e54
Set execute-repetition flag properly for both CE and TG
Aug 14, 2024
c797eab
[Llama3][BIR native kernels] MLP and QKV BIR kernels for llama
aws-shchung Aug 14, 2024
c0564b4
[Llama3] fixing number of arguments in token_tree_layer
aws-shchung Aug 19, 2024
ef67517
Flash decoding support for speculation
Aug 22, 2024
cb5659a
Add on-device generation support in spec forward
Aug 22, 2024
e74e49e
Enable on-device generation in spec_forward
Aug 22, 2024
9c107d1
Add update generation config support for fused speculation
sssrijan-amazon Aug 22, 2024
169db0f
[continuous-batching] Only _postprocess during continuous batching to…
aws-patlange Aug 22, 2024
2caf530
Fix sequence parallel norm bug when executor + speculation/windowed CE
Aug 29, 2024
a3d562f
Flash decoding support for speculation
Aug 29, 2024
50daee6
Parallel Vocab support
Aug 29, 2024
76376f8
Update version to 0.13
hannanjgaws Aug 28, 2024
171bd38
[Sequence Parallel][Kernels] Adding MLP BIR kernel support for sequen…
aws-shchung Sep 6, 2024
a029427
Parallel Vocab + sequence parallel in llama
Sep 10, 2024
923e8c1
Add Eagle spec decoding support
Sep 11, 2024
80a3e12
Add greedy sampling for eagle spec decoding
Sep 12, 2024
3b56f7a
merge with mainline
Sep 17, 2024
3bdbb60
bucket changes for shard over seq with fused speculation
Sep 18, 2024
e05fc47
fix greedy generation
Sep 20, 2024
9d27625
make sure we have the right masking first
Sep 20, 2024
916e0cb
Fix eagle accuracy issue
Sep 21, 2024
e03976c
Chunked Prefill
liangfu Sep 24, 2024
4248af8
add option to skip allgather with duplicate q weights during sos
aws-bowencc Sep 25, 2024
b3c588e
Fix Eagle SD error due to additional inputs from chunked prefill
Sep 28, 2024
68007ff
Merge branch 'jf-sos-chunked-pa' into mainline
Oct 1, 2024
8b9112a
enable flashattention-style comm, duplicate query, topoaware-comm and
Oct 8, 2024
aad25fe
Merge branch 'mainline' of ssh://git.amazon.com:2222/pkg/KaenaTransfo…
Oct 8, 2024
e861b55
reorder argument input for shraded_kv_indexing
Oct 8, 2024
c4d0b2e
fix softmax, topo-aware sharding for llama, optimize pre-layer mask g…
Oct 9, 2024
5aad935
remove topo-aware for shard-over-sequence
Oct 15, 2024
67d762b
Merge branch 'mainline' of ssh://git.amazon.com:2222/pkg/KaenaTransfo…
Oct 15, 2024
4df0cb0
Merge branch 'mainline' of ssh://git.amazon.com:2222/pkg/KaenaTransfo…
Oct 15, 2024
dbeb8b8
slice the kv cache correctly when using sos
aws-bowencc Oct 21, 2024
2214b99
Fix SD token selection logic
Oct 26, 2024
feef47c
fix eagle sd on corner cases when input length is 1 or 2
aws-bowencc Oct 28, 2024
5d0c91e
Use f32 rng for SD token selection
Oct 28, 2024
4a8eefa
TP 128 draft fix
Oct 28, 2024
cedd4e7
[generation] Fix per_batch_line sampling param lookup for CB
aws-patlange Oct 29, 2024
76c569d
[generation] Enforce desired dtype for on_device_generation sampling …
aws-patlange Oct 30, 2024
38027d4
fix duplicate q when q/tp!=1
aws-bowencc Oct 31, 2024
d2e795b
[load_weights] Nullify materialized parameters
aws-patlange Nov 6, 2024
a391dd9
Support bf16 rms norm and new neuron config bf16_rms_norm=True to ena…
Nov 15, 2024
997bf9d
enforce max context_length_estimate <= max n_positions
aws-yishanm Nov 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
334 changes: 280 additions & 54 deletions src/transformers_neuronx/base.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions src/transformers_neuronx/bloom/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,4 @@ def __init__(
self.batch_size = batch_size
self.amp = amp
self.tp_degree = tp_degree
self.model_type = 'bloom'
7 changes: 4 additions & 3 deletions src/transformers_neuronx/bloom/hlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def inputs(self, scribe, dtype, n_active_tokens, batch_size):
)
return tensors, dims

def embedding(self, input_ids, cache_ids, start_ids, last_token_id, slopes, word_embeddings, ln_weight, ln_bias):
def embedding(self, input_ids, cache_ids, start_ids, last_token_id, block_tables, context_lens, slopes, word_embeddings, ln_weight, ln_bias):
dtype = getattr(input_ids.scribe, self.config.amp)
hidden = hlo.embedding(word_embeddings, input_ids, tp_degree=self.config.tp_degree, dtype=dtype)
if self.config.hidden_size % self.config.tp_degree != 0:
Expand All @@ -41,9 +41,10 @@ def embedding(self, input_ids, cache_ids, start_ids, last_token_id, slopes, word
return hlo.layer_norm_bsh(hidden, ln_weight, ln_bias) if is_bsh \
else hlo.layer_norm(hidden, ln_weight, ln_bias)

def pre_layer(self, hidden, cache_ids, start_ids, last_token_id, *pre_layer_weights):
def pre_layer(self, hidden, cache_ids, start_ids, last_token_id, block_tables, context_lens, *pre_layer_weights):
slopes, *rest = pre_layer_weights
mask, active_mask = hlo.attention_mask(cache_ids, start_ids, self.n_positions)
mask, active_mask = hlo.attention_mask(cache_ids, start_ids, self.n_positions,
last_token_id=last_token_id, neuron_config=self.neuron_config)
prior_alibi, active_alibi = alibi.alibi(slopes, mask, active_mask)
return hidden, last_token_id, cache_ids, mask, active_mask, prior_alibi, active_alibi

Expand Down
27 changes: 15 additions & 12 deletions src/transformers_neuronx/bloom/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import torch
import math
import warnings

from transformers_neuronx import decoder
from transformers_neuronx import module
from transformers_neuronx import ops
from transformers_neuronx import sampling
from transformers_neuronx import utils
from transformers_neuronx import bucket
from transformers_neuronx import base
from transformers_neuronx.constants import LAYOUT_BSH, LAYOUT_HSB
Expand Down Expand Up @@ -67,11 +60,7 @@ def __init__(self, config, *, n_positions=2048, batch_size=1, amp='f32', tp_degr
self.decoder_lm_head = self.decoder_param_set.init_token_decoder(unroll=self.unroll, buckets=self.token_buckets, model_obj=self)

def load_weights(self):
# Materialize the embedding to CPU
self.chkpt_model.transformer.word_embeddings.materialize()
self.chkpt_model.transformer.word_embeddings_layernorm.materialize()

ops.init()
self.materialize_embeddings()

n_head = self.config.n_head
hidden_size = self.config.hidden_size
Expand Down Expand Up @@ -142,6 +131,7 @@ def load_weights(self):
ln_f = self.chkpt_model.transformer.ln_f
ln_f.materialize()
self.decoder_lm_head.add_final_layer_norm(ln_f.weight.detach(), ln_f.bias.detach())
ln_f.nullify()

lm_head = self.chkpt_model.lm_head
lm_head.materialize()
Expand All @@ -154,7 +144,20 @@ def load_weights(self):
self.decoder_lm_head.add_pre_layer_parameter(self.chkpt_model.transformer.word_embeddings_layernorm.weight)
self.decoder_lm_head.add_pre_layer_parameter(self.chkpt_model.transformer.word_embeddings_layernorm.bias)
self.decoder_lm_head.to_neuron()
self.init_rest_of_model()
self.maybe_nullify_embeddings()

def materialize_embeddings(self):
# Materialize the embedding to CPU
self.chkpt_model.transformer.word_embeddings.materialize()
self.chkpt_model.transformer.word_embeddings_layernorm.materialize()

def maybe_nullify_embeddings(self):
if self.neuron_config.on_device_embedding:
self.chkpt_model.transformer.word_embeddings.nullify()
self.chkpt_model.transformer.word_embeddings_layernorm.nullify()

def init_rest_of_model(self):
if self.context_buckets:
for context_length_estimate in self.context_buckets:
for batch_size in self.batch_sizes:
Expand Down
69 changes: 52 additions & 17 deletions src/transformers_neuronx/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,19 @@
# ==============================================================================
import os
import shlex
import shutil
import subprocess
import hashlib
import tarfile
import tempfile
from contextlib import contextmanager
import numpy as np
from contextlib import contextmanager, nullcontext
from textwrap import dedent
import torch
import logging
import json
import math

import numpy as np
import torch

from torch_neuronx.pyhlo import xla_data_pb2
from torch_neuronx.pyhlo.scribe import HloScribe
from torch_neuronx.pyhlo.constant.serialize_torch import serialize_torch
Expand All @@ -35,6 +37,7 @@
from libneuronxla.neuron_cc_cache import CacheUrl, create_compile_cache
from neuronxcc import __version__ as compiler_version


def get_hash_module(hlo_module, flags):
# Hashing is pretty fast and neglegible compared to compilation time
hash_gen = hashlib.sha256()
Expand All @@ -45,8 +48,29 @@ def get_hash_module(hlo_module, flags):
hash = str(hash_gen.hexdigest())[:20]
return hash


@contextmanager
def envvar(key, value):
prior = os.environ.pop(key, None)
if value is not None:
os.environ[key] = value
try:
yield
finally:
os.environ.pop(key, None)
if prior is not None:
os.environ[key] = prior


def compile_py_func(py_func):
return HloScribe(serialize_torch)(py_func).module_proto

# Adds file/scope metadata during debug dump
context = nullcontext()
if "NEURONX_DUMP_TO" in os.environ and 'ENABLE_PYHLO_FILE_METADATA' not in os.environ:
context = envvar('ENABLE_PYHLO_FILE_METADATA', '1')

with context:
return HloScribe(serialize_torch)(py_func).module_proto


def build_kernel(py_func, tp_degree):
Expand Down Expand Up @@ -81,9 +105,11 @@ def get_compiler_flags() -> str:
return ' '.join(flags)


def compile_hlo_module(hlo_module, tag=None):
def compile_hlo_module(hlo_module, tag=None, num_exec_repetition=1):

flags = get_compiler_flags()
flags = f'{flags} --execute-repetition={num_exec_repetition}'

module_flag_hash = get_hash_module(hlo_module, flags)
module_hash = get_hash_module(hlo_module, None)

Expand All @@ -97,10 +123,8 @@ def compile_hlo_module(hlo_module, tag=None):
hlo_module_name = f'{tag}-{hlo_module.name}.{compiler_version}.{module_flag_hash}'

if dump:


dump_to = os.environ.get('NEURONX_DUMP_TO', '/tmp')
dump_to = os.path.join(dump_to, hlo_module_name)
dump_to_parent = os.environ.get('NEURONX_DUMP_TO', '/tmp')
dump_to = os.path.join(dump_to_parent, hlo_module_name)
os.makedirs(dump_to, exist_ok=True)
hlo_module_path = os.path.join(dump_to, f'{hlo_module_name}.pb')
hlo_module_path = os.path.realpath(hlo_module_path)
Expand All @@ -115,6 +139,10 @@ def compile_hlo_module(hlo_module, tag=None):
subprocess.check_call(command_line, cwd=dump_to)
with open(neff_path, 'rb') as f:
neff_bytes = f.read()
try:
shutil.copyfile(os.path.join(dump_to_parent, 'neuron_model_config.json'), os.path.join(dump_to, 'neuron_model_config.json'))
except FileNotFoundError:
pass
else:
module_bytes = hlo_module.SerializeToString()
try:
Expand Down Expand Up @@ -201,7 +229,10 @@ def __init__(self):
F32 FLOAT float32
F64 DOUBLE float64
BF16 BFLOAT16 bfloat16
F8E4M3FN INT8 float8_e4m3fn
'''
# Note that for FP8 we map metaneff datatype to int8, since from the runtime perspective these datatypes are functionally equivalent (for fp8 storage only)
# Within Tnx, we no longer use the metaneff flow, so this would not matter anyway.
name_mapping = dedent(name_mapping)
name_mapping = name_mapping.lstrip().strip()
self.hlo2metaneff_mapping = {}
Expand All @@ -211,6 +242,8 @@ def __init__(self):
for line in name_mapping.split('\n'):
line = line.lstrip().strip()
pname, dname, tname = line.split()
if not hasattr(torch, tname):
continue
primitive_type = getattr(xla_data_pb2.PrimitiveType, pname)
metaneff_dtype = getattr(metaneff_pb2.MetaTensor.DataType, dname)
torch_dtype = getattr(torch, tname)
Expand Down Expand Up @@ -355,10 +388,10 @@ def __call__(self, inputs, return_ranks: int = -1):
result: The output tensors from each rank concatenated along dim 0.
"""
casted = []
for cpu, buf in zip(inputs, self.inputs):
for i, (cpu, buf) in enumerate(zip(inputs, self.inputs)):
if cpu.shape != buf.shape:
raise AssertionError(
f"Input shape mismatch. Expected {buf.shape}, but got {cpu.shape}"
f"{i+1}th input shape mismatch. Expected {buf.shape}, but got {cpu.shape}"
)
if cpu.dtype != buf.dtype:
cpu = cpu.to(buf.dtype)
Expand Down Expand Up @@ -444,7 +477,7 @@ def io_ring_cache_context(size):

class ParallelKernel:
hlo_snapshot_iter = 0
def __init__(self, hlo_module, tp_degree, g_start_device_id=0, g_device_count=None, tag=None):
def __init__(self, hlo_module, tp_degree, g_start_device_id=0, g_device_count=None, tag=None, num_exec_repetition=1):
self.hlo_module = hlo_module
self.tp_degree = tp_degree
self.neff_bytes = None
Expand All @@ -459,6 +492,7 @@ def __init__(self, hlo_module, tp_degree, g_start_device_id=0, g_device_count=No
self.tag = tag
self.g_device_count = g_device_count
self.memories = []
self.num_exec_repetition = num_exec_repetition
self.total_input_tensors_size = get_total_input_tensors_size(self.hlo_module)
logging.debug(f"Total input tensor size of the module (per rank): {self.total_input_tensors_size / (10**9)} G, whole (all ranks): {self.total_input_tensors_size * tp_degree / (10**9)} G")

Expand All @@ -467,15 +501,15 @@ def build_memory(self):
self.memories.append(memory)
return memory

def compile(self):
self.build()
def compile(self, num_exec_repetition=1):
self.build(num_exec_repetition)
return self.neff_bytes

def build(self):
def build(self, num_exec_repetition=1):
# Avoid rebuilding NEFF. This path occurs during deserialization
if self.neff_bytes is not None:
return
self.neff_bytes = compile_hlo_module(self.hlo_module, self.tag)
self.neff_bytes = compile_hlo_module(self.hlo_module, self.tag, num_exec_repetition)

def load(self, io_ring_cache_size=1):
assert self.neff_bytes is not None, f"Try to load with neff bytes as None, might due to compilation failure"
Expand Down Expand Up @@ -700,3 +734,4 @@ def setup(self, nc_input_buffers, nc_output_buffers, output_count=None):

def run(self):
self.kernel(self.memories)

Loading