Skip to content

Commit b9c971d

Browse files
committed
[mlir][spirv] Fix coop matrix store
- Fix operand/attribute order - Use ODS for parsing/printing - Allow for stride to be any integer type
1 parent 4d2536c commit b9c971d

File tree

3 files changed

+33
-60
lines changed

3 files changed

+33
-60
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td

+10-5
Original file line numberDiff line numberDiff line change
@@ -171,10 +171,10 @@ def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStor
171171

172172
``` {.ebnf}
173173
coop-matrix-store-op ::= `spirv.KHR.CooperativeMatrixStore `
174-
ssa-use `, ` ssa-use `, `
175-
ssa-use `, ` cooperative-matrix-layout `, `
176-
(`[` memory-operand `]`)? `:`
177-
pointer-type `,` coop-matrix-type
174+
ssa-use `,` ssa-use `,`
175+
ssa-use `,` `<` cooperative-matrix-layout `>
176+
(`,` `<` memory-operand `>`)? `:`
177+
pointer-type `,` coop-matrix-type, stride-type
178178
```
179179

180180
#### Example:
@@ -185,6 +185,11 @@ def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStor
185185
```
186186
}];
187187

188+
let assemblyFormat = [{
189+
$pointer `,` $object `,` $stride `,` $matrix_layout ( `,` $memory_operand^ )? attr-dict `:`
190+
type($pointer) `,` type($object) `,` type($stride)
191+
}];
192+
188193
let availability = [
189194
MinVersion<SPIRV_V_1_6>,
190195
MaxVersion<SPIRV_V_1_6>,
@@ -195,8 +200,8 @@ def SPIRV_KHRCooperativeMatrixStoreOp : SPIRV_KhrVendorOp<"CooperativeMatrixStor
195200
let arguments = (ins
196201
SPIRV_AnyPtr:$pointer,
197202
SPIRV_AnyCooperativeMatrix:$object,
198-
SPIRV_Integer:$stride,
199203
SPIRV_KHR_CooperativeMatrixLayoutAttr:$matrix_layout,
204+
SPIRV_Integer:$stride,
200205
OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_operand
201206
);
202207

mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp

-43
Original file line numberDiff line numberDiff line change
@@ -106,49 +106,6 @@ LogicalResult KHRCooperativeMatrixLoadOp::verify() {
106106
// spirv.KHR.CooperativeMatrixStore
107107
//===----------------------------------------------------------------------===//
108108

109-
ParseResult KHRCooperativeMatrixStoreOp::parse(OpAsmParser &parser,
110-
OperationState &result) {
111-
std::array<OpAsmParser::UnresolvedOperand, 3> operandInfo = {};
112-
for (auto &op : operandInfo) {
113-
if (parser.parseOperand(op) || parser.parseComma())
114-
return failure();
115-
}
116-
117-
CooperativeMatrixLayoutKHR layout;
118-
if (parseEnumKeywordAttr<CooperativeMatrixLayoutKHRAttr>(
119-
layout, parser, result, kKhrCooperativeMatrixLayoutAttrName)) {
120-
return failure();
121-
}
122-
123-
if (parseMemoryAccessAttributes(parser, result, kMemoryOperandAttrName))
124-
return failure();
125-
126-
Type ptrType;
127-
Type objectType;
128-
if (parser.parseColon() || parser.parseType(ptrType) || parser.parseComma() ||
129-
parser.parseType(objectType)) {
130-
return failure();
131-
}
132-
133-
Type strideType = parser.getBuilder().getIntegerType(32);
134-
if (parser.resolveOperands(operandInfo, {ptrType, objectType, strideType},
135-
parser.getNameLoc(), result.operands)) {
136-
return failure();
137-
}
138-
139-
return success();
140-
}
141-
142-
void KHRCooperativeMatrixStoreOp::print(OpAsmPrinter &printer) {
143-
printer << " " << getPointer() << ", " << getObject() << ", " << getStride()
144-
<< ", " << getMatrixLayout();
145-
146-
// Print optional memory operand attribute.
147-
if (auto memOperand = getMemoryOperand())
148-
printer << " [\"" << *memOperand << "\"]";
149-
printer << " : " << getPointer().getType() << ", " << getObject().getType();
150-
}
151-
152109
LogicalResult KHRCooperativeMatrixStoreOp::verify() {
153110
return verifyPointerAndCoopMatrixType(*this, getPointer().getType(),
154111
getObject().getType());

mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir

+23-12
Original file line numberDiff line numberDiff line change
@@ -60,21 +60,32 @@ spirv.func @cooperative_matrix_load_function(%ptr : !spirv.ptr<i32, Function>, %
6060
// CHECK-LABEL: @cooperative_matrix_store
6161
spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
6262
%m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
63-
// CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, RowMajor :
64-
// CHECK-SAME: !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>
65-
spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, RowMajor :
66-
!spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>
63+
// CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, <RowMajor> :
64+
// CHECK-SAME: !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
65+
spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor> :
66+
!spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
6767
spirv.Return
6868
}
6969

7070
// CHECK-LABEL: @cooperative_matrix_store_memoperand
7171
spirv.func @cooperative_matrix_store_memoperand(%ptr : !spirv.ptr<i32, StorageBuffer>,
7272
%m : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>,
7373
%stride : i32) "None" {
74-
// CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, ColumnMajor ["Volatile"] :
75-
// CHECK-SAME: !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>
76-
spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, ColumnMajor ["Volatile"] :
77-
!spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>
74+
// CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, <ColumnMajor>, <Volatile> :
75+
// CHECK-SAME: !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>, i32
76+
spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <ColumnMajor>, <Volatile> :
77+
!spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>, i32
78+
spirv.Return
79+
}
80+
81+
// CHECK-LABEL: @cooperative_matrix_store_stride_i16
82+
spirv.func @cooperative_matrix_store_stride_i16(%ptr : !spirv.ptr<i32, StorageBuffer>,
83+
%m : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>,
84+
%stride : i16) "None" {
85+
// CHECK: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, <ColumnMajor> :
86+
// CHECK-SAME: !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>, i16
87+
spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <ColumnMajor> :
88+
!spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Subgroup, MatrixB>, i16
7889
spirv.Return
7990
}
8091

@@ -128,9 +139,9 @@ spirv.func @cooperative_matrix_store_missing_attr(%ptr : !spirv.ptr<i32, Storage
128139

129140
spirv.func @cooperative_matrix_store_missing_attr(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
130141
%m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
131-
// expected-error @+1 {{expected valid keyword}}
142+
// expected-error @+1 {{expected '<'}}
132143
spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, :
133-
!spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>
144+
!spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
134145
spirv.Return
135146
}
136147

@@ -139,8 +150,8 @@ spirv.func @cooperative_matrix_store_missing_attr(%ptr : !spirv.ptr<i32, Storage
139150
spirv.func @cooperative_matrix_store_bad_object_type(%ptr : !spirv.ptr<i32, StorageBuffer>,
140151
%stride : i32) "None" {
141152
// expected-error @+1 {{op operand #1 must be any SPIR-V cooperative matrix type}}
142-
spirv.KHR.CooperativeMatrixStore %ptr, %stride, %stride, RowMajor :
143-
!spirv.ptr<i32, StorageBuffer>, i32
153+
spirv.KHR.CooperativeMatrixStore %ptr, %stride, %stride, <RowMajor> :
154+
!spirv.ptr<i32, StorageBuffer>, i32, i32
144155
spirv.Return
145156
}
146157

0 commit comments

Comments
 (0)