Skip to content

Commit 36730cf

Browse files
[mlir][ODS] Verify type constraints in Types and Attributes
When a type/attribute is defined in TableGen, a type constraint can be used for parameters, but the type constraint verification was missing. Example: ``` def TestTypeVerification : Test_Type<"TestTypeVerification"> { let parameters = (ins AnyTypeOf<[I16, I32]>:$param); // ... } ``` No verification code was generated to ensure that `$param` is I16 or I32. When type constraints a present, a new method will generated for types and attributes: `verifyInvariantsImpl`. (The naming is similar to op verifiers.) The user-provided verifier is called `verify` (no change). There is now a new entry point to type/attribute verification: `verifyInvariants`. This function calls both `verifyInvariantsImpl` and `verify`. If neither of those two verifications are present, the `verifyInvariants` function is not generated. When a type/attribute is not defined in TableGen, but a verifier is needed, users can implement the `verifyInvariants` function. (This function was previously called `verify`.)
1 parent 1919db9 commit 36730cf

24 files changed

+280
-90
lines changed

mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h

+4-3
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,10 @@ class MMAMatrixType
148148

149149
/// Verify that shape and elementType are actually allowed for the
150150
/// MMAMatrixType.
151-
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
152-
ArrayRef<int64_t> shape, Type elementType,
153-
StringRef operand);
151+
static LogicalResult
152+
verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
153+
ArrayRef<int64_t> shape, Type elementType,
154+
StringRef operand);
154155

155156
/// Get number of dims.
156157
unsigned getNumDims() const;

mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h

+7-5
Original file line numberDiff line numberDiff line change
@@ -180,11 +180,13 @@ class LLVMStructType
180180
ArrayRef<Type> getBody() const;
181181

182182
/// Verifies that the type about to be constructed is well-formed.
183-
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
184-
StringRef, bool);
185-
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
186-
ArrayRef<Type> types, bool);
187-
using Base::verify;
183+
static LogicalResult
184+
verifyInvariants(function_ref<InFlightDiagnostic()> emitError, StringRef,
185+
bool);
186+
static LogicalResult
187+
verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
188+
ArrayRef<Type> types, bool);
189+
using Base::verifyInvariants;
188190

189191
/// Hooks for DataLayoutTypeInterface. Should not be called directly. Obtain a
190192
/// DataLayout instance and query it instead.

mlir/include/mlir/Dialect/Quant/QuantTypes.h

+22-21
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,10 @@ class QuantizedType : public Type {
5454
/// The maximum number of bits supported for storage types.
5555
static constexpr unsigned MaxStorageBits = 32;
5656

57-
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
58-
unsigned flags, Type storageType,
59-
Type expressedType, int64_t storageTypeMin,
60-
int64_t storageTypeMax);
57+
static LogicalResult
58+
verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
59+
Type storageType, Type expressedType, int64_t storageTypeMin,
60+
int64_t storageTypeMax);
6161

6262
/// Support method to enable LLVM-style type casting.
6363
static bool classof(Type type);
@@ -214,10 +214,10 @@ class AnyQuantizedType
214214
int64_t storageTypeMax);
215215

216216
/// Verifies construction invariants and issues errors/warnings.
217-
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
218-
unsigned flags, Type storageType,
219-
Type expressedType, int64_t storageTypeMin,
220-
int64_t storageTypeMax);
217+
static LogicalResult
218+
verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
219+
Type storageType, Type expressedType, int64_t storageTypeMin,
220+
int64_t storageTypeMax);
221221
};
222222

223223
/// Represents a family of uniform, quantized types.
@@ -276,11 +276,11 @@ class UniformQuantizedType
276276
int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax);
277277

278278
/// Verifies construction invariants and issues errors/warnings.
279-
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
280-
unsigned flags, Type storageType,
281-
Type expressedType, double scale,
282-
int64_t zeroPoint, int64_t storageTypeMin,
283-
int64_t storageTypeMax);
279+
static LogicalResult
280+
verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
281+
Type storageType, Type expressedType, double scale,
282+
int64_t zeroPoint, int64_t storageTypeMin,
283+
int64_t storageTypeMax);
284284

285285
/// Gets the scale term. The scale designates the difference between the real
286286
/// values corresponding to consecutive quantized values differing by 1.
@@ -338,12 +338,12 @@ class UniformQuantizedPerAxisType
338338
int64_t storageTypeMin, int64_t storageTypeMax);
339339

