Skip to content

Commit cb30169

Browse files
authored
[flang] Use LLVM dialect ops for stack save/restore in target-rewrite (#107879)
Mostly NFC, I was bothered by the declaration that were always made even if unsued, and I think using LLVM Ops is nicer anyway with regards to side effects here. ``` func.func private @llvm.stacksave.p0() -> !fir.ref<i8> func.func private @llvm.stackrestore.p0(!fir.ref<i8>) ``` There are other places in lowering that are using the calls instead of the LLVM intrinsics, but I will deal with them another time (the issue there is mostly to get the proper address space for the llvm.ptr type).
1 parent 306b08c commit cb30169

File tree

5 files changed

+25
-26
lines changed

5 files changed

+25
-26
lines changed

flang/include/flang/Optimizer/CodeGen/CGPasses.td

+1-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def TargetRewritePass : Pass<"target-rewrite", "mlir::ModuleOp"> {
6868
representations that may differ based on the target machine.
6969
}];
7070
let dependentDialects = [ "fir::FIROpsDialect", "mlir::func::FuncDialect",
71-
"mlir::DLTIDialect" ];
71+
"mlir::DLTIDialect", "mlir::LLVM::LLVMDialect" ];
7272
let options = [
7373
Option<"forcedTargetTriple", "target", "std::string", /*default=*/"",
7474
"Override module's target triple.">,

flang/lib/Optimizer/CodeGen/TargetRewrite.cpp

+12-11
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "flang/Optimizer/Dialect/Support/FIRContext.h"
2828
#include "flang/Optimizer/Support/DataLayout.h"
2929
#include "mlir/Dialect/DLTI/DLTI.h"
30+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
3031
#include "mlir/Transforms/DialectConversion.h"
3132
#include "llvm/ADT/STLExtras.h"
3233
#include "llvm/ADT/TypeSwitch.h"
@@ -114,13 +115,6 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
114115

115116
setMembers(specifics.get(), &rewriter, &*dl);
116117

117-
// We may need to call stacksave/stackrestore later, so
118-
// create the FuncOps beforehand.
119-
fir::FirOpBuilder builder(rewriter, mod);
120-
builder.setInsertionPointToStart(mod.getBody());
121-
stackSaveFn = fir::factory::getLlvmStackSave(builder);
122-
stackRestoreFn = fir::factory::getLlvmStackRestore(builder);
123-
124118
// Perform type conversion on signatures and call sites.
125119
if (mlir::failed(convertTypes(mod))) {
126120
mlir::emitError(mlir::UnknownLoc::get(&context),
@@ -1242,22 +1236,29 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
12421236

12431237
inline void clearMembers() { setMembers(nullptr, nullptr, nullptr); }
12441238

1239+
uint64_t getAllocaAddressSpace() const {
1240+
if (dataLayout)
1241+
if (mlir::Attribute addrSpace = dataLayout->getAllocaMemorySpace())
1242+
return llvm::cast<mlir::IntegerAttr>(addrSpace).getUInt();
1243+
return 0;
1244+
}
1245+
12451246
// Inserts a call to llvm.stacksave at the current insertion
12461247
// point and the given location. Returns the call's result Value.
12471248
inline mlir::Value genStackSave(mlir::Location loc) {
1248-
return rewriter->create<fir::CallOp>(loc, stackSaveFn).getResult(0);
1249+
mlir::Type voidPtr = mlir::LLVM::LLVMPointerType::get(
1250+
rewriter->getContext(), getAllocaAddressSpace());
1251+
return rewriter->create<mlir::LLVM::StackSaveOp>(loc, voidPtr);
12491252
}
12501253

12511254
// Inserts a call to llvm.stackrestore at the current insertion
12521255
// point and the given location and argument.
12531256
inline void genStackRestore(mlir::Location loc, mlir::Value sp) {
1254-
rewriter->create<fir::CallOp>(loc, stackRestoreFn, mlir::ValueRange{sp});
1257+
rewriter->create<mlir::LLVM::StackRestoreOp>(loc, sp);
12551258
}
12561259

12571260
fir::CodeGenSpecifics *specifics = nullptr;
12581261
mlir::OpBuilder *rewriter = nullptr;
12591262
mlir::DataLayout *dataLayout = nullptr;
1260-
mlir::func::FuncOp stackSaveFn = nullptr;
1261-
mlir::func::FuncOp stackRestoreFn = nullptr;
12621263
};
12631264
} // namespace

flang/test/Fir/struct-passing-x86-64-one-field-inreg.fir

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@ func.func @test_call_i16(%0 : !fir.ref<!fir.type<ti16{i:i16}>>) {
1313
// CHECK-LABEL: func.func @test_call_i16(
1414
// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref<!fir.type<ti16{i:i16}>>) {
1515
// CHECK: %[[VAL_1:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.type<ti16{i:i16}>>
16-
// CHECK: %[[VAL_2:.*]] = fir.call @llvm.stacksave.p0() : () -> !fir.ref<i8>
16+
// CHECK: %[[VAL_2:.*]] = llvm.intr.stacksave : !llvm.ptr
1717
// CHECK: %[[VAL_3:.*]] = fir.alloca i16
1818
// CHECK: %[[VAL_4:.*]] = fir.convert %[[VAL_3]] : (!fir.ref<i16>) -> !fir.ref<!fir.type<ti16{i:i16}>>
1919
// CHECK: fir.store %[[VAL_1]] to %[[VAL_4]] : !fir.ref<!fir.type<ti16{i:i16}>>
2020
// CHECK: %[[VAL_5:.*]] = fir.load %[[VAL_3]] : !fir.ref<i16>
2121
// CHECK: fir.call @test_func_i16(%[[VAL_5]]) : (i16) -> ()
22-
// CHECK: fir.call @llvm.stackrestore.p0(%[[VAL_2]]) : (!fir.ref<i8>) -> ()
22+
// CHECK: llvm.intr.stackrestore %[[VAL_2]] : !llvm.ptr
2323

2424
func.func private @test_func_i16(%0 : !fir.type<ti16{i:i16}>) -> () {
2525
return

flang/test/Fir/struct-passing-x86-64-several-fields-inreg.fir

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@ func.func @test_call_i8_a16(%0 : !fir.ref<!fir.type<ti8_a16{a:!fir.array<16xi8>}
1414
// CHECK-LABEL: func.func @test_call_i8_a16(
1515
// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref<!fir.type<ti8_a16{a:!fir.array<16xi8>}>>) {
1616
// CHECK: %[[VAL_1:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.type<ti8_a16{a:!fir.array<16xi8>}>>
17-
// CHECK: %[[VAL_2:.*]] = fir.call @llvm.stacksave.p0() : () -> !fir.ref<i8>
17+
// CHECK: %[[VAL_2:.*]] = llvm.intr.stacksave : !llvm.ptr
1818
// CHECK: %[[VAL_3:.*]] = fir.alloca tuple<i64, i64>
1919
// CHECK: %[[VAL_4:.*]] = fir.convert %[[VAL_3]] : (!fir.ref<tuple<i64, i64>>) -> !fir.ref<!fir.type<ti8_a16{a:!fir.array<16xi8>}>>
2020
// CHECK: fir.store %[[VAL_1]] to %[[VAL_4]] : !fir.ref<!fir.type<ti8_a16{a:!fir.array<16xi8>}>>
2121
// CHECK: %[[VAL_5:.*]] = fir.load %[[VAL_3]] : !fir.ref<tuple<i64, i64>>
2222
// CHECK: %[[VAL_6:.*]] = fir.extract_value %[[VAL_5]], [0 : i32] : (tuple<i64, i64>) -> i64
2323
// CHECK: %[[VAL_7:.*]] = fir.extract_value %[[VAL_5]], [1 : i32] : (tuple<i64, i64>) -> i64
2424
// CHECK: fir.call @test_func_i8_a16(%[[VAL_6]], %[[VAL_7]]) : (i64, i64) -> ()
25-
// CHECK: fir.call @llvm.stackrestore.p0(%[[VAL_2]]) : (!fir.ref<i8>) -> ()
25+
// CHECK: llvm.intr.stackrestore %[[VAL_2]] : !llvm.ptr
2626
// CHECK: return
2727

2828
func.func private @test_func_i8_a16(%0 : !fir.type<ti8_a16{a:!fir.array<16xi8>}>) -> () {

flang/test/Fir/target-rewrite-complex16.fir

+8-10
Original file line numberDiff line numberDiff line change
@@ -63,18 +63,18 @@ func.func @addrof() {
6363
// CHECK: func.func private @paramcomplex16(!fir.ref<tuple<!fir.real<16>, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval = tuple<!fir.real<16>, !fir.real<16>>})
6464

6565
// CHECK-LABEL: func.func @callcomplex16() {
66-
// CHECK: %[[VAL_0:.*]] = fir.call @llvm.stacksave.p0() : () -> !fir.ref<i8>
66+
// CHECK: %[[VAL_0:.*]] = llvm.intr.stacksave : !llvm.ptr
6767
// CHECK: %[[VAL_1:.*]] = fir.alloca tuple<!fir.real<16>, !fir.real<16>>
6868
// CHECK: fir.call @returncomplex16(%[[VAL_1]]) : (!fir.ref<tuple<!fir.real<16>, !fir.real<16>>>) -> ()
6969
// CHECK: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (!fir.ref<tuple<!fir.real<16>, !fir.real<16>>>) -> !fir.ref<!fir.complex<16>>
7070
// CHECK: %[[VAL_3:.*]] = fir.load %[[VAL_2]] : !fir.ref<!fir.complex<16>>
71-
// CHECK: fir.call @llvm.stackrestore.p0(%[[VAL_0]]) : (!fir.ref<i8>) -> ()
72-
// CHECK: %[[VAL_4:.*]] = fir.call @llvm.stacksave.p0() : () -> !fir.ref<i8>
71+
// CHECK: llvm.intr.stackrestore %[[VAL_0]] : !llvm.ptr
72+
// CHECK: %[[VAL_4:.*]] = llvm.intr.stacksave : !llvm.ptr
7373
// CHECK: %[[VAL_5:.*]] = fir.alloca !fir.complex<16>
7474
// CHECK: fir.store %[[VAL_3]] to %[[VAL_5]] : !fir.ref<!fir.complex<16>>
7575
// CHECK: %[[VAL_6:.*]] = fir.convert %[[VAL_5]] : (!fir.ref<!fir.complex<16>>) -> !fir.ref<tuple<!fir.real<16>, !fir.real<16>>>
7676
// CHECK: fir.call @paramcomplex16(%[[VAL_6]]) : (!fir.ref<tuple<!fir.real<16>, !fir.real<16>>>) -> ()
77-
// CHECK: fir.call @llvm.stackrestore.p0(%[[VAL_4]]) : (!fir.ref<i8>) -> ()
77+
// CHECK: llvm.intr.stackrestore %[[VAL_4]] : !llvm.ptr
7878
// CHECK: return
7979
// CHECK: }
8080
// CHECK: func.func private @calleemultipleparamscomplex16(!fir.ref<tuple<!fir.real<16>, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval = tuple<!fir.real<16>, !fir.real<16>>}, !fir.ref<tuple<!fir.real<16>, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval = tuple<!fir.real<16>, !fir.real<16>>}, !fir.ref<tuple<!fir.real<16>, !fir.real<16>>> {llvm.align = 16 : i32, llvm.byval = tuple<!fir.real<16>, !fir.real<16>>})
@@ -87,7 +87,7 @@ func.func @addrof() {
8787
// CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_5]] : !fir.ref<!fir.complex<16>>
8888
// CHECK: %[[VAL_7:.*]] = fir.convert %[[VAL_0]] : (!fir.ref<tuple<!fir.real<16>, !fir.real<16>>>) -> !fir.ref<!fir.complex<16>>
8989
// CHECK: %[[VAL_8:.*]] = fir.load %[[VAL_7]] : !fir.ref<!fir.complex<16>>
90-
// CHECK: %[[VAL_9:.*]] = fir.call @llvm.stacksave.p0() : () -> !fir.ref<i8>
90+
// CHECK: %[[VAL_9:.*]] = llvm.intr.stacksave : !llvm.ptr
9191
// CHECK: %[[VAL_10:.*]] = fir.alloca !fir.complex<16>
9292
// CHECK: fir.store %[[VAL_8]] to %[[VAL_10]] : !fir.ref<!fir.complex<16>>
9393
// CHECK: %[[VAL_11:.*]] = fir.convert %[[VAL_10]] : (!fir.ref<!fir.complex<16>>) -> !fir.ref<tuple<!fir.real<16>, !fir.real<16>>>
@@ -98,7 +98,7 @@ func.func @addrof() {
9898
// CHECK: fir.store %[[VAL_4]] to %[[VAL_14]] : !fir.ref<!fir.complex<16>>
9999
// CHECK: %[[VAL_15:.*]] = fir.convert %[[VAL_14]] : (!fir.ref<!fir.complex<16>>) -> !fir.ref<tuple<!fir.real<16>, !fir.real<16>>>
100100
// CHECK: fir.call @calleemultipleparamscomplex16(%[[VAL_11]], %[[VAL_13]], %[[VAL_15]]) : (!fir.ref<tuple<!fir.real<16>, !fir.real<16>>>, !fir.ref<tuple<!fir.real<16>, !fir.real<16>>>, !fir.ref<tuple<!fir.real<16>, !fir.real<16>>>) -> ()
101-
// CHECK: fir.call @llvm.stackrestore.p0(%[[VAL_9]]) : (!fir.ref<i8>) -> ()
101+
// CHECK: llvm.intr.stackrestore %[[VAL_9]] : !llvm.ptr
102102
// CHECK: return
103103
// CHECK: }
104104

@@ -108,7 +108,7 @@ func.func @addrof() {
108108
// CHECK: %[[VAL_4:.*]] = fir.load %[[VAL_3]] : !fir.ref<complex<f128>>
109109
// CHECK: %[[VAL_5:.*]] = fir.convert %[[VAL_1]] : (!fir.ref<tuple<f128, f128>>) -> !fir.ref<complex<f128>>
110110
// CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_5]] : !fir.ref<complex<f128>>
111-
// CHECK: %[[VAL_7:.*]] = fir.call @llvm.stacksave.p0() : () -> !fir.ref<i8>
111+
// CHECK: %[[VAL_7:.*]] = llvm.intr.stacksave : !llvm.ptr
112112
// CHECK: %[[VAL_8:.*]] = fir.alloca tuple<f128, f128>
113113
// CHECK: %[[VAL_9:.*]] = fir.alloca complex<f128>
114114
// CHECK: fir.store %[[VAL_6]] to %[[VAL_9]] : !fir.ref<complex<f128>>
@@ -119,7 +119,7 @@ func.func @addrof() {
119119
// CHECK: fir.call @mlircomplexf128(%[[VAL_8]], %[[VAL_10]], %[[VAL_12]]) : (!fir.ref<tuple<f128, f128>>, !fir.ref<tuple<f128, f128>>, !fir.ref<tuple<f128, f128>>) -> ()
120120
// CHECK: %[[VAL_13:.*]] = fir.convert %[[VAL_8]] : (!fir.ref<tuple<f128, f128>>) -> !fir.ref<complex<f128>>
121121
// CHECK: %[[VAL_14:.*]] = fir.load %[[VAL_13]] : !fir.ref<complex<f128>>
122-
// CHECK: fir.call @llvm.stackrestore.p0(%[[VAL_7]]) : (!fir.ref<i8>) -> ()
122+
// CHECK: llvm.intr.stackrestore %[[VAL_7]] : !llvm.ptr
123123
// CHECK: %[[VAL_15:.*]] = fir.convert %[[VAL_0]] : (!fir.ref<tuple<f128, f128>>) -> !fir.ref<complex<f128>>
124124
// CHECK: fir.store %[[VAL_14]] to %[[VAL_15]] : !fir.ref<complex<f128>>
125125
// CHECK: return
@@ -130,5 +130,3 @@ func.func @addrof() {
130130
// CHECK: %[[VAL_1:.*]] = fir.address_of(@paramcomplex16) : (!fir.ref<tuple<!fir.real<16>, !fir.real<16>>>) -> ()
131131
// CHECK: return
132132
// CHECK: }
133-
// CHECK: func.func private @llvm.stacksave.p0() -> !fir.ref<i8>
134-
// CHECK: func.func private @llvm.stackrestore.p0(!fir.ref<i8>)

0 commit comments

Comments
 (0)