From b467dae24cdb6ba843e249d09a12cbbfb002dcde Mon Sep 17 00:00:00 2001 From: Jorn Tuyls Date: Sat, 20 Apr 2024 05:55:35 -0700 Subject: [PATCH] [mlir] Extract forall_to_for logic into reusable function and add pass --- .../mlir/Dialect/SCF/Transforms/Passes.h | 3 + .../mlir/Dialect/SCF/Transforms/Passes.td | 5 ++ .../mlir/Dialect/SCF/Transforms/Transforms.h | 7 ++ .../SCF/TransformOps/SCFTransformOps.cpp | 33 ++------ .../lib/Dialect/SCF/Transforms/CMakeLists.txt | 1 + .../Dialect/SCF/Transforms/ForallToFor.cpp | 79 +++++++++++++++++++ mlir/test/Dialect/SCF/forall-to-for.mlir | 57 +++++++++++++ 7 files changed, 160 insertions(+), 25 deletions(-) create mode 100644 mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp create mode 100644 mlir/test/Dialect/SCF/forall-to-for.mlir diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h index 90b315e83a8cf..31c3d0eb629d2 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h @@ -59,6 +59,9 @@ createParallelLoopTilingPass(llvm::ArrayRef tileSize = {}, /// loop range. std::unique_ptr createForLoopRangeFoldingPass(); +/// Creates a pass that converts SCF forall loops to SCF for loops. +std::unique_ptr createForallToForLoopPass(); + // Creates a pass which lowers for loops into while loops. std::unique_ptr createForToWhileLoopPass(); diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td index 350611ad86873..a7aeb42d60c0e 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td @@ -120,6 +120,11 @@ def SCFForLoopRangeFolding : Pass<"scf-for-loop-range-folding"> { let constructor = "mlir::createForLoopRangeFoldingPass()"; } +def SCFForallToForLoop : Pass<"scf-forall-to-for"> { + let summary = "Convert SCF forall loops to SCF for loops"; + let constructor = "mlir::createForallToForLoopPass()"; +} + def SCFForToWhileLoop : Pass<"scf-for-to-while"> { let summary = "Convert SCF for loops to SCF while loops"; let constructor = "mlir::createForToWhileLoopPass()"; diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h index 220dcb35571d2..b063e6e775e63 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h @@ -28,10 +28,17 @@ class Value; namespace scf { class IfOp; +class ForallOp; class ForOp; class ParallelOp; class WhileOp; +/// Try converting scf.forall into a set of nested scf.for loops. +/// The newly created scf.for ops will be returned through the `results` +/// vector if provided. +LogicalResult forallToForLoop(RewriterBase &rewriter, ForallOp forallOp, + SmallVectorImpl *results = nullptr); + /// Fuses all adjacent scf.parallel operations with identical bounds and step /// into one scf.parallel operations. Uses a naive aliasing and dependency /// analysis. diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp index 7e4faf8b73afb..69f83d8bd70da 100644 --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -69,16 +69,12 @@ transform::ForallToForOp::apply(transform::TransformRewriter &rewriter, return diag; } - rewriter.setInsertionPoint(target); - if (!target.getOutputs().empty()) { return emitSilenceableError() << "unsupported shared outputs (didn't bufferize?)"; } SmallVector lbs = target.getMixedLowerBound(); - SmallVector ubs = target.getMixedUpperBound(); - SmallVector steps = target.getMixedStep(); if (getNumResults() != lbs.size()) { DiagnosedSilenceableFailure diag = @@ -89,28 +85,15 @@ transform::ForallToForOp::apply(transform::TransformRewriter &rewriter, return diag; } - auto loc = target.getLoc(); - SmallVector ivs; - for (auto &&[lb, ub, step] : llvm::zip(lbs, ubs, steps)) { - Value lbValue = getValueOrCreateConstantIndexOp(rewriter, loc, lb); - Value ubValue = getValueOrCreateConstantIndexOp(rewriter, loc, ub); - Value stepValue = getValueOrCreateConstantIndexOp(rewriter, loc, step); - auto loop = rewriter.create( - loc, lbValue, ubValue, stepValue, ValueRange(), - [](OpBuilder &, Location, Value, ValueRange) {}); - ivs.push_back(loop.getInductionVar()); - rewriter.setInsertionPointToStart(loop.getBody()); - rewriter.create(loc); - rewriter.setInsertionPointToStart(loop.getBody()); + SmallVector opResults; + if (failed(scf::forallToForLoop(rewriter, target, &opResults))) { + DiagnosedSilenceableFailure diag = emitSilenceableError() + << "failed to convert forall into for"; + return diag; } - rewriter.eraseOp(target.getBody()->getTerminator()); - rewriter.inlineBlockBefore(target.getBody(), &*rewriter.getInsertionPoint(), - ivs); - rewriter.eraseOp(target); - - for (auto &&[i, iv] : llvm::enumerate(ivs)) { - results.set(cast(getTransformed()[i]), - {iv.getParentBlock()->getParentOp()}); + + for (auto &&[i, res] : llvm::enumerate(opResults)) { + results.set(cast(getTransformed()[i]), {res}); } return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt index a2925aef17ca7..e7671c9cc28f8 100644 --- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRSCFTransforms BufferDeallocationOpInterfaceImpl.cpp BufferizableOpInterfaceImpl.cpp Bufferize.cpp + ForallToFor.cpp ForToWhile.cpp LoopCanonicalization.cpp LoopPipelining.cpp diff --git a/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp new file mode 100644 index 0000000000000..198cb2e6cc69e --- /dev/null +++ b/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp @@ -0,0 +1,79 @@ +//===- ForallToFor.cpp - scf.forall to scf.for loop conversion ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Transforms SCF.ForallOp's into SCF.ForOp's. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SCF/Transforms/Passes.h" + +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/Transforms.h" +#include "mlir/IR/PatternMatch.h" + +namespace mlir { +#define GEN_PASS_DEF_SCFFORALLTOFORLOOP +#include "mlir/Dialect/SCF/Transforms/Passes.h.inc" +} // namespace mlir + +using namespace llvm; +using namespace mlir; +using scf::ForallOp; +using scf::ForOp; +using scf::LoopNest; + +LogicalResult +mlir::scf::forallToForLoop(RewriterBase &rewriter, scf::ForallOp forallOp, + SmallVectorImpl *results) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(forallOp); + + Location loc = forallOp.getLoc(); + SmallVector lbs = getValueOrCreateConstantIndexOp( + rewriter, loc, forallOp.getMixedLowerBound()); + SmallVector ubs = getValueOrCreateConstantIndexOp( + rewriter, loc, forallOp.getMixedUpperBound()); + SmallVector steps = + getValueOrCreateConstantIndexOp(rewriter, loc, forallOp.getMixedStep()); + LoopNest loopNest = scf::buildLoopNest(rewriter, loc, lbs, ubs, steps); + + SmallVector ivs = llvm::map_to_vector( + loopNest.loops, [](scf::ForOp loop) { return loop.getInductionVar(); }); + + Block *innermostBlock = loopNest.loops.back().getBody(); + rewriter.eraseOp(forallOp.getBody()->getTerminator()); + rewriter.inlineBlockBefore(forallOp.getBody(), innermostBlock, + innermostBlock->getTerminator()->getIterator(), + ivs); + rewriter.eraseOp(forallOp); + + if (results) { + llvm::move(loopNest.loops, std::back_inserter(*results)); + } + + return success(); +} + +namespace { +struct ForallToForLoop : public impl::SCFForallToForLoopBase { + void runOnOperation() override { + Operation *parentOp = getOperation(); + IRRewriter rewriter(parentOp->getContext()); + + parentOp->walk([&](scf::ForallOp forallOp) { + if (failed(scf::forallToForLoop(rewriter, forallOp))) { + return signalPassFailure(); + } + }); + } +}; +} // namespace + +std::unique_ptr mlir::createForallToForLoopPass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/SCF/forall-to-for.mlir b/mlir/test/Dialect/SCF/forall-to-for.mlir new file mode 100644 index 0000000000000..e7d183fb9d2b5 --- /dev/null +++ b/mlir/test/Dialect/SCF/forall-to-for.mlir @@ -0,0 +1,57 @@ +// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(scf-forall-to-for))' -split-input-file | FileCheck %s + +func.func private @callee(%i: index, %j: index) + +// CHECK-LABEL: @two_iters +// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index +func.func @two_iters(%ub1: index, %ub2: index) { + scf.forall (%i, %j) in (%ub1, %ub2) { + func.call @callee(%i, %j) : (index, index) -> () + } + // CHECK: scf.for %[[IV1:.+]] = %{{.*}} to %[[UB1]] + // CHECK: scf.for %[[IV2:.+]] = %{{.*}} to %[[UB2]] + // CHECK: func.call @callee(%[[IV1]], %[[IV2]]) + return +} + +// ----- + +func.func private @callee(%i: index, %j: index) + +// CHECK-LABEL: @repeated +// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index +func.func @repeated(%ub1: index, %ub2: index) { + scf.forall (%i, %j) in (%ub1, %ub2) { + func.call @callee(%i, %j) : (index, index) -> () + } + // CHECK: scf.for %[[IV1:.+]] = %{{.*}} to %[[UB1]] + // CHECK: scf.for %[[IV2:.+]] = %{{.*}} to %[[UB2]] + // CHECK: func.call @callee(%[[IV1]], %[[IV2]]) + scf.forall (%i, %j) in (%ub1, %ub2) { + func.call @callee(%i, %j) : (index, index) -> () + } + // CHECK: scf.for %[[IV1:.+]] = %{{.*}} to %[[UB1]] + // CHECK: scf.for %[[IV2:.+]] = %{{.*}} to %[[UB2]] + // CHECK: func.call @callee(%[[IV1]], %[[IV2]]) + return +} + +// ----- + +func.func private @callee(%i: index, %j: index, %k: index, %l: index) + +// CHECK-LABEL: @nested +// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index, %[[UB3:.+]]: index, %[[UB4:.+]]: index +func.func @nested(%ub1: index, %ub2: index, %ub3: index, %ub4: index) { + // CHECK: scf.for %[[IV1:.+]] = %{{.*}} to %[[UB1]] + // CHECK: scf.for %[[IV2:.+]] = %{{.*}} to %[[UB2]] + // CHECK: scf.for %[[IV3:.+]] = %{{.*}} to %[[UB3]] + // CHECK: scf.for %[[IV4:.+]] = %{{.*}} to %[[UB4]] + // CHECK: func.call @callee(%[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]]) + scf.forall (%i, %j) in (%ub1, %ub2) { + scf.forall (%k, %l) in (%ub3, %ub4) { + func.call @callee(%i, %j, %k, %l) : (index, index, index, index) -> () + } + } + return +}