Skip to content

Commit f339d44

Browse files
[mlir][ODS] Verify type constraints in Types and Attributes
When a type/attribute is defined in TableGen, a type constraint can be used for parameters, but the type constraint verification was missing. Example: ``` def TestTypeVerification : Test_Type<"TestTypeVerification"> { let parameters = (ins AnyTypeOf<[I16, I32]>:$param); // ... } ``` No verification code was generated to ensure that `$param` is I16 or I32. When type constraints a present, a new method will generated for types and attributes: `verifyInvariantsImpl`. (The naming is similar to op verifiers.) The user-provided verifier is called `verify` (no change). There is now a new entry point to type/attribute verification: `verifyInvariants`. This function calls both `verifyInvariantsImpl` and `verify`. If neither of those two verifications are present, the `verifyInvariants` function is not generated. When a type/attribute is not defined in TableGen, but a verifier is needed, users can implement the `verifyInvariants` function. (This function was previously called `verify`.)
1 parent 96d824d commit f339d44

23 files changed

+251
-100
lines changed

mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h

+4-3
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,10 @@ class MMAMatrixType
148148

149149
/// Verify that shape and elementType are actually allowed for the
150150
/// MMAMatrixType.
151-
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
152-
ArrayRef<int64_t> shape, Type elementType,
153-
StringRef operand);
151+
static LogicalResult
152+
verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
153+
ArrayRef<int64_t> shape, Type elementType,
154+
StringRef operand);
154155

155156
/// Get number of dims.
156157
unsigned getNumDims() const;

mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h

+7-5
Original file line numberDiff line numberDiff line change
@@ -180,11 +180,13 @@ class LLVMStructType
180180
ArrayRef<Type> getBody() const;
181181

182182
/// Verifies that the type about to be constructed is well-formed.
183-
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
184-
StringRef, bool);
185-
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
186-
ArrayRef<Type> types, bool);
187-
using Base::verify;
183+
static LogicalResult
184+
verifyInvariants(function_ref<InFlightDiagnostic()> emitError, StringRef,
185+
bool);
186+
static LogicalResult
187+
verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
188+
ArrayRef<Type> types, bool);
189+
using Base::verifyInvariants;
188190

189191
/// Hooks for DataLayoutTypeInterface. Should not be called directly. Obtain a
190192
/// DataLayout instance and query it instead.

mlir/include/mlir/Dialect/Quant/QuantTypes.h

+22-21
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,10 @@ class QuantizedType : public Type {
5454
/// The maximum number of bits supported for storage types.
5555
static constexpr unsigned MaxStorageBits = 32;
5656

57-
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
58-
unsigned flags, Type storageType,
59-
Type expressedType, int64_t storageTypeMin,
60-
int64_t storageTypeMax);
57+
static LogicalResult
58+
verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
59+
Type storageType, Type expressedType, int64_t storageTypeMin,
60+
int64_t storageTypeMax);
6161

6262
/// Support method to enable LLVM-style type casting.
6363
static bool classof(Type type);
@@ -214,10 +214,10 @@ class AnyQuantizedType
214214
int64_t storageTypeMax);
215215

216216
/// Verifies construction invariants and issues errors/warnings.
217-
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
218-
unsigned flags, Type storageType,
219-
Type expressedType, int64_t storageTypeMin,
220-
int64_t storageTypeMax);
217+
static LogicalResult
218+
verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
219+
Type storageType, Type expressedType, int64_t storageTypeMin,
220+
int64_t storageTypeMax);
221221
};
222222

223223
/// Represents a family of uniform, quantized types.
@@ -276,11 +276,11 @@ class UniformQuantizedType
276276
int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax);
277277

278278
/// Verifies construction invariants and issues errors/warnings.
279-
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
280-
unsigned flags, Type storageType,
281-
Type expressedType, double scale,
282-
int64_t zeroPoint, int64_t storageTypeMin,
283-
int64_t storageTypeMax);
279+
static LogicalResult
280+
verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
281+
Type storageType, Type expressedType, double scale,
282+
int64_t zeroPoint, int64_t storageTypeMin,
283+
int64_t storageTypeMax);
284284

285285
/// Gets the scale term. The scale designates the difference between the real
286286
/// values corresponding to consecutive quantized values differing by 1.
@@ -338,12 +338,12 @@ class UniformQuantizedPerAxisType
338338
int64_t storageTypeMin, int64_t storageTypeMax);
339339

340340
/// Verifies construction invariants and issues errors/warnings.
341-
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
342-
unsigned flags, Type storageType,
343-
Type expressedType, ArrayRef<double> scales,
344-
ArrayRef<int64_t> zeroPoints,
345-
int32_t quantizedDimension,
346-
int64_t storageTypeMin, int64_t storageTypeMax);
341+
static LogicalResult
342+
verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
343+
Type storageType, Type expressedType,
344+
ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
345+
int32_t quantizedDimension, int64_t storageTypeMin,
346+
int64_t storageTypeMax);
347347

348348
/// Gets the quantization scales. The scales designate the difference between
349349
/// the real values corresponding to consecutive quantized values differing
@@ -403,8 +403,9 @@ class CalibratedQuantizedType
403403
double min, double max);
404404

