Skip to content

Commit 1bc9c0e

Browse files
authored
NFC: simplify code in Coalesce.cpp, also add DEBUG_ONLY support (triton-lang#2820)
1 parent c5f83b2 commit 1bc9c0e

File tree

1 file changed

+80
-53
lines changed

1 file changed

+80
-53
lines changed

lib/Dialect/TritonGPU/Transforms/Coalesce.cpp

Lines changed: 80 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
#include <iterator>
88
#include <numeric>
99

10+
#define DEBUG_TYPE "tritongpu-coalesce"
11+
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
12+
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
13+
1014
using namespace mlir;
1115
using namespace mlir::triton;
1216

@@ -21,11 +25,16 @@ template <class T> SmallVector<unsigned, 4> argSort(const T &arr) {
2125
return ret;
2226
}
2327

24-
unsigned getElementBitWidth(const Value &val) {
28+
// Type of val can be either Tensor Pointer or Tensor.
29+
static RankedTensorType getTensorType(const Value &val) {
2530
auto valType = val.getType();
2631
if (valType.isa<PointerType>())
2732
valType = valType.cast<PointerType>().getPointeeType();
28-
auto tensorType = valType.cast<RankedTensorType>();
33+
return valType.cast<RankedTensorType>();
34+
}
35+
36+
unsigned getElementBitWidth(const Value &val) {
37+
auto tensorType = getTensorType(val);
2938

3039
auto typeForMem =
3140
tensorType.getElementType().isa<PointerType>()
@@ -48,70 +57,71 @@ static Value getMemAccessPtr(Operation *op) {
4857
return nullptr;
4958
}
5059

60+
// TODO(Keren): integrate it into AxisInfoAnalysis
61+
static AxisInfo getAxisInfoForTensorPointer(const Value &val) {
62+
auto valType = val.getType();
63+
// TODO(Chenggang): encoding for tensor pointers is meaningless, remove
64+
// these later while merging into the GitHub main
65+
auto ptrType = valType.cast<PointerType>();
66+
auto tensorTy = ptrType.getPointeeType().cast<RankedTensorType>();
67+
auto makeTensorPtr = getMakeTensorPtrOp(val);
68+
auto order = makeTensorPtr.getOrder();
69+
auto tileShape = triton::gpu::getShapePerCTA(tensorTy);
70+
size_t rank = order.size();
71+
auto elemSizeInBytes = tensorTy.getElementType().getIntOrFloatBitWidth() / 8;
72+
SmallVector<int64_t> contiguity(rank, 1);
73+
SmallVector<int64_t> divisibility(rank, 1);
74+
SmallVector<int64_t> constancy(rank, 1);
75+
// The contiguity in `order[0]` is `tileShape[order[0]]`
76+
// The divisibility in `order[0]` is 16
77+
// TODO[goostavz]: confirm the legality of it
78+
contiguity[order[0]] = tileShape[order[0]];
79+
divisibility[order[0]] = 16 * 8 / elemSizeInBytes;
80+
return AxisInfo(contiguity, divisibility, constancy);
81+
}
82+
5183
struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
5284
void
5385
setCoalescedEncoding(ModuleAxisInfoAnalysis &axisInfoAnalysis, Operation *op,
5486
int numWarps, int threadsPerWarp,
5587
llvm::MapVector<Operation *, Attribute> &layoutMap) {
5688
Value ptr = getMemAccessPtr(op);
57-
auto refType = ptr.getType();
58-
if (refType.isa<PointerType>())
59-
refType = refType.cast<PointerType>().getPointeeType();
60-
auto refTensorType = refType.cast<RankedTensorType>();
61-
62-
// TODO(Keren): integrate it into AxisInfoAnalysis
63-
// Get axis info
64-
auto queryAxisInfo = [&](const Value &val) -> AxisInfo {
65-
auto valType = val.getType();
66-
// Tensor pointer
67-
// TODO(Chenggang): encoding for tensor pointers is meaningless, remove
68-
// these later while merging into the GitHub main
69-
if (auto ptrType = valType.dyn_cast<PointerType>()) {
70-
auto tensorTy = ptrType.getPointeeType().dyn_cast<RankedTensorType>();
71-
assert(tensorTy);
72-
auto makeTensorPtr = getMakeTensorPtrOp(val);
73-
auto order = makeTensorPtr.getOrder();
74-
auto tileShape = triton::gpu::getShapePerCTA(tensorTy);
75-
size_t rank = order.size();
76-
auto elemSizeInBytes =
77-
tensorTy.getElementType().getIntOrFloatBitWidth() / 8;
78-
SmallVector<int64_t> contiguity(rank, 1);
79-
SmallVector<int64_t> divisibility(rank, 1);
80-
SmallVector<int64_t> constancy(rank, 1);
81-
// The contiguity in `order[0]` is `tileShape[order[0]]`
82-
// The divisibility in `order[0]` is 16
83-
// TODO[goostavz]: confirm the legality of it
84-
contiguity[order[0]] = tileShape[order[0]];
85-
divisibility[order[0]] = 16 * 8 / elemSizeInBytes;
86-
return AxisInfo(contiguity, divisibility, constancy);
87-
}
88-
// Normal cases
89-
assert(valType.isa<RankedTensorType>());
90-
return *axisInfoAnalysis.getAxisInfo(val);
91-
};
89+
auto refTensorType = getTensorType(ptr);
9290

9391
// Get the contiguity order of `ptr`
9492
SmallVector<unsigned> order;
95-
if (auto ptrType = ptr.getType().dyn_cast<PointerType>()) {
96-
// Tensor pointer
93+
LDBG("op is: " << *op);
94+
if (ptr.getType().isa<PointerType>()) {
9795
auto makeTensorPtr = getMakeTensorPtrOp(ptr);
9896
std::copy(makeTensorPtr.getOrder().begin(),
9997
makeTensorPtr.getOrder().end(), std::back_inserter(order));
10098
} else {
10199
// Normal cases
102-
order = argSort(queryAxisInfo(ptr).getContiguity());
100+
auto contiguity = axisInfoAnalysis.getAxisInfo(ptr)->getContiguity();
101+
order = argSort(contiguity);
102+
LLVM_DEBUG({
103+
DBGS() << "contiguity is: ";
104+
for (const auto &O : contiguity) {
105+
llvm::dbgs() << O << " ";
106+
}
107+
llvm::dbgs() << "\n";
108+
});
103109
}
110+
LLVM_DEBUG({
111+
DBGS() << "order is: ";
112+
for (const auto &O : order) {
113+
llvm::dbgs() << O << " ";
114+
}
115+
llvm::dbgs() << "\n";
116+
});
104117

105118
auto matchesShape = [&refTensorType](const Value &val) {
106119
if (val.getType() == refTensorType) {
107120
return true;
108121
}
109122

110123
auto rttType = val.getType().dyn_cast<RankedTensorType>();
111-
if (!rttType) {
112-
return false;
113-
}
114-
return rttType.getShape() == refTensorType.getShape();
124+
return rttType ? rttType.getShape() == refTensorType.getShape() : false;
115125
};
116126

117127
// The desired divisibility is the maximum divisibility
@@ -120,29 +130,35 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
120130
// We only do it for normal tensors of pointers, not tensor pointers.
121131
llvm::SmallSetVector<Operation *, 32> memAccessesSameOrder;
122132
memAccessesSameOrder.insert(op);
123-
if (refType.isa<RankedTensorType>() && ptr.getDefiningOp()) {
133+
if (ptr.getDefiningOp()) {
124134
for (Operation *use : mlir::multiRootGetSlice(op)) {
125135
Value val = getMemAccessPtr(use);
126-
if (!val)
127-
continue;
128-
if (!matchesShape(val))
136+
if (!val || !matchesShape(val) || memAccessesSameOrder.contains(use))
129137
continue;
130138
auto currOrder =
131139
argSort(axisInfoAnalysis.getAxisInfo(val)->getContiguity());
132140
if (order == currOrder) {
141+
LDBG("multi-root-slice: insert to memAccessesSameOrder " << *use);
133142
memAccessesSameOrder.insert(use);
134143
}
135144
}
136145
}
137146

138147
auto shapePerCTA = triton::gpu::getShapePerCTA(refTensorType);
148+
LLVM_DEBUG({
149+
DBGS() << "shapePerCTA is ";
150+
for (const auto &O : shapePerCTA) {
151+
llvm::dbgs() << O << " ";
152+
}
153+
llvm::dbgs() << "\n";
154+
});
139155
int numElems = product<int64_t>(shapePerCTA);
140156
int numThreads = numWarps * threadsPerWarp;
141157
int numElemsPerThread = std::max(numElems / numThreads, 1);
142158

143159
// For tensor of pointers, the element to access is the pointee type;
144-
// while for tensor pointer type (`refType` is directly the final shape),
145-
// the element to access is itself.
160+
// while for tensor pointer type (`refTensorType` is directly the final
161+
// shape), the element to access is itself.
146162
auto typeForMem = refTensorType.getElementType().isa<PointerType>()
147163
? refTensorType.getElementType()
148164
.cast<PointerType>()
@@ -151,7 +167,13 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
151167

152168
auto getNumElementPerThread = [&](Operation *op) {
153169
Value val = getMemAccessPtr(op);
154-
auto valInfo = queryAxisInfo(val);
170+
AxisInfo valInfo;
171+
if (val.getType().isa<PointerType>()) {
172+
valInfo = getAxisInfoForTensorPointer(val);
173+
} else {
174+
assert(val.getType().isa<RankedTensorType>());
175+
valInfo = *axisInfoAnalysis.getAxisInfo(val);
176+
}
155177
unsigned elemNumBits = getElementBitWidth(val);
156178
unsigned elemNumBytes = std::max(elemNumBits / 8, 1u);
157179
unsigned maxMultipleBytes = valInfo.getDivisibility(order[0]);
@@ -163,12 +185,17 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
163185
return currPerThread;
164186
};
165187
unsigned perThread = getNumElementPerThread(op);
166-
for (Operation *op : memAccessesSameOrder) {
167-
unsigned currPerThread = getNumElementPerThread(op);
188+
LDBG("perThread for op: " << perThread);
189+
for (Operation *opSameOrder : memAccessesSameOrder) {
190+
if (opSameOrder == op)
191+
continue;
192+
unsigned currPerThread = getNumElementPerThread(opSameOrder);
193+
LDBG("perThread for opSameOrder: " << currPerThread);
168194
perThread = std::max(perThread, currPerThread);
169195
}
170196

171197
perThread = std::min<int>(perThread, numElemsPerThread);
198+
LDBG("perThread: " << perThread);
172199

173200
if (!dyn_cast<triton::LoadOp>(op)) {
174201
// For ops that can result in a global memory write, we should enforce

0 commit comments

Comments
 (0)