Skip to content

Commit 718af88

Browse files
authored
[mlir][vector] Extend mask calculation for vector.contract (#65733)
Make sure that when calculating the expected mask for `vector.contract`, scalable sizes are correctly taken into account. Depends on: #65724
1 parent 4b5fe9c commit 718af88

File tree

2 files changed

+35
-6
lines changed

2 files changed

+35
-6
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -912,22 +912,27 @@ Type ContractionOp::getExpectedMaskType() {
912912

913913
unsigned numVecDims = lhsIdxMap.getNumDims();
914914
SmallVector<int64_t> maskShape(numVecDims, ShapedType::kDynamic);
915+
SmallVector<bool> maskShapeScalableDims(numVecDims, false);
915916

916917
// Using the information in the indexing maps, extract the size of each
917918
// dimension in the vector.contract operation from the two input operands.
918-
for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape()))
919+
for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape())) {
919920
maskShape[lhsIdxMap.getDimPosition(dimIdx)] = dimSize;
920-
for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape()))
921+
maskShapeScalableDims[lhsIdxMap.getDimPosition(dimIdx)] =
922+
lhsType.getScalableDims()[dimIdx];
923+
}
924+
for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) {
921925
maskShape[rhsIdxMap.getDimPosition(dimIdx)] = dimSize;
926+
maskShapeScalableDims[rhsIdxMap.getDimPosition(dimIdx)] =
927+
rhsType.getScalableDims()[dimIdx];
928+
}
922929

923930
assert(!ShapedType::isDynamicShape(maskShape) &&
924931
"Mask shape couldn't be computed");
925-
// TODO: Extend the scalable vector type representation with a bit map.
926-
assert(!lhsType.isScalable() && !rhsType.isScalable() &&
927-
"Scalable vectors are not supported yet");
928932

929933
return VectorType::get(maskShape,
930-
IntegerType::get(lhsType.getContext(), /*width=*/1));
934+
IntegerType::get(lhsType.getContext(), /*width=*/1),
935+
maskShapeScalableDims);
931936
}
932937

933938
SmallVector<StringRef> ContractionOp::getTraitAttrNames() {

mlir/test/Dialect/Vector/ops.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -979,3 +979,27 @@ func.func @vector_scalable_extract(%sv: vector<[8]xi32>) {
979979
%2 = vector.scalable.extract %sv[4] : vector<4xi32> from vector<[8]xi32>
980980
return
981981
}
982+
983+
#matmat_accesses = [
984+
affine_map<(i, j, k) -> (i, k)>,
985+
affine_map<(i, j, k) -> (k, j)>,
986+
affine_map<(i, j, k) -> (i, j)>
987+
]
988+
#matmat_trait = {
989+
indexing_maps = #matmat_accesses,
990+
iterator_types = ["parallel", "parallel", "reduction"]
991+
}
992+
// CHECK-LABEL: func.func @contraction_masked_scalable(
993+
// CHECK-SAME: %[[A:.*]]: vector<3x4xf32>,
994+
// CHECK-SAME: %[[B:.*]]: vector<4x[8]xf32>,
995+
// CHECK-SAME: %[[C:.*]]: vector<3x[8]xf32>,
996+
// CHECK-SAME: %[[M:.*]]: vector<3x[8]x4xi1>) -> vector<3x[8]xf32> {
997+
func.func @contraction_masked_scalable(%A: vector<3x4xf32>,
998+
%B: vector<4x[8]xf32>,
999+
%C: vector<3x[8]xf32>,
1000+
%M : vector<3x[8]x4xi1>) -> vector<3x[8]xf32> {
1001+
// CHECK: vector.mask %[[M]] { vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[A]], %[[B]], %[[C]] : vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> } : vector<3x[8]x4xi1> -> vector<3x[8]xf32>
1002+
%0 = vector.mask %M { vector.contract #matmat_trait %A, %B, %C : vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> }
1003+
: vector<3x[8]x4xi1> -> vector<3x[8]xf32>
1004+
return %0 : vector<3x[8]xf32>
1005+
}

0 commit comments

Comments
 (0)