Skip to content

Commit

Permalink
[flang][cuda] Make launch configuration optional for cuf kernel (#115947
Browse files Browse the repository at this point in the history
)
  • Loading branch information
clementval authored Nov 13, 2024
1 parent 01d233f commit 37143fe
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 43 deletions.
1 change: 1 addition & 0 deletions flang/include/flang/Parser/dump-parse-tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ class ParseTreeDumper {
NODE(parser, CUFKernelDoConstruct)
NODE(CUFKernelDoConstruct, StarOrExpr)
NODE(CUFKernelDoConstruct, Directive)
NODE(CUFKernelDoConstruct, LaunchConfiguration)
NODE(parser, CUFReduction)
NODE(parser, CycleStmt)
NODE(parser, DataComponentDefStmt)
Expand Down
11 changes: 8 additions & 3 deletions flang/include/flang/Parser/parse-tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -4527,12 +4527,17 @@ struct CUFReduction {
struct CUFKernelDoConstruct {
TUPLE_CLASS_BOILERPLATE(CUFKernelDoConstruct);
WRAPPER_CLASS(StarOrExpr, std::optional<ScalarIntExpr>);
struct LaunchConfiguration {
TUPLE_CLASS_BOILERPLATE(LaunchConfiguration);
std::tuple<std::list<StarOrExpr>, std::list<StarOrExpr>,
std::optional<ScalarIntExpr>>
t;
};
struct Directive {
TUPLE_CLASS_BOILERPLATE(Directive);
CharBlock source;
std::tuple<std::optional<ScalarIntConstantExpr>, std::list<StarOrExpr>,
std::list<StarOrExpr>, std::optional<ScalarIntExpr>,
std::list<CUFReduction>>
std::tuple<std::optional<ScalarIntConstantExpr>,
std::optional<LaunchConfiguration>, std::list<CUFReduction>>
t;
};
std::tuple<Directive, std::optional<DoConstruct>> t;
Expand Down
69 changes: 38 additions & 31 deletions flang/lib/Lower/Bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2862,14 +2862,11 @@ class FirConverter : public Fortran::lower::AbstractConverter {
if (nestedLoops > 1)
n = builder->getIntegerAttr(builder->getI64Type(), nestedLoops);

const std::list<Fortran::parser::CUFKernelDoConstruct::StarOrExpr> &grid =
std::get<1>(dir.t);
const std::list<Fortran::parser::CUFKernelDoConstruct::StarOrExpr> &block =
std::get<2>(dir.t);
const std::optional<Fortran::parser::ScalarIntExpr> &stream =
std::get<3>(dir.t);
const auto &launchConfig = std::get<std::optional<
Fortran::parser::CUFKernelDoConstruct::LaunchConfiguration>>(dir.t);

const std::list<Fortran::parser::CUFReduction> &cufreds =
std::get<4>(dir.t);
std::get<2>(dir.t);

llvm::SmallVector<mlir::Value> reduceOperands;
llvm::SmallVector<mlir::Attribute> reduceAttrs;
Expand Down Expand Up @@ -2913,35 +2910,45 @@ class FirConverter : public Fortran::lower::AbstractConverter {
builder->createIntegerConstant(loc, builder->getI32Type(), 0);

llvm::SmallVector<mlir::Value> gridValues;
if (!isOnlyStars(grid)) {
for (const Fortran::parser::CUFKernelDoConstruct::StarOrExpr &expr :
grid) {
if (expr.v) {
gridValues.push_back(fir::getBase(
genExprValue(*Fortran::semantics::GetExpr(*expr.v), stmtCtx)));
} else {
gridValues.push_back(zero);
llvm::SmallVector<mlir::Value> blockValues;
mlir::Value streamValue;

if (launchConfig) {
const std::list<Fortran::parser::CUFKernelDoConstruct::StarOrExpr> &grid =
std::get<0>(launchConfig->t);
const std::list<Fortran::parser::CUFKernelDoConstruct::StarOrExpr>
&block = std::get<1>(launchConfig->t);
const std::optional<Fortran::parser::ScalarIntExpr> &stream =
std::get<2>(launchConfig->t);
if (!isOnlyStars(grid)) {
for (const Fortran::parser::CUFKernelDoConstruct::StarOrExpr &expr :
grid) {
if (expr.v) {
gridValues.push_back(fir::getBase(
genExprValue(*Fortran::semantics::GetExpr(*expr.v), stmtCtx)));
} else {
gridValues.push_back(zero);
}
}
}
}
llvm::SmallVector<mlir::Value> blockValues;
if (!isOnlyStars(block)) {
for (const Fortran::parser::CUFKernelDoConstruct::StarOrExpr &expr :
block) {
if (expr.v) {
blockValues.push_back(fir::getBase(
genExprValue(*Fortran::semantics::GetExpr(*expr.v), stmtCtx)));
} else {
blockValues.push_back(zero);
if (!isOnlyStars(block)) {
for (const Fortran::parser::CUFKernelDoConstruct::StarOrExpr &expr :
block) {
if (expr.v) {
blockValues.push_back(fir::getBase(
genExprValue(*Fortran::semantics::GetExpr(*expr.v), stmtCtx)));
} else {
blockValues.push_back(zero);
}
}
}

if (stream)
streamValue = builder->createConvert(
loc, builder->getI32Type(),
fir::getBase(
genExprValue(*Fortran::semantics::GetExpr(*stream), stmtCtx)));
}
mlir::Value streamValue;
if (stream)
streamValue = builder->createConvert(
loc, builder->getI32Type(),
fir::getBase(
genExprValue(*Fortran::semantics::GetExpr(*stream), stmtCtx)));

const auto &outerDoConstruct =
std::get<std::optional<Fortran::parser::DoConstruct>>(kernel.t);
Expand Down
10 changes: 7 additions & 3 deletions flang/lib/Parser/executable-parsers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -563,11 +563,15 @@ TYPE_PARSER(("REDUCTION"_tok || "REDUCE"_tok) >>
parenthesized(construct<CUFReduction>(Parser<CUFReduction::Operator>{},
":" >> nonemptyList(scalar(variable)))))

TYPE_PARSER("<<<" >>
construct<CUFKernelDoConstruct::LaunchConfiguration>(gridOrBlock,
"," >> gridOrBlock,
maybe((", 0 ,"_tok || ", STREAM ="_tok) >> scalarIntExpr) / ">>>"))

TYPE_PARSER(sourced(beginDirective >> "$CUF KERNEL DO"_tok >>
construct<CUFKernelDoConstruct::Directive>(
maybe(parenthesized(scalarIntConstantExpr)), "<<<" >> gridOrBlock,
"," >> gridOrBlock,
maybe((", 0 ,"_tok || ", STREAM ="_tok) >> scalarIntExpr) / ">>>",
maybe(parenthesized(scalarIntConstantExpr)),
maybe(Parser<CUFKernelDoConstruct::LaunchConfiguration>{}),
many(Parser<CUFReduction>{}) / endDirective)))
TYPE_CONTEXT_PARSER("!$CUF KERNEL DO construct"_en_US,
extension<LanguageFeature::CUDA>(construct<CUFKernelDoConstruct>(
Expand Down
16 changes: 10 additions & 6 deletions flang/lib/Parser/unparse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2932,11 +2932,9 @@ class UnparseVisitor {
Word("*");
}
}
void Unparse(const CUFKernelDoConstruct::Directive &x) {
Word("!$CUF KERNEL DO");
Walk(" (", std::get<std::optional<ScalarIntConstantExpr>>(x.t), ")");
void Unparse(const CUFKernelDoConstruct::LaunchConfiguration &x) {
Word(" <<<");
const auto &grid{std::get<1>(x.t)};
const auto &grid{std::get<0>(x.t)};
if (grid.empty()) {
Word("*");
} else if (grid.size() == 1) {
Expand All @@ -2945,18 +2943,24 @@ class UnparseVisitor {
Walk("(", grid, ",", ")");
}
Word(",");
const auto &block{std::get<2>(x.t)};
const auto &block{std::get<1>(x.t)};
if (block.empty()) {
Word("*");
} else if (block.size() == 1) {
Walk(block.front());
} else {
Walk("(", block, ",", ")");
}
if (const auto &stream{std::get<3>(x.t)}) {
if (const auto &stream{std::get<2>(x.t)}) {
Word(",STREAM="), Walk(*stream);
}
Word(">>>");
}
void Unparse(const CUFKernelDoConstruct::Directive &x) {
Word("!$CUF KERNEL DO");
Walk(" (", std::get<std::optional<ScalarIntConstantExpr>>(x.t), ")");
Walk(std::get<std::optional<CUFKernelDoConstruct::LaunchConfiguration>>(
x.t));
Walk(" ", std::get<std::list<CUFReduction>>(x.t), " ");
Word("\n");
}
Expand Down
3 changes: 3 additions & 0 deletions flang/test/Parser/cuf-sanity-common
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ module m
!$cuf kernel do <<<1, (2, 3), stream = 1>>>
do j = 1, 10
end do
!$cuf kernel do
do j = 1, 10
end do
!$cuf kernel do <<<*, *>>> reduce(+:x,y) reduce(*:z)
do j = 1, 10
x = x + a(j)
Expand Down

0 comments on commit 37143fe

Please sign in to comment.