diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td index 8a9ce949a750d..89c1a06412947 100644 --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -263,16 +263,6 @@ def ForOp : SCF_Op<"for", } /// Number of operands controlling the loop: lb, ub, step unsigned getNumControlOperands() { return 3; } - /// Get the iter arg number for an operand. If it isnt an iter arg - /// operand return std::nullopt. - std::optional getIterArgNumberForOpOperand(OpOperand &opOperand) { - if (opOperand.getOwner() != getOperation()) - return std::nullopt; - unsigned operandNumber = opOperand.getOperandNumber(); - if (operandNumber < getNumControlOperands()) - return std::nullopt; - return operandNumber - getNumControlOperands(); - } /// Get the region iter arg that corresponds to an OpOperand. /// This helper prevents internal op implementation detail leakage to diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp index 21bc0554e7176..a9debb7bbc489 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp @@ -810,13 +810,9 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting, OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPointAfter(hoistedPackedTensor.getDefiningOp()); - std::optional maybeOperandNumber = - forOp.getIterArgNumberForOpOperand(*pUse); - assert(maybeOperandNumber.has_value() && "expected a proper iter arg number"); - - int64_t operandNumber = maybeOperandNumber.value(); + unsigned iterArgNumber = forOp.getResultForOpOperand(*pUse).getResultNumber(); auto yieldOp = cast(forOp.getBody(0)->getTerminator()); - auto yieldingExtractSliceOp = yieldOp->getOperand(operandNumber) + auto yieldingExtractSliceOp = yieldOp->getOperand(iterArgNumber) .getDefiningOp(); if (!yieldingExtractSliceOp) return tensor::ExtractSliceOp(); @@ -829,9 +825,9 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting, return tensor::ExtractSliceOp(); SmallVector initArgs = forOp.getInitArgs(); - initArgs[operandNumber] = hoistedPackedTensor; + initArgs[iterArgNumber] = hoistedPackedTensor; SmallVector yieldOperands = yieldOp.getOperands(); - yieldOperands[operandNumber] = yieldingExtractSliceOp.getSource(); + yieldOperands[iterArgNumber] = yieldingExtractSliceOp.getSource(); int64_t numOriginalForOpResults = initArgs.size(); LLVM_DEBUG(DBGS() << "numOriginalForOpResults: " << numOriginalForOpResults @@ -844,7 +840,7 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting, hoistedPackedTensor.getLoc(), hoistedPackedTensor, outerSliceOp.getMixedOffsets(), outerSliceOp.getMixedSizes(), outerSliceOp.getMixedStrides()); - rewriter.replaceAllUsesWith(forOp.getResult(operandNumber), extracted); + rewriter.replaceAllUsesWith(forOp.getResult(iterArgNumber), extracted); } scf::ForOp newForOp = replaceLoopWithNewYields(rewriter, forOp, initArgs, yieldOperands); @@ -853,20 +849,20 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting, << "\n"); LLVM_DEBUG(DBGS() << "replace source of: " << extracted << "\n"); LLVM_DEBUG(DBGS() << "with result #" - << numOriginalForOpResults + operandNumber + << numOriginalForOpResults + iterArgNumber << " of forOp, giving us: " << extracted << "\n"); rewriter.startRootUpdate(extracted); extracted.getSourceMutable().assign( - newForOp.getResult(numOriginalForOpResults + operandNumber)); + newForOp.getResult(numOriginalForOpResults + iterArgNumber)); rewriter.finalizeRootUpdate(extracted); LLVM_DEBUG(DBGS() << "replace uses of: " << paddedValueBeforeHoisting << "\n"); LLVM_DEBUG(DBGS() << "with region iter arg #" - << numOriginalForOpResults + operandNumber << "\n"); + << numOriginalForOpResults + iterArgNumber << "\n"); rewriter.replaceAllUsesWith( paddedValueBeforeHoisting, - newForOp.getRegionIterArg(numOriginalForOpResults + operandNumber)); + newForOp.getRegionIterArg(numOriginalForOpResults + iterArgNumber)); return extracted; } diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index f5586712a84fa..6cfba3fef15eb 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -569,8 +569,9 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter, scf::ForOp outerMostLoop = loops.front(); if (destinationInitArg && (*destinationInitArg)->getOwner() == outerMostLoop) { - std::optional iterArgNumber = - outerMostLoop.getIterArgNumberForOpOperand(**destinationInitArg); + unsigned iterArgNumber = + outerMostLoop.getResultForOpOperand(**destinationInitArg) + .getResultNumber(); int64_t resultNumber = fusableProducer.getResultNumber(); if (auto dstOp = dyn_cast(fusableProducer.getOwner())) { @@ -584,7 +585,7 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter, scf::ForOp innerMostLoop = loops.back(); updateDestinationOperandsForTiledOp( rewriter, dstOp.getDpsInitOperand(resultNumber)->get(), - innerMostLoop.getRegionIterArgs()[iterArgNumber.value()]); + innerMostLoop.getRegionIterArgs()[iterArgNumber]); } } return scf::SCFFuseProducerOfSliceResult{fusableProducer,