-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][OpenMP] inscan reduction modifier and scan op mlir support #114737
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-llvm @llvm/pr-subscribers-flang-openmp Author: Anchu Rajendran S (anchuraj) ChangesScan directive allows to specify scan reductions within an worksharing loop, worksharing loop simd or simd directive which should have an Patch is 26.06 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/114737.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 886554f66afffc..b45d89463639c5 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -283,6 +283,34 @@ class OpenMP_DoacrossClauseSkip<
def OpenMP_DoacrossClause : OpenMP_DoacrossClauseSkip<>;
+//===----------------------------------------------------------------------===//
+// V5.2: [5.4.7] `exclusive` clause
+//===----------------------------------------------------------------------===//
+
+class OpenMP_ExclusiveClauseSkip<
+ bit traits = false, bit arguments = false, bit assemblyFormat = false,
+ bit description = false, bit extraClassDeclaration = false
+ > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
+ extraClassDeclaration> {
+ let arguments = (ins
+ Variadic<AnyType>:$exclusive_vars
+ );
+
+ let optAssemblyFormat = [{
+ `exclusive` `(` $exclusive_vars `:` type($exclusive_vars) `)`
+ }];
+
+ let description = [{
+ The exclusive clause is used on a separating directive that separates a
+ structured block into two structured block sequences. If it
+ is specified, the input phase excludes the preceding structured block
+ sequence and instead includes the following structured block sequence,
+ while the scan phase includes the preceding structured block sequence.
+ }];
+}
+
+def OpenMP_ExclusiveClause : OpenMP_ExclusiveClauseSkip<>;
+
//===----------------------------------------------------------------------===//
// V5.2: [10.5.1] `filter` clause
//===----------------------------------------------------------------------===//
@@ -393,6 +421,34 @@ class OpenMP_HasDeviceAddrClauseSkip<
def OpenMP_HasDeviceAddrClause : OpenMP_HasDeviceAddrClauseSkip<>;
+//===----------------------------------------------------------------------===//
+// V5.2: [5.4.7] `inclusive` clause
+//===----------------------------------------------------------------------===//
+
+class OpenMP_InclusiveClauseSkip<
+ bit traits = false, bit arguments = false, bit assemblyFormat = false,
+ bit description = false, bit extraClassDeclaration = false
+ > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
+ extraClassDeclaration> {
+ let arguments = (ins
+ Variadic<AnyType>:$inclusive_vars
+ );
+
+ let optAssemblyFormat = [{
+ `inclusive` `(` $inclusive_vars `:` type($inclusive_vars) `)`
+ }];
+
+ let description = [{
+ The inclusive clause is used on a separating directive that separates a
+ structured block into two structured block sequences. If it is specified,
+ the input phase includes the preceding structured block sequence and the
+ scan phase includes the following structured block sequence.
+ }];
+}
+
+def OpenMP_InclusiveClause : OpenMP_InclusiveClauseSkip<>;
+
+
//===----------------------------------------------------------------------===//
// V5.2: [15.1.2] `hint` clause
//===----------------------------------------------------------------------===//
@@ -983,6 +1039,7 @@ class OpenMP_ReductionClauseSkip<
];
let arguments = (ins
+ OptionalAttr<ReductionModifierAttr>:$reduction_mod,
Variadic<OpenMP_PointerLikeType>:$reduction_vars,
OptionalAttr<DenseBoolArrayAttr>:$reduction_byref,
OptionalAttr<SymbolRefArrayAttr>:$reduction_syms
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
index b1a9e3330522b2..23086556bbb2f5 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
@@ -178,6 +178,27 @@ def OrderModifier
def OrderModifierAttr : EnumAttr<OpenMP_Dialect, OrderModifier,
"order_mod">;
+//===----------------------------------------------------------------------===//
+// reduction_modifier enum.
+//===----------------------------------------------------------------------===//
+
+def ReductionModifierInScan : I32EnumAttrCase<"InScan", 0>;
+def ReductionModifierTask : I32EnumAttrCase<"Task", 1>;
+def ReductionModifierDefault : I32EnumAttrCase<"Default", 2>;
+
+def ReductionModifier : OpenMP_I32EnumAttr<
+ "ReductionModifier",
+ "reduction modifier", [
+ ReductionModifierInScan,
+ ReductionModifierTask,
+ ReductionModifierDefault
+ ]>;
+
+def ReductionModifierAttr : OpenMP_EnumAttr<ReductionModifier,
+ "reduction_modifier"> {
+ let assemblyFormat = "`(` $value `)`";
+}
+
//===----------------------------------------------------------------------===//
// sched_mod enum.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 626539cb7bde42..a03f18a816c39e 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -170,7 +170,7 @@ def ParallelOp : OpenMP_Op<"parallel", traits = [
let assemblyFormat = clausesAssemblyFormat # [{
custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
- $private_syms, $reduction_vars, type($reduction_vars), $reduction_byref,
+ $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
$reduction_syms) attr-dict
}];
@@ -215,7 +215,7 @@ def TeamsOp : OpenMP_Op<"teams", traits = [
let assemblyFormat = clausesAssemblyFormat # [{
custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
- $private_syms, $reduction_vars, type($reduction_vars), $reduction_byref,
+ $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
$reduction_syms) attr-dict
}];
@@ -274,7 +274,7 @@ def SectionsOp : OpenMP_Op<"sections", traits = [
let assemblyFormat = clausesAssemblyFormat # [{
custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
- $private_syms, $reduction_vars, type($reduction_vars), $reduction_byref,
+ $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
$reduction_syms) attr-dict
}];
@@ -422,7 +422,7 @@ def WsloopOp : OpenMP_Op<"wsloop", traits = [
let assemblyFormat = clausesAssemblyFormat # [{
custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
- $private_syms, $reduction_vars, type($reduction_vars), $reduction_byref,
+ $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
$reduction_syms) attr-dict
}];
@@ -476,7 +476,7 @@ def SimdOp : OpenMP_Op<"simd", traits = [
let assemblyFormat = clausesAssemblyFormat # [{
custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
- $private_syms, $reduction_vars, type($reduction_vars), $reduction_byref,
+ $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
$reduction_syms) attr-dict
}];
@@ -680,7 +680,7 @@ def TaskloopOp : OpenMP_Op<"taskloop", traits = [
custom<InReductionPrivateReductionRegion>(
$region, $in_reduction_vars, type($in_reduction_vars),
$in_reduction_byref, $in_reduction_syms, $private_vars,
- type($private_vars), $private_syms, $reduction_vars,
+ type($private_vars), $private_syms, $reduction_mod, $reduction_vars,
type($reduction_vars), $reduction_byref, $reduction_syms) attr-dict
}];
@@ -1560,6 +1560,26 @@ def CancellationPointOp : OpenMP_Op<"cancellation_point", clauses = [
let hasVerifier = 1;
}
+def ScanOp : OpenMP_Op<"scan", [
+ AttrSizedOperandSegments, RecipeInterface, IsolatedFromAbove
+ ], clauses = [
+ OpenMP_InclusiveClause, OpenMP_ExclusiveClause]> {
+ let summary = "scan directive";
+ let description = [{
+ The scan directive allows to specify scan reduction. Scan directive
+ should be enclosed with in a parent directive along with which , a
+ reduction clause with `InScan` modifier must be specified. Scan directive
+ allows to separate code blocks to input phase and scan phase in the region
+ enclosed by the parent.
+ }] # clausesDescription;
+
+ let builders = [
+ OpBuilder<(ins CArg<"const ScanOperands &">:$clauses)>
+ ];
+
+ let hasVerifier = 1;
+}
+
//===----------------------------------------------------------------------===//
// 2.19.5.7 declare reduction Directive
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index aa241b91d758ca..233739e1d6d917 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -451,6 +451,7 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
/* private_vars = */ ValueRange(),
/* private_syms = */ nullptr,
/* proc_bind_kind = */ omp::ClauseProcBindKindAttr{},
+ /* reduction_mod = */ nullptr,
/* reduction_vars = */ llvm::SmallVector<Value>{},
/* reduction_byref = */ DenseBoolArrayAttr{},
/* reduction_syms = */ ArrayAttr{});
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index e1df647d6a3c71..0ad7fe2c2cf243 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -491,16 +491,27 @@ struct PrivateParseArgs {
SmallVectorImpl<Type> &types, ArrayAttr &syms)
: vars(vars), types(types), syms(syms) {}
};
+
+static ReductionModifierAttr nullReductionMod = nullptr;
struct ReductionParseArgs {
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
SmallVectorImpl<Type> &types;
DenseBoolArrayAttr &byref;
ArrayAttr &syms;
+ ReductionModifierAttr &reductionMod;
+ ReductionParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
+ SmallVectorImpl<Type> &types, DenseBoolArrayAttr &byref,
+ ArrayAttr &syms, ReductionModifierAttr &redMod)
+ : vars(vars), types(types), byref(byref), syms(syms),
+ reductionMod(redMod) {}
ReductionParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
SmallVectorImpl<Type> &types, DenseBoolArrayAttr &byref,
ArrayAttr &syms)
- : vars(vars), types(types), byref(byref), syms(syms) {}
+ : vars(vars), types(types), byref(byref), syms(syms),
+ reductionMod(nullReductionMod) {}
};
+
+// specifies the arguments needs for `reduction` clause
struct AllRegionParseArgs {
std::optional<ReductionParseArgs> inReductionArgs;
std::optional<MapParseArgs> mapArgs;
@@ -517,7 +528,8 @@ static ParseResult parseClauseWithRegionArgs(
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
SmallVectorImpl<Type> &types,
SmallVectorImpl<OpAsmParser::Argument> ®ionPrivateArgs,
- ArrayAttr *symbols = nullptr, DenseBoolArrayAttr *byref = nullptr) {
+ ArrayAttr *symbols = nullptr, DenseBoolArrayAttr *byref = nullptr,
+ ReductionModifierAttr &reductionMod = nullReductionMod) {
SmallVector<SymbolRefAttr> symbolVec;
SmallVector<bool> isByRefVec;
unsigned regionArgOffset = regionPrivateArgs.size();
@@ -525,6 +537,16 @@ static ParseResult parseClauseWithRegionArgs(
if (parser.parseLParen())
return failure();
+ StringRef enumStr;
+ if (succeeded(parser.parseOptionalKeyword("type"))) {
+ if (parser.parseColon() || parser.parseKeyword(&enumStr) ||
+ parser.parseComma())
+ return failure();
+ std::optional<ReductionModifier> enumValue =
+ symbolizeReductionModifier(enumStr);
+ reductionMod = ReductionModifierAttr::get(parser.getContext(), *enumValue);
+ }
+
if (parser.parseCommaSeparatedList([&]() {
if (byref)
isByRefVec.push_back(
@@ -615,15 +637,14 @@ static ParseResult parseBlockArgClause(
if (succeeded(parser.parseOptionalKeyword(keyword))) {
if (!reductionArgs)
return failure();
-
if (failed(parseClauseWithRegionArgs(
parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs,
- &reductionArgs->syms, &reductionArgs->byref)))
+ &reductionArgs->syms, &reductionArgs->byref,
+ reductionArgs->reductionMod)))
return failure();
}
return success();
}
-
static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region ®ion,
AllRegionParseArgs args) {
llvm::SmallVector<OpAsmParser::Argument> entryBlockArgs;
@@ -704,6 +725,7 @@ static ParseResult parseInReductionPrivateReductionRegion(
DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
+ ReductionModifierAttr &reductionMod,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &reductionVars,
SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
ArrayAttr &reductionSyms) {
@@ -712,7 +734,7 @@ static ParseResult parseInReductionPrivateReductionRegion(
inReductionByref, inReductionSyms);
args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
- reductionSyms);
+ reductionSyms, reductionMod);
return parseBlockArgRegion(parser, region, args);
}
@@ -729,13 +751,14 @@ static ParseResult parsePrivateReductionRegion(
OpAsmParser &parser, Region ®ion,
llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
+ ReductionModifierAttr &reductionMod,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &reductionVars,
SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
ArrayAttr &reductionSyms) {
AllRegionParseArgs args;
args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
- reductionSyms);
+ reductionSyms, reductionMod);
return parseBlockArgRegion(parser, region, args);
}
@@ -784,9 +807,12 @@ struct ReductionPrintArgs {
TypeRange types;
DenseBoolArrayAttr byref;
ArrayAttr syms;
+ ReductionModifierAttr reductionMod;
ReductionPrintArgs(ValueRange vars, TypeRange types, DenseBoolArrayAttr byref,
- ArrayAttr syms)
- : vars(vars), types(types), byref(byref), syms(syms) {}
+ ArrayAttr syms,
+ ReductionModifierAttr reductionMod = nullReductionMod)
+ : vars(vars), types(types), byref(byref), syms(syms),
+ reductionMod(reductionMod) {}
};
struct AllRegionPrintArgs {
std::optional<ReductionPrintArgs> inReductionArgs;
@@ -799,17 +825,21 @@ struct AllRegionPrintArgs {
};
} // namespace
-static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx,
- StringRef clauseName,
- ValueRange argsSubrange,
- ValueRange operands, TypeRange types,
- ArrayAttr symbols = nullptr,
- DenseBoolArrayAttr byref = nullptr) {
+static void printClauseWithRegionArgs(
+ OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
+ ValueRange argsSubrange, ValueRange operands, TypeRange types,
+ ArrayAttr symbols = nullptr, DenseBoolArrayAttr byref = nullptr,
+ ReductionModifierAttr reductionMod = nullptr) {
if (argsSubrange.empty())
return;
p << clauseName << "(";
+ if (reductionMod) {
+ p << "type: " << stringifyReductionModifier(reductionMod.getValue())
+ << ", ";
+ }
+
if (!symbols) {
llvm::SmallVector<Attribute> values(operands.size(), nullptr);
symbols = ArrayAttr::get(ctx, values);
@@ -859,7 +889,8 @@ printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
if (reductionArgs)
printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
reductionArgs->vars, reductionArgs->types,
- reductionArgs->syms, reductionArgs->byref);
+ reductionArgs->syms, reductionArgs->byref,
+ reductionArgs->reductionMod);
}
static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region ®ion,
@@ -916,14 +947,15 @@ static void printInReductionPrivateReductionRegion(
OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars,
TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
- ArrayAttr privateSyms, ValueRange reductionVars, TypeRange reductionTypes,
+ ArrayAttr privateSyms, ReductionModifierAttr reductionMod,
+ ValueRange reductionVars, TypeRange reductionTypes,
DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms) {
AllRegionPrintArgs args;
args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
inReductionByref, inReductionSyms);
args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
- reductionSyms);
+ reductionSyms, reductionMod);
printBlockArgRegion(p, op, region, args);
}
@@ -937,13 +969,14 @@ static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region ®ion,
static void printPrivateReductionRegion(
OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange privateVars,
- TypeRange privateTypes, ArrayAttr privateSyms, ValueRange reductionVars,
+ TypeRange privateTypes, ArrayAttr privateSyms,
+ ReductionModifierAttr reductionMod, ValueRange reductionVars,
TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
ArrayAttr reductionSyms) {
AllRegionPrintArgs args;
args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
- reductionSyms);
+ reductionSyms, reductionMod);
printBlockArgRegion(p, op, region, args);
}
@@ -1700,7 +1733,7 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
/*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
/*num_threads=*/nullptr, /*private_vars=*/ValueRange(),
/*private_syms=*/nullptr, /*proc_bind_kind=*/nullptr,
- /*reduction_vars=*/ValueRange(),
+ /*reduction_mod =*/nullptr, /*reduction_vars=*/ValueRange(),
/*reduction_byref=*/nullptr, /*reduction_syms=*/nullptr);
state.addAttributes(attributes);
}
@@ -1711,7 +1744,8 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
clauses.ifExpr, clauses.numThreads, clauses.privateVars,
makeArrayAttr(ctx, clauses.privateSyms),
- clauses.procBindKind, clauses.reductionVars,
+ clauses.procBindKind, clauses.reductionMod,
+ clauses.reductionVars,
makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
makeArrayAttr(ctx, clauses.reductionSyms));
}
@@ -1810,12 +1844,13 @@ void TeamsOp::build(OpBuilder &builder, OperationState &state,
const TeamsOperands &clauses) {
MLIRContext *ctx = builder.getContext();
// TODO Store clauses in op: privateVars, privateSyms.
- TeamsOp::build(
- builder, state, clauses.allocateVars, clauses.allocatorVars,
- clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpper,
- /*private_vars=*/{}, /*private_syms=*/nullptr, cla...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you Anchu for working on this, I have some small implementation comments, but I also have some concerns about your proposed representation for omp.scan
.
Other than that, I'd also suggest adding a not-yet-implemented error in MLIR to LLVM IR translation for the new reduction modifier attributes, and update mlir/test/Target/LLVMIR/openmp-todo.mlir to check for those and for the error that should already be triggered for omp.scan
.
@@ -1560,6 +1560,26 @@ def CancellationPointOp : OpenMP_Op<"cancellation_point", clauses = [ | |||
let hasVerifier = 1; | |||
} | |||
|
|||
def ScanOp : OpenMP_Op<"scan", [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I understand the behavior of the corresponding OpenMP directive correctly, it basically splits a loop body at the point it appears. Everything before it is the input phase, and everything after it is the scan phase. In that case, shouldn't the corresponding MLIR operation reflect this by defining either one or two regions?
// 2-region alternative.
omp.loop_nest ... {
omp.scan ...
input {
// Everything before '!$omp scan ...'
...
omp.terminator
} scan {
// Everything after '!$omp scan ...'
...
omp.terminator
}
// Nothing else other than omp.yield allowed here.
omp.yield
}
// 1-region alternative.
omp.loop_nest ... {
// Everything before '!$omp scan ...'
...
omp.scan ... {
// Everything after '!$omp scan ...'
...
omp.terminator
}
// Nothing else other than omp.yield allowed here.
omp.yield
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great question. There are a few facts to be considered here. For Simplicity, i will consider body of for-loop as
{
/* pre scan block */ b1;
scan
/* post scan block */ b2;
}
- As per the definition of scan,
b1
orb2
can be input or scan phase based onexclusive
orinclusive
clause used with the directive - My first attempt was to define this structure in parsing, however I failed parsing
scan directive
as<structured sequence> scan <structured sequence>
because it was making the grammar ambiguous and I could not get it successfully parsed. I digged into clang and saw that the blocks (b1, b2) are not associated with the directive (scan is a stand-alone directive). Its the llvm lowering that separated the two blocks. When lowered, the body of the for loop before scan is emitted as first basic block and after scan is emitted as another basic block. On encountering scan, these blocks are treated as input or scan phase based on clauses that scan has. Since there are only two blocks and as scan directive always appear in loops withinscan
reduction modifier, dealing it while lowering would be simple. - When parsing does not separate it easily, representing it as an MLIR op will be very complicated to do in frontend and would result in code that is not much reusable if we go with that representation.
These are why the current representation is chosen.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My understanding is that clang would be able to base the implementation of scan
on the position of the directive within the loop body (just like how it's defined at the source level by the spec) because the AST is not modified by moving statements around possibly crossing the scan
split point. Which would also be the reason why you are able to parse it that way in Flang.
However, MLIR transformation passes are able to move operations within a region in a way that would potentially break scan
. We must make sure the MLIR representation we choose prevents this from happening, and at the moment the best I can think of would be the 2-region alternative I illustrated above.
But I see the problem you point to about the naming of these regions. However, if instead of calling them input
and scan
, we call them pre
and post
then we should be able to lower them to LLVM IR properly based on whether it's an inclusive
or exclusive
scan. That would mean something like this:
omp.wsloop reduction(inscan @reduction %0 -> %red0 : !llvm.ptr) {
omp.loop_nest ... {
omp.scan exclusive(%red0 : !llvm.ptr)
pre {
// Everything before '!$omp scan ...'
...
omp.terminator
} post {
// Everything after '!$omp scan ...'
...
omp.terminator
}
omp.yield
}
}
The questions that I guess would remain for this approach would be whether we can let common sub-expressions, constants etc. be hoisted out of the pre
and post
regions, whether we should allow other non-terminator operations after omp.scan
inside of the loop and whether we need to make values defined in pre
accessible to post
by having an omp.yield
terminator for the former and corresponding entry block arguments for the latter.
CC: @kiranchandramohan, @tblah.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall I like @skatrak's proposal
The examples for inclusive scans here seem to indicate that we will need to allow yielding values from the pre
to the post
region.
As for whether we should allow non-terminators after (or before) the omp.scan, the language used in the standard is structured block sequence
. I think that would mean that even if there are multiple structured blocks, all of those would still become logically part of the scan
and so all operations before the scan should be in the pre
and all after should be in the post
region.
As for common subexpression hoisting, I think this is safe if an expression is hoisted all of the way out of the loop (because then it must not have any side effects or depend upon the loop iteration). I'm unsure about putting it before the scan inside of the loop. It would probably easier to generate the LLVMIR if this was not allowed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The 2-region option proposed by @skatrak would be easier to interface with the OpenMPIRBuilder during translation.
The 1-region option might be easier for the lowering in Flang (parse-tree to MLIR).
But just want to understand whether the no-region option will definitely lead to an issue.
However, MLIR transformation passes are able to move operations within a region in a way that would potentially break scan.
Are these Operations moving across the scan operation? Would these be only constants or something else? Assuming the scan operation is side-effecting will operations cross it?
Generally, from a flang point of view, we typically have loading and storing from memory/variables. Since all variables are from outside the loop region these will be visible in both pre
and post
. This will change if we have performed load-store forwarding or mem2reg. Also, I wonder about the case where all the contents of the block are inside a block construct. In this case the alloca for j
will not be visible in the post
region and would need yielding.
program mn
do i=1,10
block
integer :: j
scan directive
= j
end block
end do
end program
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @tblah and @kiranchandramohan for sharing your thoughts. @anchuraj and I had a call about this the other day, and at the time the 2-region approach made the most sense. But if we can somehow ensure that no MLIR transformations are able to move operations across the scan
split point in a way that breaks semantics, it is my understanding that the translation to LLVM IR wouldn't be much more difficult than having explicit regions, and it would potentially make the representation and Flang lowering simpler in general.
With regards to both approaches introducing regions, they have the disadvantage that the current to the PFT to MLIR lowering system doesn't seem well-suited to lower only parts of a block of code, so we thought that, if we followed that approach, then we'd probably want to first produce an omp.scan
with both regions empty at first, while lowering the rest of the loop normally. After that, we'd introduce another Flang-specific MLIR pass (similar to function filtering, map info finalization, etc.) that would sink operations before the omp.scan
into its pre
region and operations after it into its post
region. If there were values defined in pre
used in post
, that pass would be able to detect them and produce the corresponding omp.yield
and post
entry block arguments.
If a common subexpression doesn't depend on the loop index, I agree that this would result in no problems for any of the approaches, since it should be hoisted out of the loop body. Same thing with constants, for example. However, one case I had in mind, which I think could potentially result in MLIR optimizations moving operations across the split point would be the i + 1
expression below:
subroutine foo(v1, v2, size)
implicit none
integer, intent(in) :: size, v1(size)
integer, intent(out) :: v2(size)
integer :: sum, i
sum = 0
!$omp parallel do reduction(+:sum)
do i = 0, size - 1
sum = sum + v1(i + 1)
!$omp scan inclusive(sum)
v2(i + 1) = sum
end do
!$omp end parallel do
end subroutine foo
At the moment that results in separate fir.load
s for i
and separate arith.constant
s for 1
, so the only values that would be used both inside of pre
and post
would be the hlfir.declare
for the sum
reduction variable (which we may want to leave out of both regions or put it in pre
). But I can't see any reason why, if we don't split the loop body, MLIR passes wouldn't be able to realize there are 2 loads to a same variable that is not modified and calculate i + 1
only once at the beginning. That kind of transformation doesn't seem like it would break the region-less approach, but it would be one case introducing the need for passing values from pre
to post
.
So I'm thinking that perhaps it would indeed be enough to mark omp.scan
as reading from and writing to its inclusive / exclusive
arguments. That should ensure no ops related to these values will be reordered across the split point, and that omp.scan
won't be itself moved in a way that results in such a reordering. Maybe that's enough to ensure valid passes won't break this operation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your suggestions @tblah, @skatrak and @kiranchandramohan.
Looking at the standard, following is a restriction on scan directive.
Intra-iteration dependences from a statement in the structured block sequence that precede a scan directive to a statement in the structured block sequence that follows a scan directive must not exist, except for dependences for the list items specified in an inclusive or exclusive clause.
I believe, the above restriction makes the examples like the one using block
construct mentioned by @kiranchandramohan invalid.
From the example @skatrak used, we need to transform program to a way where yielding would be required from input phase to scan phase. It transforms program to a structure which is not compliant with the above OpenMP restriction ( which need not be hold at the MLIR transformation stage) . I also believe the current llvm lowering logic will have to changed if this needs to be implemented. The current llvm ir generated on a high level has the following structure : the for loop
is split into two (one with the input phase and one with the scan phase) and a reduction loop is added in the middle. (In detail: A new buffer of size number of iterations is declared and the update to the reduction variable in the input phase in each iteration is copied to the buffer. After the input phase for loop is executed, a reduction for loop
is executed, after which, content at each index of the buffer corresponds to the result of scan reduction at iteration number = index. Then the scan phase for loop is executed , which first updates the reduction variable to the value from buffer corresponding to the index) and then executes scan phase
//declare a buffer
buffer[num_iterations]
//First for loop
for(int i=0; ...) {
input phase code;
buffer[i] = red_var;
}
//Scan reduction for loop
for(..){
do scan reduction over buffer
}
//Second for loop
for(int i=0; ...) {
red = InclusiveScan ? buffer[i] : buffer[i-1];
scan phase;
}
With yielding we need to make the updates in first for loop available in second for loop which might require additional transformations.
Circling back on @kiranchandramohan's question, is it possible to prevent such optimizations inside WsloopOp
with inscan
modifier, considering scan
as side-effecting? Please let me know @skatrak and @tblah your thoughts.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I guess giving the scan memory effects could work. I can't think of a reason off the top of my head why @skatrak's suggestion to use the memory effects only on the exclusive/inclusive arguments wouldn't work.
But just to be safe I would probably do the initial implementation saying it has some broadly specified memory effects. If we follow the recommendations for modelling syscalls given here, I think that should ensure that any observable effects before the scan must stay before (and vice versa).
So it would be modeled a bit like
input phase code;
write(1, "I am a scan statement.\n", 23);
output phase code;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you. I discussed this offline with @skatrak. Based on the discussions and inputs, I am proceeding with implementing no region approach with ScanOp
having memory side effects . Please let me know if there are any further concerns on this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated the operation with MemWrite
effects
a838aa0
to
9a20d5b
Compare
@skatrak, I have removed the check to compare
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you Anchu for this work. I think this representation should work, my remaining comments are mostly just about properly communicating currently unsupported cases to users and some minor nits.
mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, this LGTM. Just minor comments, but no need for a second review by me before merging.
return failure(); | ||
} | ||
return success(); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remember to add back this line before merging.
Thank you @skatrak for reviewing my PR! I have addressed the review comments |
Scan directive allows to specify scan reductions within an worksharing loop, worksharing loop simd or simd directive which should have an
InScan
modifier associated with it. This change adds the mlir support for the same.Related PR: Parsing and Semantic Support for scan