diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index bd3108141695..843f022d67a8 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -299,13 +299,13 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage set_current_device(device) if stream is None and not warmup: stream = get_cuda_stream(device) - try: - bin = cache[device][key] + bin = cache[device].get(key, None) + if bin is not None: if not warmup: bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, {args}) return bin # kernel not cached -- compile - except KeyError: + else: # build dict of constant values args = [{args}] all_args = {', '.join([f'{arg}' for arg in self.arg_names])},