77#include < iterator>
88#include < numeric>
99
10+ #define DEBUG_TYPE " tritongpu-coalesce"
11+ #define DBGS () (llvm::dbgs() << " [" DEBUG_TYPE " ]: " )
12+ #define LDBG (X ) LLVM_DEBUG(DBGS() << X << " \n " )
13+
1014using namespace mlir ;
1115using namespace mlir ::triton;
1216
@@ -21,11 +25,16 @@ template <class T> SmallVector<unsigned, 4> argSort(const T &arr) {
2125 return ret;
2226}
2327
24- unsigned getElementBitWidth (const Value &val) {
28+ // Type of val can be either Tensor Pointer or Tensor.
29+ static RankedTensorType getTensorType (const Value &val) {
2530 auto valType = val.getType ();
2631 if (valType.isa <PointerType>())
2732 valType = valType.cast <PointerType>().getPointeeType ();
28- auto tensorType = valType.cast <RankedTensorType>();
33+ return valType.cast <RankedTensorType>();
34+ }
35+
36+ unsigned getElementBitWidth (const Value &val) {
37+ auto tensorType = getTensorType (val);
2938
3039 auto typeForMem =
3140 tensorType.getElementType ().isa <PointerType>()
@@ -48,70 +57,71 @@ static Value getMemAccessPtr(Operation *op) {
4857 return nullptr ;
4958}
5059
60+ // TODO(Keren): integrate it into AxisInfoAnalysis
61+ static AxisInfo getAxisInfoForTensorPointer (const Value &val) {
62+ auto valType = val.getType ();
63+ // TODO(Chenggang): encoding for tensor pointers is meaningless, remove
64+ // these later while merging into the GitHub main
65+ auto ptrType = valType.cast <PointerType>();
66+ auto tensorTy = ptrType.getPointeeType ().cast <RankedTensorType>();
67+ auto makeTensorPtr = getMakeTensorPtrOp (val);
68+ auto order = makeTensorPtr.getOrder ();
69+ auto tileShape = triton::gpu::getShapePerCTA (tensorTy);
70+ size_t rank = order.size ();
71+ auto elemSizeInBytes = tensorTy.getElementType ().getIntOrFloatBitWidth () / 8 ;
72+ SmallVector<int64_t > contiguity (rank, 1 );
73+ SmallVector<int64_t > divisibility (rank, 1 );
74+ SmallVector<int64_t > constancy (rank, 1 );
75+ // The contiguity in `order[0]` is `tileShape[order[0]]`
76+ // The divisibility in `order[0]` is 16
77+ // TODO[goostavz]: confirm the legality of it
78+ contiguity[order[0 ]] = tileShape[order[0 ]];
79+ divisibility[order[0 ]] = 16 * 8 / elemSizeInBytes;
80+ return AxisInfo (contiguity, divisibility, constancy);
81+ }
82+
5183struct CoalescePass : public TritonGPUCoalesceBase <CoalescePass> {
5284 void
5385 setCoalescedEncoding (ModuleAxisInfoAnalysis &axisInfoAnalysis, Operation *op,
5486 int numWarps, int threadsPerWarp,
5587 llvm::MapVector<Operation *, Attribute> &layoutMap) {
5688 Value ptr = getMemAccessPtr (op);
57- auto refType = ptr.getType ();
58- if (refType.isa <PointerType>())
59- refType = refType.cast <PointerType>().getPointeeType ();
60- auto refTensorType = refType.cast <RankedTensorType>();
61-
62- // TODO(Keren): integrate it into AxisInfoAnalysis
63- // Get axis info
64- auto queryAxisInfo = [&](const Value &val) -> AxisInfo {
65- auto valType = val.getType ();
66- // Tensor pointer
67- // TODO(Chenggang): encoding for tensor pointers is meaningless, remove
68- // these later while merging into the GitHub main
69- if (auto ptrType = valType.dyn_cast <PointerType>()) {
70- auto tensorTy = ptrType.getPointeeType ().dyn_cast <RankedTensorType>();
71- assert (tensorTy);
72- auto makeTensorPtr = getMakeTensorPtrOp (val);
73- auto order = makeTensorPtr.getOrder ();
74- auto tileShape = triton::gpu::getShapePerCTA (tensorTy);
75- size_t rank = order.size ();
76- auto elemSizeInBytes =
77- tensorTy.getElementType ().getIntOrFloatBitWidth () / 8 ;
78- SmallVector<int64_t > contiguity (rank, 1 );
79- SmallVector<int64_t > divisibility (rank, 1 );
80- SmallVector<int64_t > constancy (rank, 1 );
81- // The contiguity in `order[0]` is `tileShape[order[0]]`
82- // The divisibility in `order[0]` is 16
83- // TODO[goostavz]: confirm the legality of it
84- contiguity[order[0 ]] = tileShape[order[0 ]];
85- divisibility[order[0 ]] = 16 * 8 / elemSizeInBytes;
86- return AxisInfo (contiguity, divisibility, constancy);
87- }
88- // Normal cases
89- assert (valType.isa <RankedTensorType>());
90- return *axisInfoAnalysis.getAxisInfo (val);
91- };
89+ auto refTensorType = getTensorType (ptr);
9290
9391 // Get the contiguity order of `ptr`
9492 SmallVector<unsigned > order;
95- if ( auto ptrType = ptr. getType (). dyn_cast <PointerType>()) {
96- // Tensor pointer
93+ LDBG ( " op is: " << *op);
94+ if (ptr. getType (). isa <PointerType>()) {
9795 auto makeTensorPtr = getMakeTensorPtrOp (ptr);
9896 std::copy (makeTensorPtr.getOrder ().begin (),
9997 makeTensorPtr.getOrder ().end (), std::back_inserter (order));
10098 } else {
10199 // Normal cases
102- order = argSort (queryAxisInfo (ptr).getContiguity ());
100+ auto contiguity = axisInfoAnalysis.getAxisInfo (ptr)->getContiguity ();
101+ order = argSort (contiguity);
102+ LLVM_DEBUG ({
103+ DBGS () << " contiguity is: " ;
104+ for (const auto &O : contiguity) {
105+ llvm::dbgs () << O << " " ;
106+ }
107+ llvm::dbgs () << " \n " ;
108+ });
103109 }
110+ LLVM_DEBUG ({
111+ DBGS () << " order is: " ;
112+ for (const auto &O : order) {
113+ llvm::dbgs () << O << " " ;
114+ }
115+ llvm::dbgs () << " \n " ;
116+ });
104117
105118 auto matchesShape = [&refTensorType](const Value &val) {
106119 if (val.getType () == refTensorType) {
107120 return true ;
108121 }
109122
110123 auto rttType = val.getType ().dyn_cast <RankedTensorType>();
111- if (!rttType) {
112- return false ;
113- }
114- return rttType.getShape () == refTensorType.getShape ();
124+ return rttType ? rttType.getShape () == refTensorType.getShape () : false ;
115125 };
116126
117127 // The desired divisibility is the maximum divisibility
@@ -120,29 +130,35 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
120130 // We only do it for normal tensors of pointers, not tensor pointers.
121131 llvm::SmallSetVector<Operation *, 32 > memAccessesSameOrder;
122132 memAccessesSameOrder.insert (op);
123- if (refType. isa <RankedTensorType>() && ptr.getDefiningOp ()) {
133+ if (ptr.getDefiningOp ()) {
124134 for (Operation *use : mlir::multiRootGetSlice (op)) {
125135 Value val = getMemAccessPtr (use);
126- if (!val)
127- continue ;
128- if (!matchesShape (val))
136+ if (!val || !matchesShape (val) || memAccessesSameOrder.contains (use))
129137 continue ;
130138 auto currOrder =
131139 argSort (axisInfoAnalysis.getAxisInfo (val)->getContiguity ());
132140 if (order == currOrder) {
141+ LDBG (" multi-root-slice: insert to memAccessesSameOrder " << *use);
133142 memAccessesSameOrder.insert (use);
134143 }
135144 }
136145 }
137146
138147 auto shapePerCTA = triton::gpu::getShapePerCTA (refTensorType);
148+ LLVM_DEBUG ({
149+ DBGS () << " shapePerCTA is " ;
150+ for (const auto &O : shapePerCTA) {
151+ llvm::dbgs () << O << " " ;
152+ }
153+ llvm::dbgs () << " \n " ;
154+ });
139155 int numElems = product<int64_t >(shapePerCTA);
140156 int numThreads = numWarps * threadsPerWarp;
141157 int numElemsPerThread = std::max (numElems / numThreads, 1 );
142158
143159 // For tensor of pointers, the element to access is the pointee type;
144- // while for tensor pointer type (`refType ` is directly the final shape),
145- // the element to access is itself.
160+ // while for tensor pointer type (`refTensorType ` is directly the final
161+ // shape), the element to access is itself.
146162 auto typeForMem = refTensorType.getElementType ().isa <PointerType>()
147163 ? refTensorType.getElementType ()
148164 .cast <PointerType>()
@@ -151,7 +167,13 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
151167
152168 auto getNumElementPerThread = [&](Operation *op) {
153169 Value val = getMemAccessPtr (op);
154- auto valInfo = queryAxisInfo (val);
170+ AxisInfo valInfo;
171+ if (val.getType ().isa <PointerType>()) {
172+ valInfo = getAxisInfoForTensorPointer (val);
173+ } else {
174+ assert (val.getType ().isa <RankedTensorType>());
175+ valInfo = *axisInfoAnalysis.getAxisInfo (val);
176+ }
155177 unsigned elemNumBits = getElementBitWidth (val);
156178 unsigned elemNumBytes = std::max (elemNumBits / 8 , 1u );
157179 unsigned maxMultipleBytes = valInfo.getDivisibility (order[0 ]);
@@ -163,12 +185,17 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
163185 return currPerThread;
164186 };
165187 unsigned perThread = getNumElementPerThread (op);
166- for (Operation *op : memAccessesSameOrder) {
167- unsigned currPerThread = getNumElementPerThread (op);
188+ LDBG (" perThread for op: " << perThread);
189+ for (Operation *opSameOrder : memAccessesSameOrder) {
190+ if (opSameOrder == op)
191+ continue ;
192+ unsigned currPerThread = getNumElementPerThread (opSameOrder);
193+ LDBG (" perThread for opSameOrder: " << currPerThread);
168194 perThread = std::max (perThread, currPerThread);
169195 }
170196
171197 perThread = std::min<int >(perThread, numElemsPerThread);
198+ LDBG (" perThread: " << perThread);
172199
173200 if (!dyn_cast<triton::LoadOp>(op)) {
174201 // For ops that can result in a global memory write, we should enforce
0 commit comments