Skip to content

[flang][hlfir] Better recognize non-overlapping array sections. #65707

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 66 additions & 7 deletions flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,6 @@ static bool areIdenticalOrDisjointSlices(mlir::Value ref1, mlir::Value ref2) {
des1.getComponentShape() != des2.getComponentShape() ||
des1.getSubstring() != des2.getSubstring() ||
des1.getComplexPart() != des2.getComplexPart() ||
des1.getShape() != des2.getShape() ||
des1.getTypeparams() != des2.getTypeparams()) {
LLVM_DEBUG(llvm::dbgs() << "Different designator specs for:\n"
<< des1 << "and:\n"
Expand All @@ -211,27 +210,87 @@ static bool areIdenticalOrDisjointSlices(mlir::Value ref1, mlir::Value ref2) {
// If all the triplets (section speficiers) are the same, then
// we do not care if %0 is equal to %1 - the slices are either
// identical or completely disjoint.
//
// TODO: if we can prove that all non-triplet subscripts are different
// (by value), then we may return true regardless of the triplet
// values - the sections must be completely disjoint.
auto des1It = des1.getIndices().begin();
auto des2It = des2.getIndices().begin();
bool identicalTriplets = true;
for (bool isTriplet : des1.getIsTriplet()) {
if (isTriplet) {
for (int i = 0; i < 3; ++i)
if (*des1It++ != *des2It++) {
LLVM_DEBUG(llvm::dbgs() << "Triplet mismatch for:\n"
<< des1 << "and:\n"
<< des2 << "\n");
return false;
identicalTriplets = false;
break;
}
} else {
++des1It;
++des2It;
}
}
return true;
if (identicalTriplets)
return true;

// See if we can prove that any of the triplets do not overlap.
// This is mostly a Polyhedron/nf performance hack that looks for
// particular relations between the lower and upper bounds
// of the array sections, e.g. for any positive constant C:
// X:Y does not overlap with (Y+C):Z
// X:Y does not overlap with Z:(X-C)
auto displacedByConstant = [](mlir::Value v1, mlir::Value v2) {
auto removeConvert = [](mlir::Value v) -> mlir::Operation * {
auto *op = v.getDefiningOp();
while (auto conv = mlir::dyn_cast_or_null<fir::ConvertOp>(op))
op = conv.getValue().getDefiningOp();
return op;
};

auto isPositiveConstant = [](mlir::Value v) -> bool {
if (auto conOp =
mlir::dyn_cast<mlir::arith::ConstantOp>(v.getDefiningOp()))
if (auto iattr = conOp.getValue().dyn_cast<mlir::IntegerAttr>())
return iattr.getInt() > 0;
return false;
};

auto *op1 = removeConvert(v1);
auto *op2 = removeConvert(v2);
if (!op1 || !op2)
return false;
if (auto addi = mlir::dyn_cast<mlir::arith::AddIOp>(op2))
if ((addi.getLhs().getDefiningOp() == op1 &&
isPositiveConstant(addi.getRhs())) ||
(addi.getRhs().getDefiningOp() == op1 &&
isPositiveConstant(addi.getLhs())))
return true;
if (auto subi = mlir::dyn_cast<mlir::arith::SubIOp>(op1))
if (subi.getLhs().getDefiningOp() == op2 &&
isPositiveConstant(subi.getRhs()))
return true;
return false;
};

des1It = des1.getIndices().begin();
des2It = des2.getIndices().begin();
for (bool isTriplet : des1.getIsTriplet()) {
if (isTriplet) {
mlir::Value des1Lb = *des1It++;
mlir::Value des1Ub = *des1It++;
mlir::Value des2Lb = *des2It++;
mlir::Value des2Ub = *des2It++;
// Ignore strides.
++des1It;
++des2It;
if (displacedByConstant(des1Ub, des2Lb) ||
displacedByConstant(des2Ub, des1Lb))
return true;
} else {
++des1It;
++des2It;
}
}

return false;
}

std::optional<ElementalAssignBufferization::MatchInfo>
Expand Down
254 changes: 254 additions & 0 deletions flang/test/HLFIR/opt-array-slice-assign.fir
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,257 @@ func.func @_QPtest3(%arg0: !fir.ref<!fir.array<10x!fir.type<_QMtypesTt{x:!fir.ar
// CHECK: hlfir.assign %[[VAL_28]] to %[[VAL_29]] : f32, !fir.ref<f32>
// CHECK: }
// CHECK: }

// ! ub == lb - 1
// subroutine test4(x, i1, i2, nx)
// real :: x(i2), f
// do i=i1,i2,nx
// x(i:i+nx-1) = (x(i-nx:i-1))
// end do
// end subroutine test4
func.func @_QPtest4(%arg0: !fir.ref<!fir.array<?xf32>> {fir.bindc_name = "x"}, %arg1: !fir.ref<i32> {fir.bindc_name = "i1"}, %arg2: !fir.ref<i32> {fir.bindc_name = "i2"}, %arg3: !fir.ref<i32> {fir.bindc_name = "nx"}) {
%c1 = arith.constant 1 : index
%c1_i32 = arith.constant 1 : i32
%c0 = arith.constant 0 : index
%0 = fir.alloca f32 {bindc_name = "f", uniq_name = "_QFtest4Ef"}
%1:2 = hlfir.declare %0 {uniq_name = "_QFtest4Ef"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
%2 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFtest4Ei"}
%3:2 = hlfir.declare %2 {uniq_name = "_QFtest4Ei"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
%4:2 = hlfir.declare %arg1 {uniq_name = "_QFtest4Ei1"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
%5:2 = hlfir.declare %arg2 {uniq_name = "_QFtest4Ei2"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
%6:2 = hlfir.declare %arg3 {uniq_name = "_QFtest4Enx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
%7 = fir.load %5#0 : !fir.ref<i32>
%8 = fir.convert %7 : (i32) -> index
%9 = arith.cmpi sgt, %8, %c0 : index
%10 = arith.select %9, %8, %c0 : index
%11 = fir.shape %10 : (index) -> !fir.shape<1>
%12:2 = hlfir.declare %arg0(%11) {uniq_name = "_QFtest4Ex"} : (!fir.ref<!fir.array<?xf32>>, !fir.shape<1>) -> (!fir.box<!fir.array<?xf32>>, !fir.ref<!fir.array<?xf32>>)
%13 = fir.load %4#0 : !fir.ref<i32>
%14 = fir.convert %13 : (i32) -> index
%15 = fir.load %5#0 : !fir.ref<i32>
%16 = fir.convert %15 : (i32) -> index
%17 = fir.load %6#0 : !fir.ref<i32>
%18 = fir.convert %17 : (i32) -> index
%19 = fir.convert %14 : (index) -> i32
%20:2 = fir.do_loop %arg4 = %14 to %16 step %18 iter_args(%arg5 = %19) -> (index, i32) {
fir.store %arg5 to %3#1 : !fir.ref<i32>
%21 = fir.load %3#0 : !fir.ref<i32>
%22 = fir.load %6#0 : !fir.ref<i32>
%23 = arith.subi %21, %22 : i32
%24 = arith.subi %21, %c1_i32 : i32
%25 = fir.convert %23 : (i32) -> index
%26 = fir.convert %24 : (i32) -> index
%27 = arith.subi %26, %25 : index
%28 = arith.addi %27, %c1 : index
%29 = arith.cmpi sgt, %28, %c0 : index
%30 = arith.select %29, %28, %c0 : index
%31 = fir.shape %30 : (index) -> !fir.shape<1>
%32 = hlfir.designate %12#0 (%25:%26:%c1) shape %31 : (!fir.box<!fir.array<?xf32>>, index, index, index, !fir.shape<1>) -> !fir.box<!fir.array<?xf32>>
%33 = hlfir.elemental %31 unordered : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
^bb0(%arg6: index):
%48 = hlfir.designate %32 (%arg6) : (!fir.box<!fir.array<?xf32>>, index) -> !fir.ref<f32>
%49 = fir.load %48 : !fir.ref<f32>
%50 = hlfir.no_reassoc %49 : f32
hlfir.yield_element %50 : f32
}
%34 = arith.addi %21, %22 : i32
%35 = arith.subi %34, %c1_i32 : i32
%36 = fir.convert %21 : (i32) -> index
%37 = fir.convert %35 : (i32) -> index
%38 = arith.subi %37, %36 : index
%39 = arith.addi %38, %c1 : index
%40 = arith.cmpi sgt, %39, %c0 : index
%41 = arith.select %40, %39, %c0 : index
%42 = fir.shape %41 : (index) -> !fir.shape<1>
%43 = hlfir.designate %12#0 (%36:%37:%c1) shape %42 : (!fir.box<!fir.array<?xf32>>, index, index, index, !fir.shape<1>) -> !fir.box<!fir.array<?xf32>>
hlfir.assign %33 to %43 : !hlfir.expr<?xf32>, !fir.box<!fir.array<?xf32>>
hlfir.destroy %33 : !hlfir.expr<?xf32>
%44 = arith.addi %arg4, %18 : index
%45 = fir.convert %18 : (index) -> i32
%46 = fir.load %3#1 : !fir.ref<i32>
%47 = arith.addi %46, %45 : i32
fir.result %44, %47 : index, i32
}
fir.store %20#1 to %3#1 : !fir.ref<i32>
return
}
// CHECK-LABEL: func.func @_QPtest4(
// CHECK-NOT: hlfir.elemental

// ! lb == ub + 1
// subroutine test5(x, i1, i2, nx)
// real :: x(i2), f
// do i=i1,i2,nx
// x(i+1:i+nx-1) = (x(i-nx:i))
// end do
// end subroutine test5
func.func @_QPtest5(%arg0: !fir.ref<!fir.array<?xf32>> {fir.bindc_name = "x"}, %arg1: !fir.ref<i32> {fir.bindc_name = "i1"}, %arg2: !fir.ref<i32> {fir.bindc_name = "i2"}, %arg3: !fir.ref<i32> {fir.bindc_name = "nx"}) {
%c1_i32 = arith.constant 1 : i32
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%0 = fir.alloca f32 {bindc_name = "f", uniq_name = "_QFtest5Ef"}
%1:2 = hlfir.declare %0 {uniq_name = "_QFtest5Ef"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
%2 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFtest5Ei"}
%3:2 = hlfir.declare %2 {uniq_name = "_QFtest5Ei"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
%4:2 = hlfir.declare %arg1 {uniq_name = "_QFtest5Ei1"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
%5:2 = hlfir.declare %arg2 {uniq_name = "_QFtest5Ei2"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
%6:2 = hlfir.declare %arg3 {uniq_name = "_QFtest5Enx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
%7 = fir.load %5#0 : !fir.ref<i32>
%8 = fir.convert %7 : (i32) -> index
%9 = arith.cmpi sgt, %8, %c0 : index
%10 = arith.select %9, %8, %c0 : index
%11 = fir.shape %10 : (index) -> !fir.shape<1>
%12:2 = hlfir.declare %arg0(%11) {uniq_name = "_QFtest5Ex"} : (!fir.ref<!fir.array<?xf32>>, !fir.shape<1>) -> (!fir.box<!fir.array<?xf32>>, !fir.ref<!fir.array<?xf32>>)
%13 = fir.load %4#0 : !fir.ref<i32>
%14 = fir.convert %13 : (i32) -> index
%15 = fir.load %5#0 : !fir.ref<i32>
%16 = fir.convert %15 : (i32) -> index
%17 = fir.load %6#0 : !fir.ref<i32>
%18 = fir.convert %17 : (i32) -> index
%19 = fir.convert %14 : (index) -> i32
%20:2 = fir.do_loop %arg4 = %14 to %16 step %18 iter_args(%arg5 = %19) -> (index, i32) {
fir.store %arg5 to %3#1 : !fir.ref<i32>
%21 = fir.load %3#0 : !fir.ref<i32>
%22 = fir.load %6#0 : !fir.ref<i32>
%23 = arith.subi %21, %22 : i32
%24 = fir.convert %23 : (i32) -> index
%25 = fir.convert %21 : (i32) -> index
%26 = arith.subi %25, %24 : index
%27 = arith.addi %26, %c1 : index
%28 = arith.cmpi sgt, %27, %c0 : index
%29 = arith.select %28, %27, %c0 : index
%30 = fir.shape %29 : (index) -> !fir.shape<1>
%31 = hlfir.designate %12#0 (%24:%25:%c1) shape %30 : (!fir.box<!fir.array<?xf32>>, index, index, index, !fir.shape<1>) -> !fir.box<!fir.array<?xf32>>
%32 = hlfir.elemental %30 unordered : (!fir.shape<1>) -> !hlfir.expr<?xf32> {
^bb0(%arg6: index):
%48 = hlfir.designate %31 (%arg6) : (!fir.box<!fir.array<?xf32>>, index) -> !fir.ref<f32>
%49 = fir.load %48 : !fir.ref<f32>
%50 = hlfir.no_reassoc %49 : f32
hlfir.yield_element %50 : f32
}
%33 = arith.addi %21, %c1_i32 : i32
%34 = arith.addi %21, %22 : i32
%35 = arith.subi %34, %c1_i32 : i32
%36 = fir.convert %33 : (i32) -> index
%37 = fir.convert %35 : (i32) -> index
%38 = arith.subi %37, %36 : index
%39 = arith.addi %38, %c1 : index
%40 = arith.cmpi sgt, %39, %c0 : index
%41 = arith.select %40, %39, %c0 : index
%42 = fir.shape %41 : (index) -> !fir.shape<1>
%43 = hlfir.designate %12#0 (%36:%37:%c1) shape %42 : (!fir.box<!fir.array<?xf32>>, index, index, index, !fir.shape<1>) -> !fir.box<!fir.array<?xf32>>
hlfir.assign %32 to %43 : !hlfir.expr<?xf32>, !fir.box<!fir.array<?xf32>>
hlfir.destroy %32 : !hlfir.expr<?xf32>
%44 = arith.addi %arg4, %18 : index
%45 = fir.convert %18 : (index) -> i32
%46 = fir.load %3#1 : !fir.ref<i32>
%47 = arith.addi %46, %45 : i32
fir.result %44, %47 : index, i32
}
fir.store %20#1 to %3#1 : !fir.ref<i32>
return
}
// CHECK-LABEL: func.func @_QPtest5(
// CHECK-NOT: hlfir.elemental

// ! ub = lb - 1 and dim1 is unknown
// ! FIR lowering produces a temp.
// subroutine test6(x, i1, i2, nx)
// real :: x(i2,i2), f
// integer n1, n2, n3, n4
// do i=i1,i2,nx
// x(i:i+nx-1,n1:n2) = (x(i-nx:i-1,n3:n4))
// end do
// end subroutine test6
func.func @_QPtest6(%arg0: !fir.ref<!fir.array<?x?xf32>> {fir.bindc_name = "x"}, %arg1: !fir.ref<i32> {fir.bindc_name = "i1"}, %arg2: !fir.ref<i32> {fir.bindc_name = "i2"}, %arg3: !fir.ref<i32> {fir.bindc_name = "nx"}) {
%c1 = arith.constant 1 : index
%c1_i32 = arith.constant 1 : i32
%c0 = arith.constant 0 : index
%0 = fir.alloca f32 {bindc_name = "f", uniq_name = "_QFtest6Ef"}
%1:2 = hlfir.declare %0 {uniq_name = "_QFtest6Ef"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
%2 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFtest6Ei"}
%3:2 = hlfir.declare %2 {uniq_name = "_QFtest6Ei"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
%4:2 = hlfir.declare %arg1 {uniq_name = "_QFtest6Ei1"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
%5:2 = hlfir.declare %arg2 {uniq_name = "_QFtest6Ei2"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
%6 = fir.alloca i32 {bindc_name = "n1", uniq_name = "_QFtest6En1"}
%7:2 = hlfir.declare %6 {uniq_name = "_QFtest6En1"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
%8 = fir.alloca i32 {bindc_name = "n2", uniq_name = "_QFtest6En2"}
%9:2 = hlfir.declare %8 {uniq_name = "_QFtest6En2"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
%10 = fir.alloca i32 {bindc_name = "n3", uniq_name = "_QFtest6En3"}
%11:2 = hlfir.declare %10 {uniq_name = "_QFtest6En3"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
%12 = fir.alloca i32 {bindc_name = "n4", uniq_name = "_QFtest6En4"}
%13:2 = hlfir.declare %12 {uniq_name = "_QFtest6En4"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
%14:2 = hlfir.declare %arg3 {uniq_name = "_QFtest6Enx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
%15 = fir.load %5#0 : !fir.ref<i32>
%16 = fir.convert %15 : (i32) -> index
%17 = arith.cmpi sgt, %16, %c0 : index
%18 = arith.select %17, %16, %c0 : index
%19 = fir.shape %18, %18 : (index, index) -> !fir.shape<2>
%20:2 = hlfir.declare %arg0(%19) {uniq_name = "_QFtest6Ex"} : (!fir.ref<!fir.array<?x?xf32>>, !fir.shape<2>) -> (!fir.box<!fir.array<?x?xf32>>, !fir.ref<!fir.array<?x?xf32>>)
%21 = fir.load %4#0 : !fir.ref<i32>
%22 = fir.convert %21 : (i32) -> index
%23 = fir.load %5#0 : !fir.ref<i32>
%24 = fir.convert %23 : (i32) -> index
%25 = fir.load %14#0 : !fir.ref<i32>
%26 = fir.convert %25 : (i32) -> index
%27 = fir.convert %22 : (index) -> i32
%28:2 = fir.do_loop %arg4 = %22 to %24 step %26 iter_args(%arg5 = %27) -> (index, i32) {
fir.store %arg5 to %3#1 : !fir.ref<i32>
%29 = fir.load %3#0 : !fir.ref<i32>
%30 = fir.load %14#0 : !fir.ref<i32>
%31 = arith.subi %29, %30 : i32
%32 = arith.subi %29, %c1_i32 : i32
%33 = fir.convert %31 : (i32) -> index
%34 = fir.convert %32 : (i32) -> index
%35 = arith.subi %34, %33 : index
%36 = arith.addi %35, %c1 : index
%37 = arith.cmpi sgt, %36, %c0 : index
%38 = arith.select %37, %36, %c0 : index
%39 = fir.load %11#0 : !fir.ref<i32>
%40 = fir.load %13#0 : !fir.ref<i32>
%41 = fir.convert %39 : (i32) -> index
%42 = fir.convert %40 : (i32) -> index
%43 = arith.subi %42, %41 : index
%44 = arith.addi %43, %c1 : index
%45 = arith.cmpi sgt, %44, %c0 : index
%46 = arith.select %45, %44, %c0 : index
%47 = fir.shape %38, %46 : (index, index) -> !fir.shape<2>
%48 = hlfir.designate %20#0 (%33:%34:%c1, %41:%42:%c1) shape %47 : (!fir.box<!fir.array<?x?xf32>>, index, index, index, index, index, index, !fir.shape<2>) -> !fir.box<!fir.array<?x?xf32>>
%49 = hlfir.elemental %47 unordered : (!fir.shape<2>) -> !hlfir.expr<?x?xf32> {
^bb0(%arg6: index, %arg7: index):
%72 = hlfir.designate %48 (%arg6, %arg7) : (!fir.box<!fir.array<?x?xf32>>, index, index) -> !fir.ref<f32>
%73 = fir.load %72 : !fir.ref<f32>
%74 = hlfir.no_reassoc %73 : f32
hlfir.yield_element %74 : f32
}
%50 = arith.addi %29, %30 : i32
%51 = arith.subi %50, %c1_i32 : i32
%52 = fir.convert %29 : (i32) -> index
%53 = fir.convert %51 : (i32) -> index
%54 = arith.subi %53, %52 : index
%55 = arith.addi %54, %c1 : index
%56 = arith.cmpi sgt, %55, %c0 : index
%57 = arith.select %56, %55, %c0 : index
%58 = fir.load %7#0 : !fir.ref<i32>
%59 = fir.load %9#0 : !fir.ref<i32>
%60 = fir.convert %58 : (i32) -> index
%61 = fir.convert %59 : (i32) -> index
%62 = arith.subi %61, %60 : index
%63 = arith.addi %62, %c1 : index
%64 = arith.cmpi sgt, %63, %c0 : index
%65 = arith.select %64, %63, %c0 : index
%66 = fir.shape %57, %65 : (index, index) -> !fir.shape<2>
%67 = hlfir.designate %20#0 (%52:%53:%c1, %60:%61:%c1) shape %66 : (!fir.box<!fir.array<?x?xf32>>, index, index, index, index, index, index, !fir.shape<2>) -> !fir.box<!fir.array<?x?xf32>>
hlfir.assign %49 to %67 : !hlfir.expr<?x?xf32>, !fir.box<!fir.array<?x?xf32>>
hlfir.destroy %49 : !hlfir.expr<?x?xf32>
%68 = arith.addi %arg4, %26 : index
%69 = fir.convert %26 : (index) -> i32
%70 = fir.load %3#1 : !fir.ref<i32>
%71 = arith.addi %70, %69 : i32
fir.result %68, %71 : index, i32
}
fir.store %28#1 to %3#1 : !fir.ref<i32>
return
}
// CHECK-LABEL: func.func @_QPtest6(
// CHECK-NOT: hlfir.elemental