diff --git a/gradlib/gemm_runner.py b/gradlib/gemm_runner.py index 07bddde61190d..39a6911e138d1 100644 --- a/gradlib/gemm_runner.py +++ b/gradlib/gemm_runner.py @@ -37,7 +37,8 @@ def mm(self, inp, weights): n=inp.shape[0], k=inp.shape[1]) if soltype == 1: - out = torch.ops._gradlib_C.hipb_mm(inp, weights.t(), solidx, None, None, None, None, None) + out = torch.ops._gradlib_C.hipb_mm(inp, weights.t(), solidx, None, + None, None, None, None) elif soltype == 2: out = torch.ops._gradlib_C.rocb_mm(inp, weights.t(), solidx) else: diff --git a/gradlib/gemm_tuner.py b/gradlib/gemm_tuner.py index 08b4edf0839ac..bfde36eed4ec5 100644 --- a/gradlib/gemm_tuner.py +++ b/gradlib/gemm_tuner.py @@ -4,7 +4,7 @@ from pathlib import Path import torch # isort: split -import vllm._gradlib_C +import vllm._gradlib_C # noqa: F401 import pandas as pd from gradlib.GemmTuner import GemmTuner diff --git a/vllm/model_executor/layers/tuned_gemm.py b/vllm/model_executor/layers/tuned_gemm.py index 7ec92f5dc66ef..a441ca5def07b 100644 --- a/vllm/model_executor/layers/tuned_gemm.py +++ b/vllm/model_executor/layers/tuned_gemm.py @@ -64,6 +64,7 @@ def create_ds(self): soltype = 2 solds[key] = (soltype, int(ds['solidx'])) self.solids = solds + def query_sol(self, m, n, k, bias, dtype): return self.solids.get((m, n, k, bias, str(dtype)), (0, 0))