Skip to content

Commit e712f1e

Browse files
committed
[NVPTX] Make i16x2 a native type and add supported vec instructions (#65432)
commit again b3a14ca.
1 parent 53b3be7 commit e712f1e

16 files changed

+798
-210
lines changed

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

+19-19
Original file line numberDiff line numberDiff line change
@@ -612,10 +612,10 @@ bool NVPTXDAGToDAGISel::SelectSETP_F16X2(SDNode *N) {
612612
bool NVPTXDAGToDAGISel::tryEXTRACT_VECTOR_ELEMENT(SDNode *N) {
613613
SDValue Vector = N->getOperand(0);
614614

615-
// We only care about f16x2 as it's the only real vector type we
615+
// We only care about 16x2 as it's the only real vector type we
616616
// need to deal with.
617617
MVT VT = Vector.getSimpleValueType();
618-
if (!(VT == MVT::v2f16 || VT == MVT::v2bf16))
618+
if (!Isv2x16VT(VT))
619619
return false;
620620
// Find and record all uses of this vector that extract element 0 or 1.
621621
SmallVector<SDNode *, 4> E0, E1;
@@ -828,6 +828,7 @@ pickOpcodeForVT(MVT::SimpleValueType VT, unsigned Opcode_i8,
828828
return Opcode_i16;
829829
case MVT::v2f16:
830830
case MVT::v2bf16:
831+
case MVT::v2i16:
831832
return Opcode_i32;
832833
case MVT::f32:
833834
return Opcode_f32;
@@ -909,9 +910,8 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
909910
// Vector Setting
910911
unsigned vecType = NVPTX::PTXLdStInstCode::Scalar;
911912
if (SimpleVT.isVector()) {
912-
assert((LoadedVT == MVT::v2f16 || LoadedVT == MVT::v2bf16) &&
913-
"Unexpected vector type");
914-
// v2f16/v2bf16 is loaded using ld.b32
913+
assert(Isv2x16VT(LoadedVT) && "Unexpected vector type");
914+
// v2f16/v2bf16/v2i16 is loaded using ld.b32
915915
fromTypeWidth = 32;
916916
}
917917

@@ -1061,10 +1061,10 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
10611061

10621062
EVT EltVT = N->getValueType(0);
10631063

1064-
// v8f16 is a special case. PTX doesn't have ld.v8.f16
1065-
// instruction. Instead, we split the vector into v2f16 chunks and
1064+
// v8x16 is a special case. PTX doesn't have ld.v8.16
1065+
// instruction. Instead, we split the vector into v2x16 chunks and
10661066
// load them with ld.v4.b32.
1067-
if (EltVT == MVT::v2f16 || EltVT == MVT::v2bf16) {
1067+
if (Isv2x16VT(EltVT)) {
10681068
assert(N->getOpcode() == NVPTXISD::LoadV4 && "Unexpected load opcode.");
10691069
EltVT = MVT::i32;
10701070
FromType = NVPTX::PTXLdStInstCode::Untyped;
@@ -1260,12 +1260,13 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
12601260
if (EltVT.isVector()) {
12611261
NumElts = EltVT.getVectorNumElements();
12621262
EltVT = EltVT.getVectorElementType();
1263-
// vectors of f16 are loaded/stored as multiples of v2f16 elements.
1263+
// vectors of 16bits type are loaded/stored as multiples of v2x16 elements.
12641264
if ((EltVT == MVT::f16 && N->getValueType(0) == MVT::v2f16) ||
1265-
(EltVT == MVT::bf16 && N->getValueType(0) == MVT::v2bf16)) {
1266-
assert(NumElts % 2 == 0 && "Vector must have even number of elements");
1267-
EltVT = N->getValueType(0);
1268-
NumElts /= 2;
1265+
(EltVT == MVT::bf16 && N->getValueType(0) == MVT::v2bf16) ||
1266+
(EltVT == MVT::i16 && N->getValueType(0) == MVT::v2i16)) {
1267+
assert(NumElts % 2 == 0 && "Vector must have even number of elements");
1268+
EltVT = N->getValueType(0);
1269+
NumElts /= 2;
12691270
}
12701271
}
12711272

@@ -1678,9 +1679,8 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
16781679
MVT ScalarVT = SimpleVT.getScalarType();
16791680
unsigned toTypeWidth = ScalarVT.getSizeInBits();
16801681
if (SimpleVT.isVector()) {
1681-
assert((StoreVT == MVT::v2f16 || StoreVT == MVT::v2bf16) &&
1682-
"Unexpected vector type");
1683-
// v2f16 is stored using st.b32
1682+
assert(Isv2x16VT(StoreVT) && "Unexpected vector type");
1683+
// v2x16 is stored using st.b32
16841684
toTypeWidth = 32;
16851685
}
16861686

@@ -1844,10 +1844,10 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
18441844
return false;
18451845
}
18461846

1847-
// v8f16 is a special case. PTX doesn't have st.v8.f16
1848-
// instruction. Instead, we split the vector into v2f16 chunks and
1847+
// v8x16 is a special case. PTX doesn't have st.v8.x16
1848+
// instruction. Instead, we split the vector into v2x16 chunks and
18491849
// store them with st.v4.b32.
1850-
if (EltVT == MVT::v2f16 || EltVT == MVT::v2bf16) {
1850+
if (Isv2x16VT(EltVT)) {
18511851
assert(N->getOpcode() == NVPTXISD::StoreV4 && "Unexpected load opcode.");
18521852
EltVT = MVT::i32;
18531853
ToType = NVPTX::PTXLdStInstCode::Untyped;

0 commit comments

Comments
 (0)