diff --git a/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp b/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp index e048015298461..3109e1e1758d7 100644 --- a/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp +++ b/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp @@ -174,6 +174,7 @@ #include "llvm/IR/Function.h" #include "llvm/IR/GetElementPtrTypeIterator.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstIterator.h" #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" @@ -190,6 +191,7 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" #include #include @@ -198,6 +200,8 @@ using namespace llvm; using namespace llvm::PatternMatch; +#define DEBUG_TYPE "separate-offset-gep" + static cl::opt DisableSeparateConstOffsetFromGEP( "disable-separate-const-offset-from-gep", cl::init(false), cl::desc("Do not separate the constant offset from a GEP instruction"), @@ -486,6 +490,42 @@ class SeparateConstOffsetFromGEP { DenseMap> DominatingSubs; }; +/// A helper class that aims to convert xor operations into or operations when +/// their operands are disjoint and the result is used in a GEP's index. This +/// can then enable further GEP optimizations by effectively turning BaseVal | +/// Const into BaseVal + Const when they are disjoint, which +/// SeparateConstOffsetFromGEP can then process. This is a common pattern that +/// sets up a grid of memory accesses across a wave where each thread acesses +/// data at various offsets. +class XorToOrDisjointTransformer { +public: + XorToOrDisjointTransformer(Function &F, DominatorTree &DT, + const DataLayout &DL) + : F(F), DT(DT), DL(DL) {} + + bool run(); + +private: + Function &F; + DominatorTree &DT; + const DataLayout &DL; + /// Maps a common operand to all Xor instructions + using XorOpList = SmallVector, 8>; + using XorBaseValInst = DenseMap; + XorBaseValInst XorGroups; + + /// Checks if the given value has at least one GetElementPtr user + static bool hasGEPUser(const Value *V); + + /// Helper function to check if BaseXor dominates all XORs in the group + bool dominatesAllXors(BinaryOperator *BaseXor, const XorOpList &XorsInGroup); + + /// Processes a group of XOR instructions that share the same non-constant + /// base operand. Returns true if this group's processing modified the + /// function. + bool processXorGroup(Instruction *OriginalBaseInst, XorOpList &XorsInGroup); +}; + } // end anonymous namespace char SeparateConstOffsetFromGEPLegacyPass::ID = 0; @@ -1162,6 +1202,154 @@ bool SeparateConstOffsetFromGEP::splitGEP(GetElementPtrInst *GEP) { return true; } +// Helper function to check if an instruction has at least one GEP user +bool XorToOrDisjointTransformer::hasGEPUser(const Value *V) { + return llvm::any_of(V->users(), [](const User *U) { + return isa(U); + }); +} + +bool XorToOrDisjointTransformer::dominatesAllXors( + BinaryOperator *BaseXor, const XorOpList &XorsInGroup) { + return llvm::all_of(XorsInGroup, [&](const auto &XorEntry) { + BinaryOperator *XorInst = XorEntry.first; + // Do not evaluate the BaseXor, otherwise we end up cloning it. + return XorInst == BaseXor || DT.dominates(BaseXor, XorInst); + }); +} + +bool XorToOrDisjointTransformer::processXorGroup(Instruction *OriginalBaseInst, + XorOpList &XorsInGroup) { + bool Changed = false; + if (XorsInGroup.size() <= 1) + return false; + + // Sort XorsInGroup by the constant offset value in increasing order. + llvm::sort(XorsInGroup, [](const auto &A, const auto &B) { + return A.second.slt(B.second); + }); + + // Dominance check + // The "base" XOR for dominance purposes is the one with the smallest + // constant. + BinaryOperator *XorWithSmallConst = XorsInGroup[0].first; + + if (!dominatesAllXors(XorWithSmallConst, XorsInGroup)) { + LLVM_DEBUG(dbgs() << DEBUG_TYPE + << ": Cloning and inserting XOR with smallest constant (" + << *XorWithSmallConst + << ") as it does not dominate all other XORs" + << " in function " << F.getName() << "\n"); + + BinaryOperator *ClonedXor = + cast(XorWithSmallConst->clone()); + ClonedXor->setName(XorWithSmallConst->getName() + ".dom_clone"); + ClonedXor->insertAfter(OriginalBaseInst); + LLVM_DEBUG(dbgs() << " Cloned Inst: " << *ClonedXor << "\n"); + Changed = true; + XorWithSmallConst = ClonedXor; + } + + SmallVector InstructionsToErase; + const APInt SmallestConst = + cast(XorWithSmallConst->getOperand(1))->getValue(); + + // Main transformation loop: Iterate over the original XORs in the sorted + // group. + for (const auto &XorEntry : XorsInGroup) { + BinaryOperator *XorInst = XorEntry.first; // Original XOR instruction + const APInt ConstOffsetVal = XorEntry.second; + + // Do not process the one with smallest constant as it is the base. + if (XorInst == XorWithSmallConst) + continue; + + // Disjointness Check 1 + APInt NewConstVal = ConstOffsetVal - SmallestConst; + if ((NewConstVal & SmallestConst) != 0) { + LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Cannot transform XOR in function " + << F.getName() << ":\n" + << " New Const: " << NewConstVal + << " Smallest Const: " << SmallestConst + << " are not disjoint \n"); + continue; + } + + // Disjointness Check 2 + if (MaskedValueIsZero(XorWithSmallConst, NewConstVal, SimplifyQuery(DL), + 0)) { + LLVM_DEBUG(dbgs() << DEBUG_TYPE + << ": Transforming XOR to OR (disjoint) in function " + << F.getName() << ":\n" + << " Xor: " << *XorInst << "\n" + << " Base Val: " << *XorWithSmallConst << "\n" + << " New Const: " << NewConstVal << "\n"); + + auto *NewOrInst = BinaryOperator::CreateDisjointOr( + XorWithSmallConst, + ConstantInt::get(OriginalBaseInst->getType(), NewConstVal), + XorInst->getName() + ".or_disjoint", XorInst->getIterator()); + + NewOrInst->copyMetadata(*XorInst); + XorInst->replaceAllUsesWith(NewOrInst); + LLVM_DEBUG(dbgs() << " New Inst: " << *NewOrInst << "\n"); + InstructionsToErase.push_back(XorInst); // Mark original XOR for deletion + + Changed = true; + } else { + LLVM_DEBUG( + dbgs() << DEBUG_TYPE + << ": Cannot transform XOR (not proven disjoint) in function " + << F.getName() << ":\n" + << " Xor: " << *XorInst << "\n" + << " Base Val: " << *XorWithSmallConst << "\n" + << " New Const: " << NewConstVal << "\n"); + } + } + + for (Instruction *I : InstructionsToErase) + I->eraseFromParent(); + + return Changed; +} + +// Try to transform XOR(A, B+C) in to XOR(A,C) + B where XOR(A,C) becomes +// the base for memory operations. This transformation is true under the +// following conditions +// Check 1 - B and C are disjoint. +// Check 2 - XOR(A,C) and B are disjoint. +// +// This transformation is beneficial particularly for GEPs because: +// 1. OR operations often map better to addressing modes than XOR +// 2. Disjoint OR operations preserve the semantics of the original XOR +// 3. This can enable further optimizations in the GEP offset folding pipeline +bool XorToOrDisjointTransformer::run() { + bool Changed = false; + + // Collect all candidate XORs + for (Instruction &I : instructions(F)) { + Instruction *Op0 = nullptr; + ConstantInt *C1 = nullptr; + BinaryOperator *MatchedXorOp = nullptr; + + // Attempt to match the instruction 'I' as XOR operation. + if (match(&I, m_CombineAnd(m_Xor(m_Instruction(Op0), m_ConstantInt(C1)), + m_BinOp(MatchedXorOp))) && + hasGEPUser(MatchedXorOp)) + XorGroups[Op0].emplace_back(MatchedXorOp, C1->getValue()); + } + + if (XorGroups.empty()) + return false; + + // Process each group of XORs + for (auto &[OriginalBaseInst, XorsInGroup] : XorGroups) + if (processXorGroup(OriginalBaseInst, XorsInGroup)) + Changed = true; + + return Changed; +} + bool SeparateConstOffsetFromGEPLegacyPass::runOnFunction(Function &F) { if (skipFunction(F)) return false; @@ -1181,6 +1369,11 @@ bool SeparateConstOffsetFromGEP::run(Function &F) { DL = &F.getDataLayout(); bool Changed = false; + + // Decompose xor in to "or disjoint" if possible. + XorToOrDisjointTransformer XorTransformer(F, *DT, *DL); + Changed |= XorTransformer.run(); + for (BasicBlock &B : F) { if (!DT->isReachableFromEntry(&B)) continue; diff --git a/llvm/test/Transforms/SeparateConstOffsetFromGEP/AMDGPU/xor-to-or-disjoint.ll b/llvm/test/Transforms/SeparateConstOffsetFromGEP/AMDGPU/xor-to-or-disjoint.ll new file mode 100644 index 0000000000000..825227292fe14 --- /dev/null +++ b/llvm/test/Transforms/SeparateConstOffsetFromGEP/AMDGPU/xor-to-or-disjoint.ll @@ -0,0 +1,204 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 +; RUN: opt -mtriple=amdgcn-amd-amdhsa -passes=separate-const-offset-from-gep \ +; RUN: -S < %s | FileCheck %s + + +; Test a simple case of xor to or disjoint transformation +define half @test_basic_transformation(ptr %ptr, i64 %input) { +; CHECK-LABEL: define half @test_basic_transformation( +; CHECK-SAME: ptr [[PTR:%.*]], i64 [[INPUT:%.*]]) { +; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: [[BASE:%.*]] = and i64 [[INPUT]], -8192 +; CHECK-NEXT: [[ADDR1:%.*]] = xor i64 [[BASE]], 32 +; CHECK-NEXT: [[ADDR2_OR_DISJOINT:%.*]] = or disjoint i64 [[ADDR1]], 2048 +; CHECK-NEXT: [[ADDR3_OR_DISJOINT:%.*]] = or disjoint i64 [[ADDR1]], 4096 +; CHECK-NEXT: [[GEP1:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[ADDR1]] +; CHECK-NEXT: [[GEP2:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[ADDR2_OR_DISJOINT]] +; CHECK-NEXT: [[GEP3:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[ADDR3_OR_DISJOINT]] +; CHECK-NEXT: [[VAL1:%.*]] = load half, ptr [[GEP1]], align 2 +; CHECK-NEXT: [[VAL2:%.*]] = load half, ptr [[GEP2]], align 2 +; CHECK-NEXT: [[VAL3:%.*]] = load half, ptr [[GEP3]], align 2 +; CHECK-NEXT: [[VAL1_F:%.*]] = fpext half [[VAL1]] to float +; CHECK-NEXT: [[VAL2_F:%.*]] = fpext half [[VAL2]] to float +; CHECK-NEXT: [[VAL3_F:%.*]] = fpext half [[VAL3]] to float +; CHECK-NEXT: [[SUM1_F:%.*]] = fadd float [[VAL1_F]], [[VAL2_F]] +; CHECK-NEXT: [[SUM_TOTAL_F:%.*]] = fadd float [[SUM1_F]], [[VAL3_F]] +; CHECK-NEXT: [[RESULT_H:%.*]] = fptrunc float [[SUM_TOTAL_F]] to half +; CHECK-NEXT: ret half [[RESULT_H]] +; +entry: + %base = and i64 %input, -8192 ; Clear low bits + %addr1 = xor i64 %base, 32 + %addr2 = xor i64 %base, 2080 + %addr3 = xor i64 %base, 4128 + %gep1 = getelementptr i8, ptr %ptr, i64 %addr1 + %gep2 = getelementptr i8, ptr %ptr, i64 %addr2 + %gep3 = getelementptr i8, ptr %ptr, i64 %addr3 + %val1 = load half, ptr %gep1 + %val2 = load half, ptr %gep2 + %val3 = load half, ptr %gep3 + %val1.f = fpext half %val1 to float + %val2.f = fpext half %val2 to float + %val3.f = fpext half %val3 to float + %sum1.f = fadd float %val1.f, %val2.f + %sum_total.f = fadd float %sum1.f, %val3.f + %result.h = fptrunc float %sum_total.f to half + ret half %result.h +} + + +; Test the decreasing order of offset xor to or disjoint transformation +define half @test_descending_offset_transformation(ptr %ptr, i64 %input) { +; CHECK-LABEL: define half @test_descending_offset_transformation( +; CHECK-SAME: ptr [[PTR:%.*]], i64 [[INPUT:%.*]]) { +; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: [[BASE:%.*]] = and i64 [[INPUT]], -8192 +; CHECK-NEXT: [[ADDR3_DOM_CLONE:%.*]] = xor i64 [[BASE]], 32 +; CHECK-NEXT: [[ADDR1_OR_DISJOINT:%.*]] = or disjoint i64 [[ADDR3_DOM_CLONE]], 4096 +; CHECK-NEXT: [[ADDR2_OR_DISJOINT:%.*]] = or disjoint i64 [[ADDR3_DOM_CLONE]], 2048 +; CHECK-NEXT: [[ADDR3_OR_DISJOINT:%.*]] = or disjoint i64 [[ADDR3_DOM_CLONE]], 0 +; CHECK-NEXT: [[GEP1:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[ADDR1_OR_DISJOINT]] +; CHECK-NEXT: [[GEP2:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[ADDR2_OR_DISJOINT]] +; CHECK-NEXT: [[GEP3:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[ADDR3_OR_DISJOINT]] +; CHECK-NEXT: [[VAL1:%.*]] = load half, ptr [[GEP1]], align 2 +; CHECK-NEXT: [[VAL2:%.*]] = load half, ptr [[GEP2]], align 2 +; CHECK-NEXT: [[VAL3:%.*]] = load half, ptr [[GEP3]], align 2 +; CHECK-NEXT: [[VAL1_F:%.*]] = fpext half [[VAL1]] to float +; CHECK-NEXT: [[VAL2_F:%.*]] = fpext half [[VAL2]] to float +; CHECK-NEXT: [[VAL3_F:%.*]] = fpext half [[VAL3]] to float +; CHECK-NEXT: [[SUM1_F:%.*]] = fadd float [[VAL1_F]], [[VAL2_F]] +; CHECK-NEXT: [[SUM_TOTAL_F:%.*]] = fadd float [[SUM1_F]], [[VAL3_F]] +; CHECK-NEXT: [[RESULT_H:%.*]] = fptrunc float [[SUM_TOTAL_F]] to half +; CHECK-NEXT: ret half [[RESULT_H]] +; +entry: + %base = and i64 %input, -8192 ; Clear low bits + %addr1 = xor i64 %base, 4128 + %addr2 = xor i64 %base, 2080 + %addr3 = xor i64 %base, 32 + %gep1 = getelementptr i8, ptr %ptr, i64 %addr1 + %gep2 = getelementptr i8, ptr %ptr, i64 %addr2 + %gep3 = getelementptr i8, ptr %ptr, i64 %addr3 + %val1 = load half, ptr %gep1 + %val2 = load half, ptr %gep2 + %val3 = load half, ptr %gep3 + %val1.f = fpext half %val1 to float + %val2.f = fpext half %val2 to float + %val3.f = fpext half %val3 to float + %sum1.f = fadd float %val1.f, %val2.f + %sum_total.f = fadd float %sum1.f, %val3.f + %result.h = fptrunc float %sum_total.f to half + ret half %result.h +} + + +; Test that %addr2 is not transformed to or disjoint. +define half @test_no_transfomation(ptr %ptr, i64 %input) { +; CHECK-LABEL: define half @test_no_transfomation( +; CHECK-SAME: ptr [[PTR:%.*]], i64 [[INPUT:%.*]]) { +; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: [[BASE:%.*]] = and i64 [[INPUT]], -8192 +; CHECK-NEXT: [[ADDR1:%.*]] = xor i64 [[BASE]], 32 +; CHECK-NEXT: [[ADDR2:%.*]] = xor i64 [[BASE]], 64 +; CHECK-NEXT: [[ADDR3_OR_DISJOINT:%.*]] = or disjoint i64 [[ADDR1]], 2048 +; CHECK-NEXT: [[GEP1:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[ADDR1]] +; CHECK-NEXT: [[GEP2:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[ADDR2]] +; CHECK-NEXT: [[GEP3:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[ADDR3_OR_DISJOINT]] +; CHECK-NEXT: [[VAL1:%.*]] = load half, ptr [[GEP1]], align 2 +; CHECK-NEXT: [[VAL2:%.*]] = load half, ptr [[GEP2]], align 2 +; CHECK-NEXT: [[VAL3:%.*]] = load half, ptr [[GEP3]], align 2 +; CHECK-NEXT: [[VAL1_F:%.*]] = fpext half [[VAL1]] to float +; CHECK-NEXT: [[VAL2_F:%.*]] = fpext half [[VAL2]] to float +; CHECK-NEXT: [[VAL3_F:%.*]] = fpext half [[VAL3]] to float +; CHECK-NEXT: [[SUM1_F:%.*]] = fadd float [[VAL1_F]], [[VAL2_F]] +; CHECK-NEXT: [[SUM_TOTAL_F:%.*]] = fadd float [[SUM1_F]], [[VAL3_F]] +; CHECK-NEXT: [[RESULT_H:%.*]] = fptrunc float [[SUM_TOTAL_F]] to half +; CHECK-NEXT: ret half [[RESULT_H]] +; +entry: + %base = and i64 %input, -8192 ; Clear low bits + %addr1 = xor i64 %base, 32 + %addr2 = xor i64 %base, 64 ; Should not be transformed + %addr3 = xor i64 %base, 2080 + %gep1 = getelementptr i8, ptr %ptr, i64 %addr1 + %gep2 = getelementptr i8, ptr %ptr, i64 %addr2 + %gep3 = getelementptr i8, ptr %ptr, i64 %addr3 + %val1 = load half, ptr %gep1 + %val2 = load half, ptr %gep2 + %val3 = load half, ptr %gep3 + %val1.f = fpext half %val1 to float + %val2.f = fpext half %val2 to float + %val3.f = fpext half %val3 to float + %sum1.f = fadd float %val1.f, %val2.f + %sum_total.f = fadd float %sum1.f, %val3.f + %result.h = fptrunc float %sum_total.f to half + ret half %result.h +} + + +; Test case with xor instructions in different basic blocks +define half @test_dom_tree(ptr %ptr, i64 %input, i1 %cond) { +; CHECK-LABEL: define half @test_dom_tree( +; CHECK-SAME: ptr [[PTR:%.*]], i64 [[INPUT:%.*]], i1 [[COND:%.*]]) { +; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: [[BASE:%.*]] = and i64 [[INPUT]], -8192 +; CHECK-NEXT: [[ADDR1:%.*]] = xor i64 [[BASE]], 16 +; CHECK-NEXT: [[GEP1:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[ADDR1]] +; CHECK-NEXT: [[VAL1:%.*]] = load half, ptr [[GEP1]], align 2 +; CHECK-NEXT: br i1 [[COND]], label %[[THEN:.*]], label %[[ELSE:.*]] +; CHECK: [[THEN]]: +; CHECK-NEXT: [[ADDR2_OR_DISJOINT:%.*]] = or disjoint i64 [[ADDR1]], 32 +; CHECK-NEXT: [[GEP2:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[ADDR2_OR_DISJOINT]] +; CHECK-NEXT: [[VAL2:%.*]] = load half, ptr [[GEP2]], align 2 +; CHECK-NEXT: br label %[[MERGE:.*]] +; CHECK: [[ELSE]]: +; CHECK-NEXT: [[ADDR3_OR_DISJOINT:%.*]] = or disjoint i64 [[ADDR1]], 96 +; CHECK-NEXT: [[GEP3:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[ADDR3_OR_DISJOINT]] +; CHECK-NEXT: [[VAL3:%.*]] = load half, ptr [[GEP3]], align 2 +; CHECK-NEXT: br label %[[MERGE]] +; CHECK: [[MERGE]]: +; CHECK-NEXT: [[VAL_FROM_BRANCH:%.*]] = phi half [ [[VAL2]], %[[THEN]] ], [ [[VAL3]], %[[ELSE]] ] +; CHECK-NEXT: [[ADDR4_OR_DISJOINT:%.*]] = or disjoint i64 [[ADDR1]], 224 +; CHECK-NEXT: [[GEP4:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[ADDR4_OR_DISJOINT]] +; CHECK-NEXT: [[VAL4:%.*]] = load half, ptr [[GEP4]], align 2 +; CHECK-NEXT: [[VAL1_F:%.*]] = fpext half [[VAL1]] to float +; CHECK-NEXT: [[VAL_FROM_BRANCH_F:%.*]] = fpext half [[VAL_FROM_BRANCH]] to float +; CHECK-NEXT: [[VAL4_F:%.*]] = fpext half [[VAL4]] to float +; CHECK-NEXT: [[SUM_INTERMEDIATE_F:%.*]] = fadd float [[VAL1_F]], [[VAL_FROM_BRANCH_F]] +; CHECK-NEXT: [[FINAL_SUM_F:%.*]] = fadd float [[SUM_INTERMEDIATE_F]], [[VAL4_F]] +; CHECK-NEXT: [[RESULT_H:%.*]] = fptrunc float [[FINAL_SUM_F]] to half +; CHECK-NEXT: ret half [[RESULT_H]] +; +entry: + %base = and i64 %input, -8192 ; Clear low bits + %addr1 = xor i64 %base,16 + %gep1 = getelementptr i8, ptr %ptr, i64 %addr1 + %val1 = load half, ptr %gep1 + br i1 %cond, label %then, label %else + +then: + %addr2 = xor i64 %base, 48 + %gep2 = getelementptr i8, ptr %ptr, i64 %addr2 + %val2 = load half, ptr %gep2 + br label %merge + +else: + %addr3 = xor i64 %base, 112 + %gep3 = getelementptr i8, ptr %ptr, i64 %addr3 + %val3 = load half, ptr %gep3 + br label %merge + +merge: + %val_from_branch = phi half [ %val2, %then ], [ %val3, %else ] + %addr4 = xor i64 %base, 240 + %gep4 = getelementptr i8, ptr %ptr, i64 %addr4 + %val4 = load half, ptr %gep4 + %val1.f = fpext half %val1 to float + %val_from_branch.f = fpext half %val_from_branch to float + %val4.f = fpext half %val4 to float + %sum_intermediate.f = fadd float %val1.f, %val_from_branch.f + %final_sum.f = fadd float %sum_intermediate.f, %val4.f + %result.h = fptrunc float %final_sum.f to half + ret half %result.h +} +