Skip to content

Commit

Permalink
Merge pull request iree-org#5083 from google/benvanik-flow-passes
Browse files Browse the repository at this point in the history
  • Loading branch information
benvanik authored Mar 12, 2021
2 parents 0668b55 + bb819df commit 1ef76ba
Show file tree
Hide file tree
Showing 20 changed files with 343 additions and 1,195 deletions.
3 changes: 1 addition & 2 deletions iree/compiler/Conversion/LinalgToSPIRV/VectorizeMemref.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ static bool getUsesIfAllTransferOp(Value v,
return true;
}


/// Returns the bitwidth of a scalar or vector type.
static Optional<unsigned> getBitWidth(Type type) {
if (type.isIntOrFloat()) {
Expand Down Expand Up @@ -370,7 +369,7 @@ class ProcessPlaceHolder final
auto vecMemRef = getVectorizedMemRefType(rewriter, placeholder.getResult());
if (!vecMemRef) return failure();
rewriter.replaceOpWithNewOp<IREE::PlaceholderOp>(
placeholder, *vecMemRef, ValueRange(), placeholder.getAttrs());
placeholder, *vecMemRef, ValueRange(), placeholder->getAttrs());
return success();
}
};
Expand Down
45 changes: 24 additions & 21 deletions iree/compiler/Dialect/Flow/Conversion/TypeConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,40 +23,43 @@ namespace iree_compiler {
FlowTypeConverter::FlowTypeConverter() {
// Allow types through by default.
addConversion([](Type type) { return type; });
// addConversion([](IndexType type) {
// // Always treat as 32-bit.
// return IntegerType::get(32, type.getContext());
// });

addConversion([this](RankedTensorType tensorType) -> Optional<Type> {
auto convertedElementType = convertType(tensorType.getElementType());
if (!convertedElementType) {
return llvm::None;
}
return RankedTensorType::get(tensorType.getShape(), convertedElementType);
});

addConversion([](UnrankedTensorType tensorType) {
// We only support ranked tensors. We could convert unranked to ranked
// here for certain cases (such as * on the LHS).
return Type();
});

// UNSAFE: narrow 64-bit integers to 32-bit ones.
// This is a workaround for lower levels of the stack not always supporting
// 64-bit types natively.
// TODO(benvanik): make whether to narrow integers an option.
addConversion([](IntegerType integerType) -> Optional<Type> {
if (integerType.isSignlessInteger() && integerType.getWidth() > 32) {
// Don't support 64-bit types in general. Rewrite to i32 (if desired).
// TODO(benvanik): split to i32+i32? allow and use availability?
// TODO(benvanik): make an option.
return IntegerType::get(integerType.getContext(), 32);
}
return llvm::None;
});

// UNSAFE: narrow 64-bit floats to 32-bit ones.
// This is a workaround for lower levels of the stack not always supporting
// 64-bit types natively.
// TODO(benvanik): make whether to narrow floats an option.
addConversion([](FloatType floatType) -> Optional<Type> {
if (floatType.getWidth() > 32) {
// Don't support 64-bit types in general. Rewrite to f32 (if desired).
// TODO(benvanik): make an option.
return FloatType::getF32(floatType.getContext());
}
return llvm::None;
});
addConversion([this](RankedTensorType tensorType) -> Optional<Type> {
auto convertedElementType = convertType(tensorType.getElementType());
if (!convertedElementType) {
return llvm::None;
}
return RankedTensorType::get(tensorType.getShape(), convertedElementType);
});
addConversion([](UnrankedTensorType tensorType) {
// We only support ranked tensors. We could convert unranked to ranked
// here for certain cases (such as * on the LHS).
return Type();
});
// TODO(b/145876978): add conversion materializer
}

} // namespace iree_compiler
Expand Down
4 changes: 1 addition & 3 deletions iree/compiler/Dialect/Flow/Transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,10 @@ cc_library(
"FormStreams.cpp",
"HLOToHLOPreprocessing.cpp",
"HoistUnstreamableOps.cpp",
"IdentifyDispatchRegions.cpp",
"IdentifyDispatchRegions2.cpp",
"InjectDispatchTracing.cpp",
"LegalizeInputTypes.cpp",
"MaterializeExportedReflection.cpp",
"MergeExportedReflection.cpp",
"MaterializeReflectionAttrs.cpp",
"OutlineDispatchRegions.cpp",
"OutlineDispatchRegions2.cpp",
"OutlineLargeConstants.cpp",
Expand Down
4 changes: 1 addition & 3 deletions iree/compiler/Dialect/Flow/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,10 @@ iree_cc_library(
"FormStreams.cpp"
"HLOToHLOPreprocessing.cpp"
"HoistUnstreamableOps.cpp"
"IdentifyDispatchRegions.cpp"
"IdentifyDispatchRegions2.cpp"
"InjectDispatchTracing.cpp"
"LegalizeInputTypes.cpp"
"MaterializeExportedReflection.cpp"
"MergeExportedReflection.cpp"
"MaterializeReflectionAttrs.cpp"
"OutlineDispatchRegions.cpp"
"OutlineDispatchRegions2.cpp"
"OutlineLargeConstants.cpp"
Expand Down
Loading

0 comments on commit 1ef76ba

Please sign in to comment.