Skip to content

Commit

Permalink
bug fix: triton importing error (#3799)
Browse files Browse the repository at this point in the history
Co-authored-by: Stephen Youn <[email protected]>
Co-authored-by: Jeff Rasley <[email protected]>
  • Loading branch information
3 people committed Jun 23, 2023
1 parent aebdfb3 commit bafaf3c
Showing 1 changed file with 18 additions and 10 deletions.
28 changes: 18 additions & 10 deletions deepspeed/ops/transformer/inference/triton/matmul_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,15 @@
import deepspeed.ops.transformer.inference.triton.triton_matmul_kernel as triton_matmul_kernel
import pickle
from io import open
import deepspeed
from pathlib import Path
import atexit


# -----------------------------------------------------------------------------
# util class/functions for triton
def _default_cache_dir():
return os.path.join(os.environ["HOME"], ".triton", "autotune")
return os.path.join(Path.home(), ".triton", "autotune")


def bias_add_activation(C, bias=None, activation=""):
Expand Down Expand Up @@ -421,16 +424,21 @@ def _update_autotune_table():

# -----------------------------------------------------------------------------
# mapping
matmul = MatmulExt.forward
fp16_matmul = Fp16Matmul()
matmul_4d = fp16_matmul._matmul_4d
score_4d_matmul = fp16_matmul._score_4d_matmul
context_4d_matmul = fp16_matmul._context_4d_matmul

#
import atexit
if deepspeed.HAS_TRITON:
fp16_matmul = Fp16Matmul()
matmul = MatmulExt.forward
matmul_4d = fp16_matmul._matmul_4d
score_4d_matmul = fp16_matmul._score_4d_matmul
context_4d_matmul = fp16_matmul._context_4d_matmul
else:
fp16_matmul = None
matmul = None
matmul_4d = None
score_4d_matmul = None
context_4d_matmul = None


@atexit.register
def matmul_ext_update_autotune_table():
fp16_matmul._update_autotune_table()
if deepspeed.HAS_TRITON:
fp16_matmul._update_autotune_table()

0 comments on commit bafaf3c

Please sign in to comment.