Skip to content

[flang][OpenMP] Implement HAS_DEVICE_ADDR clause #128568

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Mar 10, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Address review comments
  • Loading branch information
kparzysz committed Feb 28, 2025
commit f62de23dd9b4b2ce2eba61715f68dfe240375450
2 changes: 1 addition & 1 deletion flang/include/flang/Support/OpenMP-utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ struct EntryBlockArgsEntry {
/// Structure holding the information needed to create and bind entry block
/// arguments associated to all clauses that can define them.
struct EntryBlockArgs {
llvm::ArrayRef<mlir::Value> hostEvalVars;
EntryBlockArgsEntry hasDeviceAddr;
llvm::ArrayRef<mlir::Value> hostEvalVars;
EntryBlockArgsEntry inReduction;
EntryBlockArgsEntry map;
EntryBlockArgsEntry priv;
Expand Down
8 changes: 4 additions & 4 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2291,19 +2291,19 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,

auto targetOp = firOpBuilder.create<mlir::omp::TargetOp>(loc, clauseOps);

llvm::SmallVector<mlir::Value> mapBaseValues, hasDeviceAddrBaseValues;
extractMappedBaseValues(clauseOps.mapVars, mapBaseValues);
llvm::SmallVector<mlir::Value> hasDeviceAddrBaseValues, mapBaseValues;
extractMappedBaseValues(clauseOps.hasDeviceAddrVars, hasDeviceAddrBaseValues);
extractMappedBaseValues(clauseOps.mapVars, mapBaseValues);

EntryBlockArgs args;
args.hasDeviceAddr.syms = hasDeviceAddrSyms;
args.hasDeviceAddr.vars = hasDeviceAddrBaseValues;
args.hostEvalVars = clauseOps.hostEvalVars;
// TODO: Add in_reduction syms and vars.
args.map.syms = mapSyms;
args.map.vars = mapBaseValues;
args.priv.syms = dsp.getDelayedPrivSymbols();
args.priv.vars = clauseOps.privateVars;
args.hasDeviceAddr.syms = hasDeviceAddrSyms;
args.hasDeviceAddr.vars = hasDeviceAddrBaseValues;

genBodyOfTargetOp(converter, symTable, semaCtx, eval, targetOp, args, loc,
queue, item, dsp);
Expand Down
1 change: 1 addition & 0 deletions flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ class MapInfoFinalizationPass
/// Check if the mapOp is present in the HasDeviceAddr clause on
/// the userOp. Only applies to TargetOp.
bool isHasDeviceAddr(mlir::omp::MapInfoOp mapOp, mlir::Operation *userOp) {
assert(userOp && "Expecting non-null argument");
if (auto targetOp = llvm::dyn_cast<mlir::omp::TargetOp>(userOp)) {
for (mlir::Value hda : targetOp.getHasDeviceAddrVars()) {
if (hda.getDefiningOp() == mapOp)
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1326,7 +1326,7 @@ def TargetOp : OpenMP_Op<"target", traits = [
}] # clausesExtraClassDeclaration;

let assemblyFormat = clausesAssemblyFormat # [{
custom<HasDeviceAddrHostEvalInReductionMapPrivateRegion>(
custom<TargetOpRegion>(
$region, $has_device_addr_vars, type($has_device_addr_vars),
$host_eval_vars, type($host_eval_vars), $in_reduction_vars,
type($in_reduction_vars), $in_reduction_byref, $in_reduction_syms,
Expand Down
45 changes: 15 additions & 30 deletions mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -808,9 +808,9 @@ static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region,
return parser.parseRegion(region, entryBlockArgs);
}

// See custom<HasDeviceAddrHostEvalInReductionMapPrivateRegion> in the
// definition of TargetOp.
static ParseResult parseHasDeviceAddrHostEvalInReductionMapPrivateRegion(
// These parseXyz functions correspond to the custom<Xyz> definitions
// in the .td file(s).
static ParseResult parseTargetOpRegion(
OpAsmParser &parser, Region &region,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &hasDeviceAddrVars,
SmallVectorImpl<Type> &hasDeviceAddrTypes,
Expand All @@ -835,7 +835,6 @@ static ParseResult parseHasDeviceAddrHostEvalInReductionMapPrivateRegion(
return parseBlockArgRegion(parser, region, args);
}

// See custom<InReductionPrivateRegion> in the definition of TaskOp.
static ParseResult parseInReductionPrivateRegion(
OpAsmParser &parser, Region &region,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &inReductionVars,
Expand All @@ -850,8 +849,6 @@ static ParseResult parseInReductionPrivateRegion(
return parseBlockArgRegion(parser, region, args);
}

// See custom<InReductionPrivateReductionRegion> in the definition of
// TaskloopOp.
static ParseResult parseInReductionPrivateReductionRegion(
OpAsmParser &parser, Region &region,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &inReductionVars,
Expand All @@ -872,7 +869,6 @@ static ParseResult parseInReductionPrivateReductionRegion(
return parseBlockArgRegion(parser, region, args);
}

// See custom<PrivateRegion> in the definition of SingleOp.
static ParseResult parsePrivateRegion(
OpAsmParser &parser, Region &region,
llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
Expand All @@ -882,7 +878,6 @@ static ParseResult parsePrivateRegion(
return parseBlockArgRegion(parser, region, args);
}

// See custom<PrivateReductionRegion> in the definition of LoopOp.
static ParseResult parsePrivateReductionRegion(
OpAsmParser &parser, Region &region,
llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
Expand All @@ -898,7 +893,6 @@ static ParseResult parsePrivateReductionRegion(
return parseBlockArgRegion(parser, region, args);
}

// See custom<TaskReductionRegion> in the definition of TaskgroupOp.
static ParseResult parseTaskReductionRegion(
OpAsmParser &parser, Region &region,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &taskReductionVars,
Expand All @@ -910,8 +904,6 @@ static ParseResult parseTaskReductionRegion(
return parseBlockArgRegion(parser, region, args);
}

// See custom<UseDeviceAddrUseDevicePtrRegion> in the definition of
// TargetDataOp.
static ParseResult parseUseDeviceAddrUseDevicePtrRegion(
OpAsmParser &parser, Region &region,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &useDeviceAddrVars,
Expand Down Expand Up @@ -1073,17 +1065,18 @@ static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region,
p.printRegion(region, /*printEntryBlockArgs=*/false);
}

// See custom<HasDeviceAddrHostEvalInReductionMapPrivateRegion> in the
// definition of TargetOp.
static void printHasDeviceAddrHostEvalInReductionMapPrivateRegion(
OpAsmPrinter &p, Operation *op, Region &region,
ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes,
ValueRange hostEvalVars, TypeRange hostEvalTypes,
ValueRange inReductionVars, TypeRange inReductionTypes,
DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms,
ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars,
TypeRange privateTypes, ArrayAttr privateSyms,
DenseI64ArrayAttr privateMaps) {
// These parseXyz functions correspond to the custom<Xyz> definitions
// in the .td file(s).
static void
printTargetOpRegion(OpAsmPrinter &p, Operation *op, Region &region,
ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes,
ValueRange hostEvalVars, TypeRange hostEvalTypes,
ValueRange inReductionVars, TypeRange inReductionTypes,
DenseBoolArrayAttr inReductionByref,
ArrayAttr inReductionSyms, ValueRange mapVars,
TypeRange mapTypes, ValueRange privateVars,
TypeRange privateTypes, ArrayAttr privateSyms,
DenseI64ArrayAttr privateMaps) {
AllRegionPrintArgs args;
args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
Expand All @@ -1094,7 +1087,6 @@ static void printHasDeviceAddrHostEvalInReductionMapPrivateRegion(
printBlockArgRegion(p, op, region, args);
}

// See custom<InReductionPrivateRegion> in the definition of TaskOp.
static void printInReductionPrivateRegion(
OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
Expand All @@ -1108,8 +1100,6 @@ static void printInReductionPrivateRegion(
printBlockArgRegion(p, op, region, args);
}

// See custom<InReductionPrivateReductionRegion> in the definition of
// TaskloopOp.
static void printInReductionPrivateReductionRegion(
OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
Expand All @@ -1127,7 +1117,6 @@ static void printInReductionPrivateReductionRegion(
printBlockArgRegion(p, op, region, args);
}

// See custom<PrivateRegion> in the definition of SingleOp.
static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region,
ValueRange privateVars, TypeRange privateTypes,
ArrayAttr privateSyms) {
Expand All @@ -1137,7 +1126,6 @@ static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region,
printBlockArgRegion(p, op, region, args);
}

// See custom<PrivateReductionRegion> in the definition of LoopOp.
static void printPrivateReductionRegion(
OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars,
TypeRange privateTypes, ArrayAttr privateSyms,
Expand All @@ -1152,7 +1140,6 @@ static void printPrivateReductionRegion(
printBlockArgRegion(p, op, region, args);
}

// See custom<TaskReductionRegion> in the definition of TaskgroupOp.
static void printTaskReductionRegion(OpAsmPrinter &p, Operation *op,
Region &region,
ValueRange taskReductionVars,
Expand All @@ -1165,8 +1152,6 @@ static void printTaskReductionRegion(OpAsmPrinter &p, Operation *op,
printBlockArgRegion(p, op, region, args);
}

// See custom<UseDeviceAddrUseDevicePtrRegion> in the definition of
// TargetDataOp.
static void printUseDeviceAddrUseDevicePtrRegion(OpAsmPrinter &p, Operation *op,
Region &region,
ValueRange useDeviceAddrVars,
Expand Down
Loading