From ad399acd9364e075e9a5f47b2a2054b74ef547d8 Mon Sep 17 00:00:00 2001
From: Bing Xu <antinucleon@gmail.com>
Date: Sun, 19 Apr 2020 22:15:37 -0700
Subject: [PATCH] [Blocksparse] Pipeline for lowering dense model to
 sparse-dense (#5377)

---
 python/setup.py                               |   1 +
 python/tvm/relay/__init__.py                  |   1 +
 python/tvm/relay/analysis/__init__.py         |   1 +
 python/tvm/relay/analysis/analysis.py         |  18 ++
 python/tvm/relay/analysis/sparse_dense.py     |  93 ++++++++++
 .../relay/data_dep_optimization/__init__.py   |  21 +++
 .../relay/data_dep_optimization/bsr_dense.py  |  57 +++++++
 .../simplify_fc_transpose.py                  |  60 +++++++
 .../tvm/relay/data_dep_optimization/utils.py  |  40 +++++
 python/tvm/relay/transform/transform.py       |  40 +++++
 src/relay/transforms/convert_sparse_dense.cc  | 159 ++++++++++++++++++
 src/relay/transforms/simplify_fc_transpose.cc | 154 +++++++++++++++++
 .../relay/test_simplify_fc_transpose.py       |  67 ++++++++
 .../python/relay/test_sparse_dense_convert.py |  86 ++++++++++
 14 files changed, 798 insertions(+)
 create mode 100644 python/tvm/relay/analysis/sparse_dense.py
 create mode 100644 python/tvm/relay/data_dep_optimization/__init__.py
 create mode 100644 python/tvm/relay/data_dep_optimization/bsr_dense.py
 create mode 100644 python/tvm/relay/data_dep_optimization/simplify_fc_transpose.py
 create mode 100644 python/tvm/relay/data_dep_optimization/utils.py
 create mode 100644 src/relay/transforms/convert_sparse_dense.cc
 create mode 100644 src/relay/transforms/simplify_fc_transpose.cc
 create mode 100644 tests/python/relay/test_simplify_fc_transpose.py
 create mode 100644 tests/python/relay/test_sparse_dense_convert.py

diff --git a/python/setup.py b/python/setup.py
index 62f374923714..fb126ec24e58 100644
--- a/python/setup.py
+++ b/python/setup.py
@@ -156,6 +156,7 @@ def get_package_data_files():
       zip_safe=False,
       install_requires=[
         'numpy',
+        'scipy',
         'decorator',
         'attrs',
         'psutil',
diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py
index 4e520198664c..4663866b1452 100644
--- a/python/tvm/relay/__init__.py
+++ b/python/tvm/relay/__init__.py
@@ -53,6 +53,7 @@
 from . import frontend
 from . import backend
 from . import quantize
+from . import data_dep_optimization
 
 # Dialects
 from . import qnn
diff --git a/python/tvm/relay/analysis/__init__.py b/python/tvm/relay/analysis/__init__.py
index a1833c3c08b2..e5b21cb107f5 100644
--- a/python/tvm/relay/analysis/__init__.py
+++ b/python/tvm/relay/analysis/__init__.py
@@ -28,3 +28,4 @@
 
 # Feature
 from . import feature
+from . import sparse_dense
diff --git a/python/tvm/relay/analysis/analysis.py b/python/tvm/relay/analysis/analysis.py
index 21f3edfb99eb..c237859eb987 100644
--- a/python/tvm/relay/analysis/analysis.py
+++ b/python/tvm/relay/analysis/analysis.py
@@ -333,3 +333,21 @@ def extract_fused_functions(mod):
     for hash_, func in ret_mod.functions.items():
         ret[hash_] = func
     return ret
+
+
+def search_fc_transpose(expr):
+    """Search fc weight name in the patten: y = nn.dense(x, transpose(w, [1, 0]))
+
+    This function is used in the data_dep_optimization.simplify_fc_transpose method
+
+    Parameters
+    ----------
+    expr : tvm.relay.Expr
+
+    Returns
+    -------
+    ret : Array[String]
+        Array of weight variable name in pattern y = nn.dense(x, transpose(w, [1, 0]))
+    """
+    ret = _ffi_api.search_fc_transpose(expr)
+    return ret
diff --git a/python/tvm/relay/analysis/sparse_dense.py b/python/tvm/relay/analysis/sparse_dense.py
new file mode 100644
index 000000000000..7e8f4345e336
--- /dev/null
+++ b/python/tvm/relay/analysis/sparse_dense.py
@@ -0,0 +1,93 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=no-else-return
+# pylint: disable=unidiomatic-typecheck
+"""
+This file contains helper functions for convert dense model
+to block sparse model
+"""
+from collections import namedtuple
+import numpy as np
+import scipy.sparse as sp
+import tvm
+from . import _ffi_api
+
+
+SparseAnalysisResult = namedtuple("SparseAnalysisResult", [
+    "weight_name",
+    "weight_shape",
+])
+
+def _search_dense_op_weight(expr):
+    """Search name of weight in all ```nn.dense``` operator
+       This is a helpful function to determine which param need
+       to be converted to sparse
+
+    Parameters
+    ----------
+    expr : relay.Expr
+        Expr will be searched
+
+    Returns
+    -------
+    ret : Array[String]
+        name of weight in all ``nn.dense``` operator
+    """
+    return _ffi_api.search_dense_op_weight(expr)
+
+
+def process_params(expr, params, block_size, sparsity_threshold):
+    """[summary]
+
+    Parameters
+    ----------
+    expr : Relay.Expr
+        Expr of the network
+    params : Dict[String, tvm.nd.array]
+        parameters of the network
+    block_size : Tuple(int, int)
+        Blocksize in BSR matrix
+    sparsity_threshold : float
+        Minimal sparsity requirement for converting to sparse operation
+
+    Returns
+    -------
+    ret : Namedtuple[weight_name: Array[String], weight_shape: Array[Array[IntImm]]]
+        return names of qualified dense weight and the shape in BSR format
+    """
+    memo = SparseAnalysisResult(weight_name=[], weight_shape=[])
+    weight_names = _search_dense_op_weight(expr)
+    for name in weight_names:
+        name = str(name)
+        w_np = params[name].asnumpy()
+        sparsity = 1.0 - (np.count_nonzero(w_np) / w_np.size)
+        if sparsity >= sparsity_threshold:
+            sparse_weight = sp.bsr_matrix(w_np, blocksize=block_size)
+            # remove dense weight
+            del params[name]
+            memo.weight_name.append(name)
+            memo.weight_shape.append(list(sparse_weight.data.shape) +
+                                     list(sparse_weight.indices.shape) +
+                                     list(sparse_weight.indptr.shape))
+            params[name + ".data"] = tvm.nd.array(sparse_weight.data)
+            params[name + ".indices"] = tvm.nd.array(sparse_weight.indices)
+            params[name + ".indptr"] = tvm.nd.array(sparse_weight.indptr)
+    ret = SparseAnalysisResult(
+        weight_name=tvm.runtime.convert(memo.weight_name),
+        weight_shape=tvm.runtime.convert(memo.weight_shape)
+    )
+    return ret
diff --git a/python/tvm/relay/data_dep_optimization/__init__.py b/python/tvm/relay/data_dep_optimization/__init__.py
new file mode 100644
index 000000000000..ab0caa20f0bb
--- /dev/null
+++ b/python/tvm/relay/data_dep_optimization/__init__.py
@@ -0,0 +1,21 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#pylint: disable=unused-argument, not-context-manager
+"""Optimizations involves changing of paramters"""
+
+from . import bsr_dense
+from . import simplify_fc_transpose
diff --git a/python/tvm/relay/data_dep_optimization/bsr_dense.py b/python/tvm/relay/data_dep_optimization/bsr_dense.py
new file mode 100644
index 000000000000..cc3e5deb302e
--- /dev/null
+++ b/python/tvm/relay/data_dep_optimization/bsr_dense.py
@@ -0,0 +1,57 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#pylint: disable=unused-argument, not-context-manager
+"""Automatic convert model from dense to block sparse"""
+
+from tvm import relay
+from tvm.relay.analysis.sparse_dense import process_params
+
+from .utils import _run_opt_pass
+
+def convert(func, params, blocksize, sparsity_threshold):
+    """Convert a dense func and according parameters to block sparse
+
+    Parameters
+    ----------
+    func : relay.Expr
+        Expr will be optimized to sparse operation
+    params : Dict[Srting, tvm.nd.array]
+        Parameters of the Expr
+    blocksize : Tuple(int, int)
+        Blocksize for BSR matrix
+    sparsity_threshold : float
+        Minimal sparsity requirement for converting.
+        If weight sparsity is lower than this threshold,
+        the dense operation will be kept.
+
+    Returns
+    -------
+    new_func: relay.Expr
+        Mutated Expr with sparse operations
+
+    params: Dict[Srting, tvm.nd.array]
+        New params with BSR matrix for mutated Expr
+    """
+    weight_info = process_params(func, params, blocksize, sparsity_threshold)
+    new_func = _run_opt_pass(
+        func,
+        relay.transform.DenseToSparse(
+            weight_info.weight_name,
+            weight_info.weight_shape
+        )
+    )
+    return new_func, params
diff --git a/python/tvm/relay/data_dep_optimization/simplify_fc_transpose.py b/python/tvm/relay/data_dep_optimization/simplify_fc_transpose.py
new file mode 100644
index 000000000000..345c579499f5
--- /dev/null
+++ b/python/tvm/relay/data_dep_optimization/simplify_fc_transpose.py
@@ -0,0 +1,60 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#pylint: disable=unused-argument, not-context-manager
+"""Automatic optimize fc tranpose"""
+import numpy as np
+
+import tvm
+from tvm import relay
+from tvm.relay.analysis import search_fc_transpose
+
+from .utils import _run_opt_pass
+
+
+def convert(func, params):
+    """convert all ```y = nn.dense(x, transpose(w, [1, 0]))``` to
+        ```y = nn.dense(x, wt)```
+
+    Parameters
+    ----------
+    func : relay.Expr
+        Expr will be optimized
+    params : Dict[String, tvm.nd.array]
+        Parameters of Expr
+
+    Returns
+    -------
+    new_func : relay.Expr
+        Mutated Expr from ```y = nn.dense(x, transpose(w, [1, 0]))``` to
+        ```y = nn.dense(x, wt)```
+    params: Dict[String, tvm.nd.array]
+        Parameters of mutated Expr, with weights pre-transposed
+    """
+    weight_info = search_fc_transpose(func)
+    for item in weight_info:
+        name = str(item)
+        w_np = params[name].asnumpy()
+        new_w = np.transpose(w_np, axes=[1, 0])
+        params[name + ".T"] = tvm.nd.array(new_w)
+        del params[name]
+    new_func = _run_opt_pass(
+        func,
+        relay.transform.SimplifyFCTranspose(
+            weight_info,
+        )
+    )
+    return new_func, params
diff --git a/python/tvm/relay/data_dep_optimization/utils.py b/python/tvm/relay/data_dep_optimization/utils.py
new file mode 100644
index 000000000000..6b46f815474a
--- /dev/null
+++ b/python/tvm/relay/data_dep_optimization/utils.py
@@ -0,0 +1,40 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#pylint: disable=unused-argument, not-context-manager
+"""Utils functions for optimizations"""
+
+import tvm
+
+def _run_opt_pass(expr, opt_pass):
+    """Helper function to run pass
+
+    Parameters
+    ----------
+    expr : relay.Expr
+        Expr will be optimized
+    opt_pass : relay.Pass
+        Optimization pass
+
+    Returns
+    -------
+    ret: relay.Expr
+        Optimized Expr by running opt_pass
+    """
+    assert isinstance(opt_pass, tvm.transform.Pass)
+    mod = tvm.IRModule.from_expr(expr)
+    mod = opt_pass(mod)
+    return mod["main"]
diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py
index 292c5fd39acb..647e999f647a 100644
--- a/python/tvm/relay/transform/transform.py
+++ b/python/tvm/relay/transform/transform.py
@@ -839,3 +839,43 @@ def visit_var(self, var):
                     return relay.Var(var.name_hint, relay.TensorType(new_shape, ty.dtype))
                 return var
         return ChangeBatchMutator().visit(func)
+
+
+def DenseToSparse(weight_name, weight_shape):
+    """
+    Rewrite qualified ```nn.dense operation``` to ```nn.sparse_dense```
+    This pass is used in ```data_dep_optimization.bsr_dense```
+    Parameters of this pass is generated by ```analysis.sparse_dense.process_params```
+
+    Parameters
+    ----------
+    weight_name: Array[String]
+      Names of weights which qualified sparse contrains
+
+    weight_shape: Array[Array[IntImm]]
+      Weights shape in BSR format.
+
+    Returns
+    -------
+    ret : tvm.transform.Pass
+        The registered DenseToSparse pass.
+    """
+    return _ffi_api.DenseToSparse(weight_name, weight_shape)
+
+def SimplifyFCTranspose(target_weight_name):
+    """
+    Rewrite ```y = nn.dense(x, transpose(w, [1, 0]))``` to ```y = nn.dense(x, wt)```
+    This pass is used in ```data_dep_optimization.simplify_fc_transpose```
+
+    Parameters
+    ----------
+    weight_name: Array[String]
+      Names of weights which qualified ```y = nn.dense(x, transpose(w, [1, 0]))```
+      This parameter is generated by ```analysis.search_fc_transpose``` function
+
+    Returns
+    -------
+    ret : tvm.transform.Pass
+        The registered SimplifyFCTranspose pass.
+    """
+    return _ffi_api.SimplifyFCTranspose(target_weight_name)
diff --git a/src/relay/transforms/convert_sparse_dense.cc b/src/relay/transforms/convert_sparse_dense.cc
new file mode 100644
index 000000000000..1b83e7188df2
--- /dev/null
+++ b/src/relay/transforms/convert_sparse_dense.cc
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file convert_sparse_dense.cc
+ *
+ * \brief Mutate dense operator to sparse dense operator
+ */
+#include <tvm/ir/expr.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/attrs/transform.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/relay/transform.h>
+
+#include <unordered_map>
+#include <unordered_set>
+
+namespace tvm {
+namespace relay {
+
+// Search dense op weight name from Expr
+class DenseOpWeightVisitor : private ExprVisitor {
+ public:
+  DenseOpWeightVisitor() : dense_op_(Op::Get("nn.dense")) {}
+
+  Array<String> Search(const Expr& expr) {
+    VisitExpr(expr);
+    return memo_;
+  }
+
+ private:
+  void VisitExpr_(const CallNode* n) final {
+    if (n->op == dense_op_) {
+      const auto weight = n->args[1].as<VarNode>();
+      if (weight) {
+        memo_.push_back(weight->name_hint());
+      }
+    }
+    for (const auto& arg : n->args) {
+      VisitExpr(arg);
+    }
+  }
+  // Cache op
+  const Op& dense_op_;
+
+  Array<String> memo_;
+};  // SearchDenseOpWeight
+
+Array<String> SearchDenseOpWeight(const Expr& e) { return DenseOpWeightVisitor().Search(e); }
+
+TVM_REGISTER_GLOBAL("relay.analysis.search_dense_op_weight").set_body_typed(SearchDenseOpWeight);
+
+// Mutate ```nn.dense``` to ```nn.sparse_dense```
+class DenseToSparseDenseMutator : public ExprRewriter {
+ public:
+  DenseToSparseDenseMutator(const Array<ObjectRef>& weight_name,
+                            const Array<Array<PrimExpr> >& weight_shape)
+      : dense_op_(Op::Get("nn.dense")), sparse_dense_op_(Op::Get("nn.sparse_dense")) {
+    CHECK_EQ(weight_name.size(), weight_shape.size());
+    for (size_t i = 0; i < weight_name.size(); ++i) {
+      CHECK(weight_name[i]->IsInstance<runtime::StringObj>());
+      std::string k = weight_name[i].as<runtime::StringObj>()->data;
+      const auto& ws = weight_shape[i];
+      std::vector<int> v(ws.size());
+      for (size_t j = 0; j < ws.size(); ++j) {
+        v[j] = ws[j].as<IntImmNode>()->value;
+      }
+      target_weights_.emplace(k, v);
+    }
+  }
+
+  Expr Rewrite_(const CallNode* pre, const Expr& post) override {
+    if (pre->op == dense_op_) {
+      const auto weight = pre->args[1].as<VarNode>();
+      if (weight) {
+        if (target_weights_.count(weight->name_hint())) {
+          const auto& prefix = weight->name_hint();
+          const auto& ws = target_weights_.at(prefix);
+          const auto data = post.as<CallNode>()->args[0];
+          auto ws_data_type =
+              relay::TensorType({ws.at(0), ws.at(1), ws.at(2)}, DataType::Float(32));
+          auto ws_indices_type = relay::TensorType({ws.at(3)}, DataType::Int(32));
+          auto ws_indptr_type = relay::TensorType({ws.at(4)}, DataType::Int(32));
+          Var weight_data(prefix + ".data", ws_data_type);
+          Var weight_indices(prefix + ".indices", ws_indices_type);
+          Var weight_indptr(prefix + ".indptr", ws_indptr_type);
+
+          return Call(sparse_dense_op_, {data, weight_data, weight_indices, weight_indptr});
+        }
+      }
+    }
+    return post;
+  }
+
+ private:
+  // Cached op
+  const Op& dense_op_;
+  const Op& sparse_dense_op_;
+  std::unordered_map<std::string, std::vector<int> > target_weights_;
+};  // class DenseToSparseDenseAlter
+
+Expr DenseToSparse(const Expr& e, const Array<ObjectRef>& weight_name,
+                   const Array<Array<PrimExpr> >& weight_shape) {
+  auto rewriter = DenseToSparseDenseMutator(weight_name, weight_shape);
+  return PostOrderRewrite(e, &rewriter);
+}
+
+namespace transform {
+
+Pass DenseToSparse(const Array<ObjectRef>& weight_name,
+                   const Array<Array<PrimExpr> >& weight_shape) {
+  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
+      [=](Function f, IRModule m, PassContext pc) {
+        // Remove FreeVar warnings
+        auto f0 = Downcast<Function>(DenseToSparse(f, weight_name, weight_shape));
+        Array<Var> sparse_params = FreeVars(f0);
+        auto f1 = Function(sparse_params,
+                        f0->body,
+                        f0->ret_type,
+                        f0->type_params,
+                        f0->attrs);
+        Array<Var> params = FreeVars(f1);
+        for (const auto& var : sparse_params) {
+          params.push_back(var);
+        }
+        return Function(params,
+                        f1->body,
+                        f1->ret_type,
+                        f1->type_params,
+                        f1->attrs);
+      };
+  return CreateFunctionPass(pass_func, 4, "DenseToSparse", {"DeadCodeElimination"});
+}
+
+TVM_REGISTER_GLOBAL("relay._transform.DenseToSparse").set_body_typed(DenseToSparse);
+
+}  // namespace transform
+
+}  // namespace relay
+}  // namespace tvm
diff --git a/src/relay/transforms/simplify_fc_transpose.cc b/src/relay/transforms/simplify_fc_transpose.cc
new file mode 100644
index 000000000000..6cd77f424d18
--- /dev/null
+++ b/src/relay/transforms/simplify_fc_transpose.cc
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file simplify_fc_transpose.cc
+ *
+ * \brief Mutate ```y = nn.dense(x, tranpose(w, [1, 0]))``` to
+ *        ```y = nn.dense(x, wt)```
+ */
+#include <tvm/ir/expr.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/attrs/transform.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/relay/transform.h>
+
+#include <unordered_map>
+#include <unordered_set>
+
+namespace tvm {
+namespace relay {
+
+// Find name of weight in ```y = nn.dense(x, tranpose(w, [1, 0]))```
+class FCTransposeVisitor : private ExprVisitor {
+ public:
+  FCTransposeVisitor() : dense_op_(Op::Get("nn.dense")), transpose_op_(Op::Get("transpose")) {}
+
+  Array<String> Search(const Expr& expr) {
+    VisitExpr(expr);
+    return memo_;
+  }
+
+ private:
+  void VisitExpr_(const CallNode* n) final {
+    if (n->op == dense_op_) {
+      const auto weight = n->args[1].as<CallNode>();
+      if (weight) {
+        if (weight->op == transpose_op_) {
+          if (weight->args[0].as<VarNode>()) {
+            const auto arg = weight->args[0].as<VarNode>();
+            memo_.push_back(arg->name_hint());
+          }
+        }
+      }
+    }
+    for (const auto& arg : n->args) {
+      VisitExpr(arg);
+    }
+  }
+
+  const Op& dense_op_;
+  const Op& transpose_op_;
+  Array<String> memo_;
+};  // SearchDenseOpWeight
+
+Array<String> SearchFCTranspose(const Expr& e) { return FCTransposeVisitor().Search(e); }
+
+TVM_REGISTER_GLOBAL("relay.analysis.search_fc_transpose").set_body_typed(SearchFCTranspose);
+
+// Mutate ```y = nn.dense(x, tranpose(w, [1, 0]))``` to ```y = nn.dense(x, wt)```
+class FCTransposeMutator : public ExprRewriter {
+ public:
+  explicit FCTransposeMutator(const Array<ObjectRef>& target_weights)
+      : dense_op_(Op::Get("nn.dense")), transpose_op_(Op::Get("transpose")) {
+    for (size_t i = 0; i < target_weights.size(); ++i) {
+      CHECK(target_weights[i]->IsInstance<runtime::StringObj>());
+      std::string k = target_weights[i].as<runtime::StringObj>()->data;
+      target_weights_.emplace(k);
+    }
+  }
+
+  Expr Rewrite_(const CallNode* pre, const Expr& post) override {
+    if (pre->op == dense_op_) {
+      const auto data = post.as<CallNode>()->args[0];
+      const auto weight = pre->args[1].as<CallNode>();
+      if (weight) {
+        if (weight->op == transpose_op_) {
+          const auto arg = weight->args[0];
+          if (arg.as<VarNode>()) {
+            const auto& arg_node = arg.as<VarNode>();
+            CHECK_GT(target_weights_.count(arg_node->name_hint()), 0);
+            const auto& tt = arg_node->type_annotation.as<TensorTypeNode>();
+            auto wt_type = TensorType({tt->shape[1], tt->shape[0]}, tt->dtype);
+            Var wt(arg_node->name_hint() + ".T", wt_type);
+            return Call(dense_op_, {data, wt}, pre->attrs, pre->type_args);
+          }
+        }
+      }
+    }
+    return post;
+  }
+
+ private:
+  // Cached op
+  const Op& dense_op_;
+  const Op& transpose_op_;
+  std::unordered_set<std::string> target_weights_;
+};  // class DenseToSparseDenseAlter
+
+Expr SimplifyFCTranspose(const Expr& e, const Array<ObjectRef>& target_weights) {
+  auto rewriter = FCTransposeMutator(target_weights);
+  return PostOrderRewrite(e, &rewriter);
+}
+
+namespace transform {
+
+Pass SimplifyFCTranspose(const Array<ObjectRef>& target_weights) {
+  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
+      [=](Function f, IRModule m, PassContext pc) {
+        // Remove FreeVar warning
+        auto f0 = Downcast<Function>(SimplifyFCTranspose(f, target_weights));
+        Array<Var> wt_params = FreeVars(f0);
+        auto f1 = Function(wt_params,
+                        f0->body,
+                        f0->ret_type,
+                        f0->type_params,
+                        f0->attrs);
+        Array<Var> params = FreeVars(f1);
+        for (const auto& var : wt_params) {
+          params.push_back(var);
+        }
+        return Function(params,
+                        f1->body,
+                        f1->ret_type,
+                        f1->type_params,
+                        f1->attrs);
+      };
+  return CreateFunctionPass(pass_func, 4, "SimplifyFCTranspose", {"DeadCodeElimination"});
+}
+
+TVM_REGISTER_GLOBAL("relay._transform.SimplifyFCTranspose").set_body_typed(SimplifyFCTranspose);
+
+}  // namespace transform
+
+}  // namespace relay
+}  // namespace tvm
diff --git a/tests/python/relay/test_simplify_fc_transpose.py b/tests/python/relay/test_simplify_fc_transpose.py
new file mode 100644
index 000000000000..537a5a29348c
--- /dev/null
+++ b/tests/python/relay/test_simplify_fc_transpose.py
@@ -0,0 +1,67 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import itertools
+
+import numpy as np
+import scipy.sparse as sp
+
+
+import tvm
+from tvm.ir import IRModule
+from tvm import relay
+from tvm.relay.data_dep_optimization import simplify_fc_transpose
+
+def run_func(func, params, x):
+    with relay.build_config(opt_level=3):
+        graph, lib, new_params = relay.build(func, "llvm", params=params)
+
+    from tvm.contrib import graph_runtime
+    ctx = tvm.cpu(0)
+    dtype = 'float32'
+    m = graph_runtime.create(graph, lib, ctx)
+    # set inputs
+    m.set_input('data', tvm.nd.array(x.astype(dtype)))
+    m.set_input(**new_params)
+    # execute
+    m.run()
+    # get outputs
+    tvm_output = m.get_output(0)
+    return tvm_output.asnumpy()
+
+def test_simplify_fc_transpose():
+    data = relay.var("data", shape=(1, 32), dtype="float32")
+    x = relay.nn.relu(data)
+    w1 = relay.var("w1", shape=(32, 64), dtype="float32")
+    y = relay.nn.dense(x, relay.transpose(w1, axes=[1, 0]))
+    z = relay.nn.relu(y)
+    w2 = relay.var("w2", shape=(64, 16), dtype="float32")
+    zz = relay.nn.dense(z, relay.transpose(w2, axes=[1, 0]))
+    func = relay.Function(relay.analysis.free_vars(zz), zz)
+    params = {
+        "w1": tvm.nd.array(np.random.uniform(-1, 1, (32, 64)).astype("float32")),
+        "w2": tvm.nd.array(np.random.uniform(-1, 1, (64, 16)).astype("float32"))
+    }
+    x_np = np.random.randn(1, 32).astype("float32")
+    old_result = run_func(func, params, x_np)
+
+    new_func, new_params = simplify_fc_transpose.convert(func, params)
+    new_result = run_func(new_func, new_params, x_np)
+    np.testing.assert_allclose(old_result, new_result, atol=1e-5, rtol=1e-5)
+
+if __name__ == "__main__":
+    test_simplify_fc_transpose()
diff --git a/tests/python/relay/test_sparse_dense_convert.py b/tests/python/relay/test_sparse_dense_convert.py
new file mode 100644
index 000000000000..c4f0572c0482
--- /dev/null
+++ b/tests/python/relay/test_sparse_dense_convert.py
@@ -0,0 +1,86 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import itertools
+
+import numpy as np
+import scipy.sparse as sp
+
+
+import tvm
+from tvm.ir import IRModule
+from tvm import relay
+
+
+def random_bsr_matrix(M, N, BS_R, BS_C, density, dtype="float32"):
+    Y = np.zeros((M, N), dtype=dtype)
+    assert M % BS_R == 0
+    assert N % BS_C == 0
+    nnz = int(density * M * N)
+    num_blocks = int(nnz / (BS_R * BS_C)) + 1
+    candidate_blocks = np.asarray(list(itertools.product(range(0, M, BS_R), range(0, N, BS_C))))
+    assert candidate_blocks.shape[0] == M // BS_R * N // BS_C
+    chosen_blocks = candidate_blocks[np.random.choice(candidate_blocks.shape[0], size=num_blocks, replace=False)]
+    for i in range(len(chosen_blocks)):
+        r, c = chosen_blocks[i]
+        Y[r:r+BS_R,c:c+BS_C] = np.random.randn(BS_R, BS_C)
+    s = sp.bsr_matrix(Y, blocksize=(BS_R, BS_C))
+    assert s.data.shape == (num_blocks, BS_R, BS_C)
+    assert s.data.size >= nnz
+    assert s.indices.shape == (num_blocks, )
+    assert s.indptr.shape == (M // BS_R + 1, )
+    return s
+
+def run_func(func, params, x):
+    with relay.build_config(opt_level=3):
+        graph, lib, new_params = relay.build(func, "llvm", params=params)
+
+    from tvm.contrib import graph_runtime
+    ctx = tvm.cpu(0)
+    dtype = 'float32'
+    m = graph_runtime.create(graph, lib, ctx)
+    # set inputs
+    m.set_input('data', tvm.nd.array(x.astype(dtype)))
+    m.set_input(**new_params)
+    # execute
+    m.run()
+    # get outputs
+    tvm_output = m.get_output(0)
+    return tvm_output.asnumpy()
+
+def test_bsr_sparse_dense():
+    data = relay.var("data", shape=(1, 128), dtype="float32")
+    x = relay.nn.relu(data)
+    w = relay.var("weight", shape=(768, 128), dtype="float32")
+    y = relay.nn.dense(x, w)
+    z = relay.nn.relu(y)
+    func = relay.Function(relay.analysis.free_vars(z), z)
+
+    params = {
+        "weight": tvm.nd.array(random_bsr_matrix(768, 128, 32, 1, 0.1).todense())
+    }
+
+    x_np = np.random.randn(1, 128).astype("float32")
+    # dense output
+    dense_output = run_func(func, params, x_np)
+    # sparse
+    sparse_func, params = relay.data_dep_optimization.bsr_dense.convert(func, params, (32, 1), 0.2)
+    sparse_output = run_func(sparse_func, params, x_np)
+    np.testing.assert_allclose(sparse_output, dense_output, atol=1e-5, rtol=1e-5)
+
+if __name__ == "__main__":
+    test_bsr_sparse_dense()