Skip to content

Commit f73e544

Browse files
River707tensorflower-gardener
authored andcommitted
Refactor LinalgDialect::parseType to use the DialectAsmParser methods directly.
This simplifies the implementation, and removes the need to do explicit string manipulation. A utility method 'parseDimensionList' is added to the DialectAsmParser to simplify defining types and attributes that contain shapes. PiperOrigin-RevId: 278020604 Change-Id: Id77c02a415f8976c2f36f2a9dea8ae92e867c392
1 parent 9f03d28 commit f73e544

File tree

3 files changed

+49
-39
lines changed

3 files changed

+49
-39
lines changed

third_party/mlir/include/mlir/IR/DialectImplementation.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,19 @@ class DialectAsmParser {
295295
return emitError(loc, "invalid kind of type specified");
296296
return success();
297297
}
298+
299+
/// Parse a 'x' separated dimension list. This populates the dimension list,
300+
/// using -1 for the `?` dimensions if `allowDynamic` is set and errors out on
301+
/// `?` otherwise.
302+
///
303+
/// dimension-list ::= (dimension `x`)*
304+
/// dimension ::= `?` | integer
305+
///
306+
/// When `allowDynamic` is not set, this is used to parse:
307+
///
308+
/// static-dimension-list ::= (integer `x`)*
309+
virtual ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions,
310+
bool allowDynamic = true) = 0;
298311
};
299312

300313
} // end namespace mlir

third_party/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp

Lines changed: 31 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -109,54 +109,46 @@ Optional<int64_t> mlir::linalg::BufferType::getBufferSize() {
109109
}
110110

111111
Type mlir::linalg::LinalgDialect::parseType(DialectAsmParser &parser) const {
112-
Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
113-
StringRef spec = parser.getFullSymbolSpec();
114-
StringRef origSpec = spec;
112+
// Parse the main keyword for the type.
113+
StringRef keyword;
114+
if (parser.parseKeyword(&keyword))
115+
return Type();
115116
MLIRContext *context = getContext();
116-
if (spec == "range")
117-
return RangeType::get(getContext());
118-
else if (spec.consume_front("buffer")) {
119-
if (spec.consume_front("<") && spec.consume_back(">")) {
120-
StringRef sizeSpec, typeSpec;
121-
std::tie(sizeSpec, typeSpec) = spec.split('x');
122-
if (typeSpec.empty()) {
123-
emitError(loc, "expected 'x' followed by element type");
124-
return Type();
125-
}
126-
// Check for '?'
127-
int64_t bufferSize = -1;
128-
if (!sizeSpec.consume_front("?")) {
129-
if (sizeSpec.consumeInteger(10, bufferSize)) {
130-
emitError(loc, "expected buffer size to be an unsigned integer");
131-
return Type();
132-
}
133-
}
134-
if (!sizeSpec.empty()) {
135-
emitError(loc, "unexpected token '") << sizeSpec << "'";
136-
}
137-
138-
typeSpec = typeSpec.trim();
139-
auto t = mlir::parseType(typeSpec, context);
140-
if (!t) {
141-
emitError(loc, "invalid type specification: '") << typeSpec << "'";
142-
return Type();
143-
}
144-
return (bufferSize == -1 ? BufferType::get(getContext(), t)
145-
: BufferType::get(getContext(), t, bufferSize));
117+
118+
// Handle 'range' types.
119+
if (keyword == "range")
120+
return RangeType::get(context);
121+
122+
// Handle 'buffer' types.
123+
if (keyword == "buffer") {
124+
llvm::SMLoc dimensionLoc;
125+
SmallVector<int64_t, 1> size;
126+
Type type;
127+
if (parser.parseLess() || parser.getCurrentLocation(&dimensionLoc) ||
128+
parser.parseDimensionList(size) || parser.parseType(type) ||
129+
parser.parseGreater())
130+
return Type();
131+
132+
if (size.size() != 1) {
133+
parser.emitError(dimensionLoc, "expected single element in size list");
134+
return Type();
146135
}
136+
137+
return (size.front() == -1 ? BufferType::get(context, type)
138+
: BufferType::get(context, type, size.front()));
147139
}
148-
return (emitError(loc, "unknown Linalg type: " + origSpec), Type());
140+
141+
parser.emitError(parser.getNameLoc(), "unknown Linalg type: " + keyword);
142+
return Type();
149143
}
150144

151-
/// BufferType prints as "buffer<element_type>".
145+
/// BufferType prints as "buffer<size x element_type>".
152146
static void print(BufferType bt, DialectAsmPrinter &os) {
153147
os << "buffer<";
154-
auto bs = bt.getBufferSize();
155-
if (bs) {
148+
if (Optional<int64_t> bs = bt.getBufferSize())
156149
os << bs.getValue();
157-
} else {
150+
else
158151
os << "?";
159-
}
160152
os << "x" << bt.getElementType() << ">";
161153
}
162154

third_party/mlir/lib/Parser/Parser.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,11 @@ class CustomDialectAsmParser : public DialectAsmParser {
633633
return success(static_cast<bool>(result));
634634
}
635635

636+
ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions,
637+
bool allowDynamic) override {
638+
return parser.parseDimensionListRanked(dimensions, allowDynamic);
639+
}
640+
636641
private:
637642
/// The full symbol specification.
638643
StringRef fullSpec;

0 commit comments

Comments
 (0)