@@ -109,54 +109,46 @@ Optional<int64_t> mlir::linalg::BufferType::getBufferSize() {
109
109
}
110
110
111
111
Type mlir::linalg::LinalgDialect::parseType (DialectAsmParser &parser) const {
112
- Location loc = parser.getEncodedSourceLoc (parser.getNameLoc ());
113
- StringRef spec = parser.getFullSymbolSpec ();
114
- StringRef origSpec = spec;
112
+ // Parse the main keyword for the type.
113
+ StringRef keyword;
114
+ if (parser.parseKeyword (&keyword))
115
+ return Type ();
115
116
MLIRContext *context = getContext ();
116
- if (spec == " range" )
117
- return RangeType::get (getContext ());
118
- else if (spec.consume_front (" buffer" )) {
119
- if (spec.consume_front (" <" ) && spec.consume_back (" >" )) {
120
- StringRef sizeSpec, typeSpec;
121
- std::tie (sizeSpec, typeSpec) = spec.split (' x' );
122
- if (typeSpec.empty ()) {
123
- emitError (loc, " expected 'x' followed by element type" );
124
- return Type ();
125
- }
126
- // Check for '?'
127
- int64_t bufferSize = -1 ;
128
- if (!sizeSpec.consume_front (" ?" )) {
129
- if (sizeSpec.consumeInteger (10 , bufferSize)) {
130
- emitError (loc, " expected buffer size to be an unsigned integer" );
131
- return Type ();
132
- }
133
- }
134
- if (!sizeSpec.empty ()) {
135
- emitError (loc, " unexpected token '" ) << sizeSpec << " '" ;
136
- }
137
-
138
- typeSpec = typeSpec.trim ();
139
- auto t = mlir::parseType (typeSpec, context);
140
- if (!t) {
141
- emitError (loc, " invalid type specification: '" ) << typeSpec << " '" ;
142
- return Type ();
143
- }
144
- return (bufferSize == -1 ? BufferType::get (getContext (), t)
145
- : BufferType::get (getContext (), t, bufferSize));
117
+
118
+ // Handle 'range' types.
119
+ if (keyword == " range" )
120
+ return RangeType::get (context);
121
+
122
+ // Handle 'buffer' types.
123
+ if (keyword == " buffer" ) {
124
+ llvm::SMLoc dimensionLoc;
125
+ SmallVector<int64_t , 1 > size;
126
+ Type type;
127
+ if (parser.parseLess () || parser.getCurrentLocation (&dimensionLoc) ||
128
+ parser.parseDimensionList (size) || parser.parseType (type) ||
129
+ parser.parseGreater ())
130
+ return Type ();
131
+
132
+ if (size.size () != 1 ) {
133
+ parser.emitError (dimensionLoc, " expected single element in size list" );
134
+ return Type ();
146
135
}
136
+
137
+ return (size.front () == -1 ? BufferType::get (context, type)
138
+ : BufferType::get (context, type, size.front ()));
147
139
}
148
- return (emitError (loc, " unknown Linalg type: " + origSpec), Type ());
140
+
141
+ parser.emitError (parser.getNameLoc (), " unknown Linalg type: " + keyword);
142
+ return Type ();
149
143
}
150
144
151
- // / BufferType prints as "buffer<element_type>".
145
+ // / BufferType prints as "buffer<size x element_type>".
152
146
static void print (BufferType bt, DialectAsmPrinter &os) {
153
147
os << " buffer<" ;
154
- auto bs = bt.getBufferSize ();
155
- if (bs) {
148
+ if (Optional<int64_t > bs = bt.getBufferSize ())
156
149
os << bs.getValue ();
157
- } else {
150
+ else
158
151
os << " ?" ;
159
- }
160
152
os << " x" << bt.getElementType () << " >" ;
161
153
}
162
154
0 commit comments