From 5a4f0c9af5c53c7f72d2a20610b4c1ae827fdf2e Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 22 Jul 2020 22:58:55 +0000 Subject: [PATCH] CBLAS_OFFSET only available for MKL --- python/tvm/relay/grammar/py3/RelayLexer.py | 4 ++-- python/tvm/relay/grammar/py3/RelayParser.py | 9 ++++++--- python/tvm/relay/grammar/py3/RelayVisitor.py | 2 +- src/runtime/contrib/cblas/cblas.cc | 6 ++++++ tests/python/contrib/test_cblas.py | 2 +- 5 files changed, 16 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/grammar/py3/RelayLexer.py b/python/tvm/relay/grammar/py3/RelayLexer.py index 76e988b454180..4dafc9ddfc7b1 100644 --- a/python/tvm/relay/grammar/py3/RelayLexer.py +++ b/python/tvm/relay/grammar/py3/RelayLexer.py @@ -1,4 +1,4 @@ -# Generated from /Users/doobs/Code/repo/sampl/tvm/python/tvm/relay/grammar/Relay.g4 by ANTLR 4.7.2 +# Generated from /home/ubuntu/workplace/tvm/t1/tvm/python/tvm/relay/grammar/Relay.g4 by ANTLR 4.8 from antlr4 import * from io import StringIO from typing.io import TextIO @@ -248,7 +248,7 @@ class RelayLexer(Lexer): def __init__(self, input=None, output:TextIO = sys.stdout): super().__init__(input, output) - self.checkVersion("4.7.2") + self.checkVersion("4.8") self._interp = LexerATNSimulator(self, self.atn, self.decisionsToDFA, PredictionContextCache()) self._actions = None self._predicates = None diff --git a/python/tvm/relay/grammar/py3/RelayParser.py b/python/tvm/relay/grammar/py3/RelayParser.py index f24eed4be92f7..75fe48b9de7f5 100644 --- a/python/tvm/relay/grammar/py3/RelayParser.py +++ b/python/tvm/relay/grammar/py3/RelayParser.py @@ -1,9 +1,12 @@ -# Generated from /Users/doobs/Code/repo/sampl/tvm/python/tvm/relay/grammar/Relay.g4 by ANTLR 4.7.2 +# Generated from /home/ubuntu/workplace/tvm/t1/tvm/python/tvm/relay/grammar/Relay.g4 by ANTLR 4.8 # encoding: utf-8 from antlr4 import * from io import StringIO -from typing.io import TextIO import sys +if sys.version_info[1] > 5: + from typing import TextIO +else: + from typing.io import TextIO def serializedATN(): @@ -387,7 +390,7 @@ class RelayParser ( Parser ): def __init__(self, input:TokenStream, output:TextIO = sys.stdout): super().__init__(input, output) - self.checkVersion("4.7.2") + self.checkVersion("4.8") self._interp = ParserATNSimulator(self, self.atn, self.decisionsToDFA, self.sharedContextCache) self._predicates = None diff --git a/python/tvm/relay/grammar/py3/RelayVisitor.py b/python/tvm/relay/grammar/py3/RelayVisitor.py index c6a7b7a0558c9..ee5d1d8752c48 100644 --- a/python/tvm/relay/grammar/py3/RelayVisitor.py +++ b/python/tvm/relay/grammar/py3/RelayVisitor.py @@ -1,4 +1,4 @@ -# Generated from /Users/doobs/Code/repo/sampl/tvm/python/tvm/relay/grammar/Relay.g4 by ANTLR 4.7.2 +# Generated from /home/ubuntu/workplace/tvm/t1/tvm/python/tvm/relay/grammar/Relay.g4 by ANTLR 4.8 from antlr4 import * if __name__ is not None and "." in __name__: from .RelayParser import RelayParser diff --git a/src/runtime/contrib/cblas/cblas.cc b/src/runtime/contrib/cblas/cblas.cc index 3b230455e8f3e..e84ee1127fdb2 100644 --- a/src/runtime/contrib/cblas/cblas.cc +++ b/src/runtime/contrib/cblas/cblas.cc @@ -44,6 +44,7 @@ using namespace runtime; inline CBLAS_TRANSPOSE BooleanToTranspose(bool trans) { return trans ? CblasTrans : CblasNoTrans; } +#if USE_MKL_BLAS == 1 inline CBLAS_OFFSET StringToOffset(const std::string offset_type) { if (offset_type != "CblasFixOffset" && offset_type != "CblasColOffset" && offset_type != "CblasRowOffset") { @@ -56,6 +57,7 @@ inline CBLAS_OFFSET StringToOffset(const std::string offset_type) { } return CblasRowOffset; } +#endif inline char BooleanToTransposeChar(bool trans) { return trans ? 'T' : 'N'; } @@ -63,9 +65,13 @@ struct CblasGemmU8S8S32Op { void operator()(bool ta, bool tb, int M, int N, int K, float alpha, const void* A, int lda, int offset_a, const void* B, int ldb, int offset_b, float beta, int* C, int ldc, const std::string offset_ctype, int* offset_c) { +#if USE_MKL_BLAS == 1 cblas_gemm_s8u8s32(CblasColMajor, BooleanToTranspose(ta), BooleanToTranspose(tb), StringToOffset(offset_ctype), M, N, K, alpha, A, lda, offset_a, B, ldb, offset_b, beta, C, ldc, offset_c); +#else + LOG(FATAL) << "Quantized Gemm is supported with MKL Blas only"; +#endif } }; diff --git a/tests/python/contrib/test_cblas.py b/tests/python/contrib/test_cblas.py index 21286575d4a7a..5b9dd3f4ce280 100644 --- a/tests/python/contrib/test_cblas.py +++ b/tests/python/contrib/test_cblas.py @@ -89,7 +89,7 @@ def verify(target="llvm"): if not tvm.runtime.enabled(target): print("skip because %s is not enabled..." % target) return - if not tvm.get_global_func("tvm.contrib.cblas.matmul", True): + if not tvm.get_global_func("tvm.contrib.cblas.matmul_u8s8s32", True): print("skip because extern function is not available") return ctx = tvm.cpu(0)