340340
/// Verifies construction invariants and issues errors/warnings.
341-
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
342-
unsigned flags, Type storageType,
343-
Type expressedType, ArrayRef<double> scales,
344-
ArrayRef<int64_t> zeroPoints,
345-
int32_t quantizedDimension,
346-
int64_t storageTypeMin, int64_t storageTypeMax);
341+
static LogicalResult
342+
verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
343+
Type storageType, Type expressedType,
344+
ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
345+
int32_t quantizedDimension, int64_t storageTypeMin,
346+
int64_t storageTypeMax);
347347

348348
/// Gets the quantization scales. The scales designate the difference between
349349
/// the real values corresponding to consecutive quantized values differing
@@ -403,8 +403,9 @@ class CalibratedQuantizedType
403403
double min, double max);
404404

405405
/// Verifies construction invariants and issues errors/warnings.
406-
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
407-
Type expressedType, double min, double max);
406+
static LogicalResult
407+
verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
408+
Type expressedType, double min, double max);
408409
double getMin() const;
409410
double getMax() const;
410411
};

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h

+8-6
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,10 @@ class InterfaceVarABIAttr
7676
/// Returns `spirv::StorageClass`.
7777
std::optional<StorageClass> getStorageClass();
7878

79-
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
80-
IntegerAttr descriptorSet, IntegerAttr binding,
81-
IntegerAttr storageClass);
79+
static LogicalResult
80+
verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
81+
IntegerAttr descriptorSet, IntegerAttr binding,
82+
IntegerAttr storageClass);
8283

8384
static constexpr StringLiteral name = "spirv.interface_var_abi";
8485
};
@@ -128,9 +129,10 @@ class VerCapExtAttr
128129
/// Returns the capabilities as an integer array attribute.
129130
ArrayAttr getCapabilitiesAttr();
130131

131-
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
132-
IntegerAttr version, ArrayAttr capabilities,
133-
ArrayAttr extensions);
132+
static LogicalResult
133+
verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
134+
IntegerAttr version, ArrayAttr capabilities,
135+
ArrayAttr extensions);
134136

135137
static constexpr StringLiteral name = "spirv.ver_cap_ext";
136138
};

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h

+6-4
Original file line numberDiff line numberDiff line change
@@ -258,8 +258,9 @@ class SampledImageType
258258
static SampledImageType
259259
getChecked(function_ref<InFlightDiagnostic()> emitError, Type imageType);
260260

261-
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
262-
Type imageType);
261+
static LogicalResult
262+
verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
263+
Type imageType);
263264

264265
Type getImageType() const;
265266

@@ -462,8 +463,9 @@ class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
462463
static MatrixType getChecked(function_ref<InFlightDiagnostic()> emitError,
463464
Type columnType, uint32_t columnCount);
464465

465-
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
466-
Type columnType, uint32_t columnCount);
466+
static LogicalResult
467+
verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
468+
Type columnType, uint32_t columnCount);
467469

468470
/// Returns true if the matrix elements are vectors of float elements.
469471
static bool isValidColumnType(Type columnType);

mlir/include/mlir/IR/CommonTypeConstraints.td

+1
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ class AnyTypeOf<list<Type> allowedTypeList, string summary = "",
180180
summary),
181181
cppClassName> {
182182
list<Type> allowedTypes = allowedTypeList;
183+
string cppType = cppClassName;
183184
}
184185

185186
// A type that satisfies the constraints of all given types.

mlir/include/mlir/IR/Constraints.td

+3
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,9 @@ class TypeConstraint<Pred predicate, string summary = "",
153153
Constraint<predicate, summary> {
154154
// The name of the C++ Type class if known, or Type if not.
155155
string cppClassName = cppClassNameParam;
156+
// TODO: This field is sometimes called `cppClassName` and sometimes
157+
// `cppType`. Use a single name consistently.
158+
string cppType = cppClassNameParam;
156159
}
157160

158161
// Subclass for constraints on an attribute.

mlir/include/mlir/IR/StorageUniquerSupport.h

