Skip to content

Commit

Permalink
[mlir][sparse][xla] Legalize sparse_tensor::Pack/UnpackOp to custom c…
Browse files Browse the repository at this point in the history
…alls before the translation to HLO.

PiperOrigin-RevId: 565749097
  • Loading branch information
tensorflower-gardener authored and TensorFlow MLIR Team committed Sep 15, 2023
1 parent e5c675a commit 3ef7ad5
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 18 deletions.
45 changes: 29 additions & 16 deletions mhlo/transforms/legalize_sparse_ops/legalize_sparse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include <utility>

#include "llvm/ADT/STLExtras.h"
#include "mhlo/IR/hlo_ops.h"
#include "mhlo/transforms/passes.h"
#include "mhlo/transforms/rewriters.h"
#include "mhlo/utils/legalize_to_linalg_utils.h"
Expand Down Expand Up @@ -68,32 +69,44 @@ struct LegalizeSparseOpsPass
ConversionTarget target(*ctx);
mhlo::RemoveSignTypeConverter typeConverter;
if (legalize_to_custom_calls_) {
mhlo::populateLegalizeSparseOpsToCustomCallPatterns(ctx, typeConverter,
&patterns);
setupLegalizeToCustomCallPatterns(ctx, &patterns, typeConverter, target);
} else {
mhlo::populateLegalizeSparseCHLOPatterns(ctx, typeConverter, &patterns);
setupLegalizeSparseCHLOPatterns(ctx, &patterns, typeConverter, target);
}
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
return signalPassFailure();
}
}

private:
static bool isNotSparseOp(Operation* op) {
return !sparse_tensor::hasAnySparseOperandOrResult(op);
}

static void setupLegalizeToCustomCallPatterns(MLIRContext* ctx,
RewritePatternSet* patterns,
TypeConverter& typeConverter,
ConversionTarget& target) {
mhlo::populateLegalizeSparseOpsToCustomCallPatterns(ctx, typeConverter,
patterns);
target.addIllegalDialect<sparse_tensor::SparseTensorDialect>();
target.addLegalOp<mhlo::CustomCallOp>();
}

static void setupLegalizeSparseCHLOPatterns(MLIRContext* ctx,
RewritePatternSet* patterns,
TypeConverter& typeConverter,
ConversionTarget& target) {
mhlo::populateLegalizeSparseCHLOPatterns(ctx, typeConverter, patterns);
target.addLegalDialect<bufferization::BufferizationDialect,
linalg::LinalgDialect, tensor::TensorDialect,
sparse_tensor::SparseTensorDialect>();
/// The unary operation is sparse computation if either the input or the
/// result is a sparse tensor.
/// TODO(bixia): Remove the convert of such sparse CHLO ops from
/// chlo_legalize_to_hlo.
auto isNotSparseOp = [](Operation* op) {
auto encDst =
sparse_tensor::getSparseTensorEncoding(op->getResult(0).getType());
auto encSrc =
sparse_tensor::getSparseTensorEncoding(op->getOperand(0).getType());
return !encDst && !encSrc;
};
target.addDynamicallyLegalOp<chlo::AsinOp, chlo::AsinhOp, chlo::AtanOp,
chlo::AtanhOp, chlo::BesselI1eOp, chlo::SinhOp,
chlo::TanOp>(isNotSparseOp);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
return signalPassFailure();
}
}
};

Expand Down
40 changes: 38 additions & 2 deletions mhlo/transforms/legalize_sparse_ops/sparse_ops_to_custom_calls.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,53 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "mhlo/IR/hlo_ops.h"
#include "mhlo/transforms/rewriters.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"

namespace mlir {
namespace {

StringAttr getOperationTargetName(Operation* op) {
// Strips off `dialect` from `dialect.opName`.
StringRef opName = op->getName().getIdentifier().strref().split(".").second;
return StringAttr::get(op->getContext(), "sparse_tensor_" + opName);
}

} // namespace
namespace mhlo {

template <typename OpTy>
class SparseOpToCustomCallConverter : public OpConversionPattern<OpTy> {
public:
using OpConversionPattern<OpTy>::OpConversionPattern;

LogicalResult matchAndRewrite(
OpTy op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter& rewriter) const final {
NamedAttribute callTargetName =
rewriter.getNamedAttr("call_target_name", getOperationTargetName(op));
rewriter.replaceOpWithNewOp<mhlo::CustomCallOp>(op, op->getResultTypes(),
adaptor.getOperands(),
ArrayRef{callTargetName});
return success();
}
};

void populateLegalizeSparseOpsToCustomCallPatterns(
MLIRContext* /*context*/, TypeConverter& /*typeConverter*/,
RewritePatternSet* /*patterns*/) {}
MLIRContext* context, TypeConverter& typeConverter,
RewritePatternSet* patterns) {
patterns->add<SparseOpToCustomCallConverter<sparse_tensor::PackOp>,
SparseOpToCustomCallConverter<sparse_tensor::UnpackOp>>(
typeConverter, context);
}

} // namespace mhlo
} // namespace mlir

0 comments on commit 3ef7ad5

Please sign in to comment.