-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[NVPTX] support switch statement with brx.idx #102400
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-llvm-selectiondag @llvm/pr-subscribers-backend-nvptx Author: Alex MacLean (AlexMaclean) ChangesAdd custom lowering for Full diff: https://github.com/llvm/llvm-project/pull/102400.diff 6 Files Affected:
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 9ccdbab008aec8..5b2214fa66c40b 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -3843,6 +3843,10 @@ class TargetLowering : public TargetLoweringBase {
/// returned value is a member of the MachineJumpTableInfo::JTEntryKind enum.
virtual unsigned getJumpTableEncoding() const;
+ virtual MVT getJumpTableRegTy(const DataLayout &DL) const {
+ return getPointerTy(DL);
+ }
+
virtual const MCExpr *
LowerCustomJumpTableEntry(const MachineJumpTableInfo * /*MJTI*/,
const MachineBasicBlock * /*MBB*/, unsigned /*uid*/,
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 9d617c7acd13c2..192fbf74b02dc0 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -2977,7 +2977,7 @@ void SelectionDAGBuilder::visitJumpTable(SwitchCG::JumpTable &JT) {
// Emit the code for the jump table
assert(JT.SL && "Should set SDLoc for SelectionDAG!");
assert(JT.Reg != -1U && "Should lower JT Header first!");
- EVT PTy = DAG.getTargetLoweringInfo().getPointerTy(DAG.getDataLayout());
+ EVT PTy = DAG.getTargetLoweringInfo().getJumpTableRegTy(DAG.getDataLayout());
SDValue Index = DAG.getCopyFromReg(getControlRoot(), *JT.SL, JT.Reg, PTy);
SDValue Table = DAG.getJumpTable(JT.JTI, PTy);
SDValue BrJumpTable = DAG.getNode(ISD::BR_JT, *JT.SL, MVT::Other,
@@ -3005,12 +3005,13 @@ void SelectionDAGBuilder::visitJumpTableHeader(SwitchCG::JumpTable &JT,
// This value may be smaller or larger than the target's pointer type, and
// therefore require extension or truncating.
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
- SwitchOp = DAG.getZExtOrTrunc(Sub, dl, TLI.getPointerTy(DAG.getDataLayout()));
+ SwitchOp =
+ DAG.getZExtOrTrunc(Sub, dl, TLI.getJumpTableRegTy(DAG.getDataLayout()));
unsigned JumpTableReg =
- FuncInfo.CreateReg(TLI.getPointerTy(DAG.getDataLayout()));
- SDValue CopyTo = DAG.getCopyToReg(getControlRoot(), dl,
- JumpTableReg, SwitchOp);
+ FuncInfo.CreateReg(TLI.getJumpTableRegTy(DAG.getDataLayout()));
+ SDValue CopyTo =
+ DAG.getCopyToReg(getControlRoot(), dl, JumpTableReg, SwitchOp);
JT.Reg = JumpTableReg;
if (!JTH.FallthroughUnreachable) {
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 516fc7339a4bf3..bf647c88f00e28 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -25,6 +25,7 @@
#include "llvm/CodeGen/Analysis.h"
#include "llvm/CodeGen/ISDOpcodes.h"
#include "llvm/CodeGen/MachineFunction.h"
+#include "llvm/CodeGen/MachineJumpTableInfo.h"
#include "llvm/CodeGen/MachineMemOperand.h"
#include "llvm/CodeGen/SelectionDAG.h"
#include "llvm/CodeGen/SelectionDAGNodes.h"
@@ -582,9 +583,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setOperationAction(ISD::ROTR, MVT::i8, Expand);
setOperationAction(ISD::BSWAP, MVT::i16, Expand);
- // Indirect branch is not supported.
- // This also disables Jump Table creation.
- setOperationAction(ISD::BR_JT, MVT::Other, Expand);
+ setOperationAction(ISD::BR_JT, MVT::Other, Custom);
setOperationAction(ISD::BRIND, MVT::Other, Expand);
setOperationAction(ISD::GlobalAddress, MVT::i32, Custom);
@@ -945,6 +944,9 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(NVPTXISD::Dummy)
MAKE_CASE(NVPTXISD::MUL_WIDE_SIGNED)
MAKE_CASE(NVPTXISD::MUL_WIDE_UNSIGNED)
+ MAKE_CASE(NVPTXISD::BrxEnd)
+ MAKE_CASE(NVPTXISD::BrxItem)
+ MAKE_CASE(NVPTXISD::BrxStart)
MAKE_CASE(NVPTXISD::Tex1DFloatS32)
MAKE_CASE(NVPTXISD::Tex1DFloatFloat)
MAKE_CASE(NVPTXISD::Tex1DFloatFloatLevel)
@@ -2785,6 +2787,8 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
return LowerFP_ROUND(Op, DAG);
case ISD::FP_EXTEND:
return LowerFP_EXTEND(Op, DAG);
+ case ISD::BR_JT:
+ return LowerBR_JT(Op, DAG);
case ISD::VAARG:
return LowerVAARG(Op, DAG);
case ISD::VASTART:
@@ -2810,6 +2814,41 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
}
}
+SDValue NVPTXTargetLowering::LowerBR_JT(SDValue Op, SelectionDAG &DAG) const {
+ SDLoc DL(Op);
+ SDValue Chain = Op.getOperand(0);
+ const auto *JT = cast<JumpTableSDNode>(Op.getOperand(1));
+ SDValue Index = Op.getOperand(2);
+
+ unsigned JId = JT->getIndex();
+ MachineJumpTableInfo *MJTI = DAG.getMachineFunction().getJumpTableInfo();
+ ArrayRef<MachineBasicBlock *> MBBs = MJTI->getJumpTables()[JId].MBBs;
+
+ SDValue IdV = DAG.getConstant(JId, DL, MVT::i32);
+
+ // Generate BrxStart node
+ SDVTList VTs = DAG.getVTList(MVT::Other, MVT::Glue);
+ Chain = DAG.getNode(NVPTXISD::BrxStart, DL, VTs, Chain, IdV);
+
+ // Generate BrxItem nodes
+ assert(!MBBs.empty());
+ for (MachineBasicBlock *MBB : MBBs.drop_back())
+ Chain = DAG.getNode(NVPTXISD::BrxItem, DL, VTs, Chain.getValue(0),
+ DAG.getBasicBlock(MBB), Chain.getValue(1));
+
+ // Generate BrxEnd nodes
+ SDValue EndOps[] = {Chain.getValue(0), DAG.getBasicBlock(MBBs.back()), Index,
+ IdV, Chain.getValue(1)};
+ SDValue BrxEnd = DAG.getNode(NVPTXISD::BrxEnd, DL, VTs, EndOps);
+
+ return BrxEnd;
+}
+
+// This will prevent AsmPrinter from trying to print the jump tables itself.
+unsigned NVPTXTargetLowering::getJumpTableEncoding() const {
+ return MachineJumpTableInfo::EK_Inline;
+}
+
// This function is almost a copy of SelectionDAG::expandVAArg().
// The only diff is that this one produces loads from local address space.
SDValue NVPTXTargetLowering::LowerVAARG(SDValue Op, SelectionDAG &DAG) const {
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 63262961b363ed..32e6b044b0de1f 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -62,6 +62,9 @@ enum NodeType : unsigned {
BFI,
PRMT,
DYNAMIC_STACKALLOC,
+ BrxStart,
+ BrxItem,
+ BrxEnd,
Dummy,
LoadV2 = ISD::FIRST_TARGET_MEMORY_OPCODE,
@@ -580,6 +583,11 @@ class NVPTXTargetLowering : public TargetLowering {
return true;
}
+ // The default is the same as pointer type, but brx.idx only accepts i32
+ MVT getJumpTableRegTy(const DataLayout &) const override { return MVT::i32; }
+
+ unsigned getJumpTableEncoding() const override;
+
bool enableAggressiveFMAFusion(EVT VT) const override { return true; }
// The default is to transform llvm.ctlz(x, false) (where false indicates that
@@ -637,6 +645,8 @@ class NVPTXTargetLowering : public TargetLowering {
SDValue LowerSelect(SDValue Op, SelectionDAG &DAG) const;
+ SDValue LowerBR_JT(SDValue Op, SelectionDAG &DAG) const;
+
SDValue LowerVAARG(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerVASTART(SDValue Op, SelectionDAG &DAG) const;
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 6a096fa5acea7c..cec7f20255d352 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -3880,6 +3880,44 @@ def DYNAMIC_STACKALLOC64 :
[(set Int64Regs:$ptr, (dyn_alloca Int64Regs:$size, (i32 timm:$align)))]>,
Requires<[hasPTX<73>, hasSM<52>]>;
+
+//
+// BRX
+//
+
+def SDTBrxStartProfile : SDTypeProfile<0, 1, [SDTCisInt<0>]>;
+def SDTBrxItemProfile : SDTypeProfile<0, 1, [SDTCisVT<0, OtherVT>]>;
+def SDTBrxEndProfile : SDTypeProfile<0, 3, [SDTCisVT<0, OtherVT>, SDTCisInt<1>, SDTCisInt<2>]>;
+
+def brx_start :
+ SDNode<"NVPTXISD::BrxStart", SDTBrxStartProfile,
+ [SDNPHasChain, SDNPOutGlue, SDNPSideEffect]>;
+def brx_item :
+ SDNode<"NVPTXISD::BrxItem", SDTBrxItemProfile,
+ [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>;
+def brx_end :
+ SDNode<"NVPTXISD::BrxEnd", SDTBrxEndProfile,
+ [SDNPHasChain, SDNPInGlue, SDNPSideEffect]>;
+
+let isTerminator = 1, isBranch = 1, isIndirectBranch = 1 in {
+
+ def BRX_START :
+ NVPTXInst<(outs), (ins i32imm:$id),
+ "$$L_brx_$id: .branchtargets",
+ [(brx_start (i32 imm:$id))]>;
+
+ def BRX_ITEM :
+ NVPTXInst<(outs), (ins brtarget:$target),
+ "$target,",
+ [(brx_item bb:$target)]>;
+
+ def BRX_END :
+ NVPTXInst<(outs), (ins brtarget:$target, Int32Regs:$val, i32imm:$id),
+ "$target;\n\tbrx.idx \t$val, $$L_brx_$id;",
+ [(brx_end bb:$target, (i32 Int32Regs:$val), (i32 imm:$id))]>;
+}
+
+
include "NVPTXIntrinsics.td"
//-----------------------------------
diff --git a/llvm/test/CodeGen/NVPTX/jump-table.ll b/llvm/test/CodeGen/NVPTX/jump-table.ll
new file mode 100644
index 00000000000000..8dd4115e2feb63
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/jump-table.ll
@@ -0,0 +1,69 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s | FileCheck %s
+; RUN: %if ptxas %{ llc < %s | %ptxas-verify %}
+
+target triple = "nvptx64-nvidia-cuda"
+
+@out = addrspace(1) global i32 0, align 4
+
+define void @foo(i32 %i) {
+; CHECK-LABEL: foo(
+; CHECK: {
+; CHECK-NEXT: .reg .pred %p<2>;
+; CHECK-NEXT: .reg .b32 %r<7>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0: // %entry
+; CHECK-NEXT: ld.param.u32 %r2, [foo_param_0];
+; CHECK-NEXT: setp.gt.u32 %p1, %r2, 3;
+; CHECK-NEXT: @%p1 bra $L__BB0_6;
+; CHECK-NEXT: // %bb.1: // %entry
+; CHECK-NEXT: $L_brx_0: .branchtargets
+; CHECK-NEXT: $L__BB0_2,
+; CHECK-NEXT: $L__BB0_3,
+; CHECK-NEXT: $L__BB0_4,
+; CHECK-NEXT: $L__BB0_5;
+; CHECK-NEXT: brx.idx %r2, $L_brx_0;
+; CHECK-NEXT: $L__BB0_2: // %case0
+; CHECK-NEXT: mov.b32 %r6, 0;
+; CHECK-NEXT: st.global.u32 [out], %r6;
+; CHECK-NEXT: bra.uni $L__BB0_6;
+; CHECK-NEXT: $L__BB0_4: // %case2
+; CHECK-NEXT: mov.b32 %r4, 2;
+; CHECK-NEXT: st.global.u32 [out], %r4;
+; CHECK-NEXT: bra.uni $L__BB0_6;
+; CHECK-NEXT: $L__BB0_5: // %case3
+; CHECK-NEXT: mov.b32 %r3, 3;
+; CHECK-NEXT: st.global.u32 [out], %r3;
+; CHECK-NEXT: bra.uni $L__BB0_6;
+; CHECK-NEXT: $L__BB0_3: // %case1
+; CHECK-NEXT: mov.b32 %r5, 1;
+; CHECK-NEXT: st.global.u32 [out], %r5;
+; CHECK-NEXT: $L__BB0_6: // %end
+; CHECK-NEXT: ret;
+entry:
+ switch i32 %i, label %end [
+ i32 0, label %case0
+ i32 1, label %case1
+ i32 2, label %case2
+ i32 3, label %case3
+ ]
+
+case0:
+ store i32 0, ptr addrspace(1) @out, align 4
+ br label %end
+
+case1:
+ store i32 1, ptr addrspace(1) @out, align 4
+ br label %end
+
+case2:
+ store i32 2, ptr addrspace(1) @out, align 4
+ br label %end
+
+case3:
+ store i32 3, ptr addrspace(1) @out, align 4
+ br label %end
+
+end:
+ ret void
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
Does brx.idx buy us a better code on SASS level? or is this mostly cosmetic sugar on PTX level?
|
||
def BRX_ITEM : | ||
NVPTXInst<(outs), (ins brtarget:$target), | ||
"$target,", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: We may want to indent the labels in the list. Right now they seem to end up aligned on the instruction boundary, while they are actually arguments of the .branchtargets
above. Does not impact functionality, but it looks somewhat odd.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
This does lead to improvements on the SASS level in some cases, though this is of course very dependent on |
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/180/builds/2997 Here is the relevant piece of the build log for the reference:
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/160/builds/2999 Here is the relevant piece of the build log for the reference:
|
Reverts #102400 Causes LLVM to crash on some tests.
Add custom lowering for `BR_JT` DAG nodes to the `brx.idx` PTX instruction ([PTX ISA 9.7.13.4. Control Flow Instructions: brx.idx] (https://docs.nvidia.com/cuda/parallel-thread-execution/#control-flow-instructions-brx-idx)). Depending on the heuristics in DAG selection, `switch` statements may now be lowered using `brx.idx`
Add custom lowering for `BR_JT` DAG nodes to the `brx.idx` PTX instruction ([PTX ISA 9.7.13.4. Control Flow Instructions: brx.idx] (https://docs.nvidia.com/cuda/parallel-thread-execution/#control-flow-instructions-brx-idx)). Depending on the heuristics in DAG selection, `switch` statements may now be lowered using `brx.idx`. Note: this fixes the previous issue in #102400 by adding the isBarrier attribute to BRX_END
…af56bc521 Local branch amd-gfx 77eaf56 Merged main:4fe33d067c5d0894d0059418f09edc531f16ac9f into amd-gfx:5fa38fbc60f8 Remote branch main ba97697 [NVPTX] support switch statement with brx.idx (llvm#102400)
Add custom lowering for
BR_JT
DAG nodes to thebrx.idx
PTX instruction (PTX ISA 9.7.13.4. Control Flow Instructions: brx.idx). Depending on the heuristics in DAG selection,switch
statements may now be lowered usingbrx.idx