405405
/// Verifies construction invariants and issues errors/warnings.
406-
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
407-
Type expressedType, double min, double max);
406+
static LogicalResult
407+
verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
408+
Type expressedType, double min, double max);
408409
double getMin() const;
409410
double getMax() const;
410411
};

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h

+8-6
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,10 @@ class InterfaceVarABIAttr
7676
/// Returns `spirv::StorageClass`.
7777
std::optional<StorageClass> getStorageClass();
7878

79-
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
80-
IntegerAttr descriptorSet, IntegerAttr binding,
81-
IntegerAttr storageClass);
79+
static LogicalResult
80+
verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
81+
IntegerAttr descriptorSet, IntegerAttr binding,
82+
IntegerAttr storageClass);
8283

8384
static constexpr StringLiteral name = "spirv.interface_var_abi";
8485
};
@@ -128,9 +129,10 @@ class VerCapExtAttr
128129
/// Returns the capabilities as an integer array attribute.
129130
ArrayAttr getCapabilitiesAttr();
130131

131-
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
132-
IntegerAttr version, ArrayAttr capabilities,
133-
ArrayAttr extensions);
132+
static LogicalResult
133+
verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
134+
IntegerAttr version, ArrayAttr capabilities,
135+
ArrayAttr extensions);
134136

135137
static constexpr StringLiteral name = "spirv.ver_cap_ext";
136138
};

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h

+6-4
Original file line numberDiff line numberDiff line change
@@ -258,8 +258,9 @@ class SampledImageType
258258
static SampledImageType
259259
getChecked(function_ref<InFlightDiagnostic()> emitError, Type imageType);
260260

261-
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
262-
Type imageType);
261+
static LogicalResult
262+
verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
263+
Type imageType);
263264

264265
Type getImageType() const;
265266

@@ -462,8 +463,9 @@ class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
462463
static MatrixType getChecked(function_ref<InFlightDiagnostic()> emitError,
463464
Type columnType, uint32_t columnCount);
464465

465-
static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
466-
Type columnType, uint32_t columnCount);
466+
static LogicalResult
467+
verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
468+
Type columnType, uint32_t columnCount);
467469

468470
/// Returns true if the matrix elements are vectors of float elements.
469471
static bool isValidColumnType(Type columnType);

mlir/include/mlir/IR/CommonTypeConstraints.td

+1
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ class AnyTypeOf<list<Type> allowedTypeList, string summary = "",
180180
summary),
181181
cppClassName> {
182182
list<Type> allowedTypes = allowedTypeList;
183+
string cppType = cppClassName;
183184
}
184185

185186
// A type that satisfies the constraints of all given types.

mlir/include/mlir/IR/Constraints.td

+3
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,9 @@ class TypeConstraint<Pred predicate, string summary = "",
153153
Constraint<predicate, summary> {
154154
// The name of the C++ Type class if known, or Type if not.
155155
string cppClassName = cppClassNameParam;
156+
// TODO: This field is sometimes called `cppClassName` and sometimes
157+
// `cppType`. Use a single name consistently.
158+
string cppType = cppClassNameParam;
156159
}
157160

158161
// Subclass for constraints on an attribute.

mlir/include/mlir/IR/StorageUniquerSupport.h

+6-4
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,8 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
176176
template <typename... Args>
177177
static ConcreteT get(MLIRContext *ctx, Args &&...args) {
178178
// Ensure that the invariants are correct for construction.
179-
assert(
180-
succeeded(ConcreteT::verify(getDefaultDiagnosticEmitFn(ctx), args...)));
179+
assert(succeeded(
180+
ConcreteT::verifyInvariants(getDefaultDiagnosticEmitFn(ctx), args...)));
181181
return UniquerT::template get<ConcreteT>(ctx, std::forward<Args>(args)...);
182182
}
183183

@@ -198,7 +198,7 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
198198
static ConcreteT getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
199199
MLIRContext *ctx, Args... args) {
200200
// If the construction invariants fail then we return a null attribute.
201-
if (failed(ConcreteT::verify(emitErrorFn, args...)))
201+
if (failed(ConcreteT::verifyInvariants(emitErrorFn, args...)))
202202
return ConcreteT();
203203
return UniquerT::template get<ConcreteT>(ctx, args...);
204204
}
@@ -226,7 +226,9 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
226226