+6-4
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,8 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
176176
template <typename... Args>
177177
static ConcreteT get(MLIRContext *ctx, Args &&...args) {
178178
// Ensure that the invariants are correct for construction.
179-
assert(
180-
succeeded(ConcreteT::verify(getDefaultDiagnosticEmitFn(ctx), args...)));
179+
assert(succeeded(
180+
ConcreteT::verifyInvariants(getDefaultDiagnosticEmitFn(ctx), args...)));
181181
return UniquerT::template get<ConcreteT>(ctx, std::forward<Args>(args)...);
182182
}
183183

@@ -198,7 +198,7 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
198198
static ConcreteT getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
199199
MLIRContext *ctx, Args... args) {
200200
// If the construction invariants fail then we return a null attribute.
201-
if (failed(ConcreteT::verify(emitErrorFn, args...)))
201+
if (failed(ConcreteT::verifyInvariants(emitErrorFn, args...)))
202202
return ConcreteT();
203203
return UniquerT::template get<ConcreteT>(ctx, args...);
204204
}
@@ -226,7 +226,9 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
226226

227227
/// Default implementation that just returns success.
228228
template <typename... Args>
229-
static LogicalResult verify(Args... args) {
229+
static LogicalResult
230+
verifyInvariants(function_ref<InFlightDiagnostic()> emitErrorFn,
231+
Args... args) {
230232
return success();
231233
}
232234

mlir/include/mlir/IR/Types.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class AsmState;
3434
/// Derived type classes are expected to implement several required
3535
/// implementation hooks:
3636
/// * Optional:
37-
/// - static LogicalResult verify(
37+
/// - static LogicalResult verifyInvariants(
3838
/// function_ref<InFlightDiagnostic()> emitError,
3939
/// Args... args)
4040
/// * This method is invoked when calling the 'TypeBase::get/getChecked'

mlir/include/mlir/TableGen/AttrOrTypeDef.h

+8
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "mlir/Support/LLVM.h"
1818
#include "mlir/TableGen/Builder.h"
19+
#include "mlir/TableGen/Constraint.h"
1920
#include "mlir/TableGen/Trait.h"
2021

2122
namespace llvm {
@@ -85,6 +86,9 @@ class AttrOrTypeParameter {
8586
/// Get an optional C++ parameter parser.
8687
std::optional<StringRef> getParser() const;
8788

89+
/// If this is a type constraint, return it.
90+
std::optional<Constraint> getConstraint() const;
91+
8892
/// Get an optional C++ parameter printer.
8993
std::optional<StringRef> getPrinter() const;
9094

@@ -198,6 +202,10 @@ class AttrOrTypeDef {
198202
/// method.
199203
bool genVerifyDecl() const;
200204

205+
/// Return true if we need to generate any type constraint verification and
206+
/// the getChecked method.
207+
bool genVerifyInvariantsImpl() const;
208+
201209
/// Returns the def's extra class declaration code.
202210
std::optional<StringRef> getExtraDecls() const;
203211

mlir/include/mlir/TableGen/Class.h

+2
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ class MethodParameter {
6767

6868
/// Get the C++ type.
6969
StringRef getType() const { return type; }
70+
/// Get the C++ parameter name.
71+
StringRef getName() const { return name; }
7072
/// Returns true if the parameter has a default value.
7173
bool hasDefaultValue() const { return !defaultValue.empty(); }
7274

mlir/lib/Dialect/GPU/IR/GPUDialect.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,9 @@ bool MMAMatrixType::isValidElementType(Type elementType) {
148148
}
149149

150150
LogicalResult
151-
MMAMatrixType::verify(function_ref<InFlightDiagnostic()> emitError,
152-
ArrayRef<int64_t> shape, Type elementType,
153-
StringRef operand) {
151+
MMAMatrixType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
152+
ArrayRef<int64_t> shape, Type elementType,
153+
StringRef operand) {
154154
if (operand != "AOp" && operand != "BOp" && operand != "COp")
155155
return emitError() << "operand expected to be one of AOp, BOp or COp";
156156

mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -418,8 +418,7 @@ LogicalResult LLVMPointerType::verifyEntries(DataLayoutEntryListRef entries,
418418

419419
bool LLVMStructType::isValidElementType(Type type) {
420420
return !llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
421-
LLVMFunctionType, LLVMTokenType, LLVMScalableVectorType>(
422-
type);
421+
LLVMFunctionType, LLVMTokenType>(type);
423422
}
424423

425424
LLVMStructType LLVMStructType::getIdentified(MLIRContext *context,
@@ -492,14 +491,15 @@ ArrayRef<Type> LLVMStructType::getBody() const {
492491
: getImpl()->getTypeList();
493492
}
494493

495-
LogicalResult LLVMStructType::verify(function_ref<InFlightDiagnostic()>,
496-
StringRef, bool) {
494+
LogicalResult
495+
LLVMStructType::verifyInvariants(function_ref<InFlightDiagnostic()>, StringRef,
496+
bool) {
497497
return success();
498498
}
499499

500500
LogicalResult
501-
LLVMStructType::verify(function_ref<InFlightDiagnostic()> emitError,
502-
ArrayRef<Type> types, bool) {
501+
LLVMStructType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
502+
ArrayRef<Type> types, bool) {
503503
for (Type t : types)
504504
if (!isValidElementType(t))
505505
return emitError() << "invalid LLVM structure element type: " << t;

mlir/lib/Dialect/Quant/IR/QuantTypes.cpp

+22-17
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,10 @@ bool QuantizedType::classof(Type type) {
2929
}
3030

3131
LogicalResult
32-
QuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
33-
unsigned flags, Type storageType, Type expressedType,
34-
int64_t storageTypeMin, int64_t storageTypeMax) {
32+
QuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
33+
unsigned flags, Type storageType,
34+
Type expressedType, int64_t storageTypeMin,
35+
int64_t storageTypeMax) {
3536
// Verify that the storage type is integral.
3637
// This restriction may be lifted at some point in favor of using bf16
3738
// or f16 as exact representations on hardware where that is advantageous.
@@ -233,11 +234,13 @@ AnyQuantizedType::getChecked(function_ref<InFlightDiagnostic()> emitError,
233234
}
234235

235236
LogicalResult
236-
AnyQuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
237-
unsigned flags, Type storageType, Type expressedType,
238-
int64_t storageTypeMin, int64_t storageTypeMax) {
239-
if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType,
240-
storageTypeMin, storageTypeMax))) {
237+
AnyQuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
238+
unsigned flags, Type storageType,
239+
Type expressedType, int64_t storageTypeMin,
240+
int64_t storageTypeMax) {
241+
if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType,
242+
expressedType, storageTypeMin,
243+
storageTypeMax))) {
241244
return failure();
242245
}
243246

@@ -268,12 +271,13 @@ UniformQuantizedType UniformQuantizedType::getChecked(
268271
storageTypeMin, storageTypeMax);
269272
}
270273

271-
LogicalResult UniformQuantizedType::verify(
274+
LogicalResult UniformQuantizedType::verifyInvariants(
272275
function_ref<InFlightDiagnostic()> emitError, unsigned flags,
273276
Type storageType, Type expressedType, double scale, int64_t zeroPoint,
274277
int64_t storageTypeMin, int64_t storageTypeMax) {
275-
if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType,
276-
storageTypeMin, storageTypeMax))) {
278+
if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType,
279+
expressedType, storageTypeMin,
280+
storageTypeMax))) {
277281
return failure();
278282
}
279283

@@ -321,13 +325,14 @@ UniformQuantizedPerAxisType UniformQuantizedPerAxisType::getChecked(
321325
quantizedDimension, storageTypeMin, storageTypeMax);
322326
}
323327

324-
LogicalResult UniformQuantizedPerAxisType::verify(
328+
LogicalResult UniformQuantizedPerAxisType::verifyInvariants(
325329
function_ref<InFlightDiagnostic()> emitError, unsigned flags,
326330
Type storageType, Type expressedType, ArrayRef<double> scales,
327331
ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
328332
int64_t storageTypeMin, int64_t storageTypeMax) {
329-
if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType,
330-
storageTypeMin, storageTypeMax))) {
333+
if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType,
334+
expressedType, storageTypeMin,
335+
storageTypeMax))) {
331336
return failure();
332337
}
333338

@@ -380,9 +385,9 @@ CalibratedQuantizedType CalibratedQuantizedType::getChecked(
380385
min, max);
381386
}
382387

383-
LogicalResult
384-
CalibratedQuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
385-
Type expressedType, double min, double max) {
388+
LogicalResult CalibratedQuantizedType::verifyInvariants(
389+
function_ref<InFlightDiagnostic()> emitError, Type expressedType,
390+
double min, double max) {
386391
// Verify that the expressed type is floating point.
387392
// If this restriction is ever eliminated, the parser/printer must be
388393
// extended.

0 commit comments

Comments
 (0)