227227
/// Default implementation that just returns success.
228228
template <typename... Args>
229-
static LogicalResult verify(Args... args) {
229+
static LogicalResult
230+
verifyInvariants(function_ref<InFlightDiagnostic()> emitErrorFn,
231+
Args... args) {
230232
return success();
231233
}
232234

mlir/include/mlir/IR/Types.h

+8-11
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class AsmState;
3434
/// Derived type classes are expected to implement several required
3535
/// implementation hooks:
3636
/// * Optional:
37-
/// - static LogicalResult verify(
37+
/// - static LogicalResult verifyInvariants(
3838
/// function_ref<InFlightDiagnostic()> emitError,
3939
/// Args... args)
4040
/// * This method is invoked when calling the 'TypeBase::get/getChecked'
@@ -97,20 +97,17 @@ class Type {
9797
bool operator!() const { return impl == nullptr; }
9898

9999
template <typename... Tys>
100-
[[deprecated("Use mlir::isa<U>() instead")]]
101-
bool isa() const;
100+
[[deprecated("Use mlir::isa<U>() instead")]] bool isa() const;
102101
template <typename... Tys>
103-
[[deprecated("Use mlir::isa_and_nonnull<U>() instead")]]
104-
bool isa_and_nonnull() const;
102+
[[deprecated("Use mlir::isa_and_nonnull<U>() instead")]] bool
103+
isa_and_nonnull() const;
105104
template <typename U>
106-
[[deprecated("Use mlir::dyn_cast<U>() instead")]]
107-
U dyn_cast() const;
105+
[[deprecated("Use mlir::dyn_cast<U>() instead")]] U dyn_cast() const;
108106
template <typename U>
109-
[[deprecated("Use mlir::dyn_cast_or_null<U>() instead")]]
110-
U dyn_cast_or_null() const;
107+
[[deprecated("Use mlir::dyn_cast_or_null<U>() instead")]] U
108+
dyn_cast_or_null() const;
111109
template <typename U>
112-
[[deprecated("Use mlir::cast<U>() instead")]]
113-
U cast() const;
110+
[[deprecated("Use mlir::cast<U>() instead")]] U cast() const;
114111

115112
/// Return a unique identifier for the concrete type. This is used to support
116113
/// dynamic type casting.

mlir/include/mlir/TableGen/AttrOrTypeDef.h

+8
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/Support/LLVM.h"
1818
#include "mlir/TableGen/Builder.h"
1919
#include "mlir/TableGen/Trait.h"
20+
#include "mlir/TableGen/Type.h"
2021

2122
namespace llvm {
2223
class DagInit;
@@ -85,6 +86,9 @@ class AttrOrTypeParameter {
8586
/// Get an optional C++ parameter parser.
8687
std::optional<StringRef> getParser() const;
8788

89+
/// If this is a type constraint, return it.
90+
std::optional<TypeConstraint> getTypeConstraint() const;
91+
8892
/// Get an optional C++ parameter printer.
8993
std::optional<StringRef> getPrinter() const;
9094

@@ -198,6 +202,10 @@ class AttrOrTypeDef {
198202
/// method.
199203
bool genVerifyDecl() const;
200204

205+
/// Return true if we need to generate any type constraint verification and
206+
/// the getChecked method.
207+
bool genVerifyInvariantsImpl() const;
208+
201209
/// Returns the def's extra class declaration code.
202210
std::optional<StringRef> getExtraDecls() const;
203211

mlir/include/mlir/TableGen/Class.h

+2
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ class MethodParameter {
6767

6868
/// Get the C++ type.
6969
StringRef getType() const { return type; }
70+
/// Get the C++ parameter name.
71+
StringRef getName() const { return name; }
7072
/// Returns true if the parameter has a default value.
7173
bool hasDefaultValue() const { return !defaultValue.empty(); }
7274

mlir/lib/Dialect/GPU/IR/GPUDialect.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,9 @@ bool MMAMatrixType::isValidElementType(Type elementType) {
148148
}
149149

150150
LogicalResult
151-
MMAMatrixType::verify(function_ref<InFlightDiagnostic()> emitError,
152-
ArrayRef<int64_t> shape, Type elementType,
153-
StringRef operand) {
151+
MMAMatrixType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
152+
ArrayRef<int64_t> shape, Type elementType,
153+
StringRef operand) {
154154
if (operand != "AOp" && operand != "BOp" && operand != "COp")
155155
return emitError() << "operand expected to be one of AOp, BOp or COp";
156156

mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -418,8 +418,7 @@ LogicalResult LLVMPointerType::verifyEntries(DataLayoutEntryListRef entries,
418418

419419
bool LLVMStructType::isValidElementType(Type type) {
420420
return !llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
421-
LLVMFunctionType, LLVMTokenType, LLVMScalableVectorType>(
422-
type);
421+
LLVMFunctionType, LLVMTokenType>(type);
423422
}
424423

425424
LLVMStructType LLVMStructType::getIdentified(MLIRContext *context,
@@ -492,14 +491,15 @@ ArrayRef<Type> LLVMStructType::getBody() const {
492491
: getImpl()->getTypeList();
493492
}
494493

495-
LogicalResult LLVMStructType::verify(function_ref<InFlightDiagnostic()>,
496-
StringRef, bool) {
494+
LogicalResult
495+
LLVMStructType::verifyInvariants(function_ref<InFlightDiagnostic()>, StringRef,
496+
bool) {
497497
return success();
498498
}
499499

500500
LogicalResult
501-
LLVMStructType::verify(function_ref<InFlightDiagnostic()> emitError,
502-
ArrayRef<Type> types, bool) {
501+
LLVMStructType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
502+
ArrayRef<Type> types, bool) {
503503
for (Type t : types)
504504
if (!isValidElementType(t))
505505
return emitError() << "invalid LLVM structure element type: " << t;

0 commit comments

Comments
 (0)