Skip to content

Commit b3a14ca

Browse files
committed
Revert "[NVPTX] Make i16x2 a native type and add supported vec instructions (#65432)"
This reverts commit db5d845. As per PR discussion "Looks like we've missed lowering of bitcasts between v2f16 and v2i16 and it breaks XLA."
1 parent dc9a1f0 commit b3a14ca

16 files changed

+210
-798
lines changed

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

+19-19
Original file line numberDiff line numberDiff line change
@@ -612,10 +612,10 @@ bool NVPTXDAGToDAGISel::SelectSETP_F16X2(SDNode *N) {
612612
bool NVPTXDAGToDAGISel::tryEXTRACT_VECTOR_ELEMENT(SDNode *N) {
613613
SDValue Vector = N->getOperand(0);
614614

615-
// We only care about 16x2 as it's the only real vector type we
615+
// We only care about f16x2 as it's the only real vector type we
616616
// need to deal with.
617617
MVT VT = Vector.getSimpleValueType();
618-
if (!Isv2x16VT(VT))
618+
if (!(VT == MVT::v2f16 || VT == MVT::v2bf16))
619619
return false;
620620
// Find and record all uses of this vector that extract element 0 or 1.
621621
SmallVector<SDNode *, 4> E0, E1;
@@ -828,7 +828,6 @@ pickOpcodeForVT(MVT::SimpleValueType VT, unsigned Opcode_i8,
828828
return Opcode_i16;
829829
case MVT::v2f16:
830830
case MVT::v2bf16:
831-
case MVT::v2i16:
832831
return Opcode_i32;
833832
case MVT::f32:
834833
return Opcode_f32;
@@ -910,8 +909,9 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
910909
// Vector Setting
911910
unsigned vecType = NVPTX::PTXLdStInstCode::Scalar;
912911
if (SimpleVT.isVector()) {
913-
assert(Isv2x16VT(LoadedVT) && "Unexpected vector type");
914-
// v2f16/v2bf16/v2i16 is loaded using ld.b32
912+
assert((LoadedVT == MVT::v2f16 || LoadedVT == MVT::v2bf16) &&
913+
"Unexpected vector type");
914+
// v2f16/v2bf16 is loaded using ld.b32
915915
fromTypeWidth = 32;
916916
}
917917

@@ -1061,10 +1061,10 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
10611061

10621062
EVT EltVT = N->getValueType(0);
10631063

1064-
// v8x16 is a special case. PTX doesn't have ld.v8.16
1065-
// instruction. Instead, we split the vector into v2x16 chunks and
1064+
// v8f16 is a special case. PTX doesn't have ld.v8.f16
1065+
// instruction. Instead, we split the vector into v2f16 chunks and
10661066
// load them with ld.v4.b32.
1067-
if (Isv2x16VT(EltVT)) {
1067+
if (EltVT == MVT::v2f16 || EltVT == MVT::v2bf16) {
10681068
assert(N->getOpcode() == NVPTXISD::LoadV4 && "Unexpected load opcode.");
10691069
EltVT = MVT::i32;
10701070
FromType = NVPTX::PTXLdStInstCode::Untyped;
@@ -1260,13 +1260,12 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
12601260
if (EltVT.isVector()) {
12611261
NumElts = EltVT.getVectorNumElements();
12621262
EltVT = EltVT.getVectorElementType();
1263-
// vectors of 16bits type are loaded/stored as multiples of v2x16 elements.
1263+
// vectors of f16 are loaded/stored as multiples of v2f16 elements.
12641264
if ((EltVT == MVT::f16 && N->getValueType(0) == MVT::v2f16) ||
1265-
(EltVT == MVT::bf16 && N->getValueType(0) == MVT::v2bf16) ||
1266-
(EltVT == MVT::i16 && N->getValueType(0) == MVT::v2i16)) {
1267-
assert(NumElts % 2 == 0 && "Vector must have even number of elements");
1268-
EltVT = N->getValueType(0);
1269-
NumElts /= 2;
1265+
(EltVT == MVT::bf16 && N->getValueType(0) == MVT::v2bf16)) {
1266+
assert(NumElts % 2 == 0 && "Vector must have even number of elements");
1267+
EltVT = N->getValueType(0);
1268+
NumElts /= 2;
12701269
}
12711270
}
12721271

@@ -1679,8 +1678,9 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
16791678
MVT ScalarVT = SimpleVT.getScalarType();
16801679
unsigned toTypeWidth = ScalarVT.getSizeInBits();
16811680
if (SimpleVT.isVector()) {
1682-
assert(Isv2x16VT(StoreVT) && "Unexpected vector type");
1683-
// v2x16 is stored using st.b32
1681+
assert((StoreVT == MVT::v2f16 || StoreVT == MVT::v2bf16) &&
1682+
"Unexpected vector type");
1683+
// v2f16 is stored using st.b32
16841684
toTypeWidth = 32;
16851685
}
16861686

@@ -1844,10 +1844,10 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
18441844
return false;
18451845
}
18461846

1847-
// v8x16 is a special case. PTX doesn't have st.v8.x16
1848-
// instruction. Instead, we split the vector into v2x16 chunks and
1847+
// v8f16 is a special case. PTX doesn't have st.v8.f16
1848+
// instruction. Instead, we split the vector into v2f16 chunks and
18491849
// store them with st.v4.b32.
1850-
if (Isv2x16VT(EltVT)) {
1850+
if (EltVT == MVT::v2f16 || EltVT == MVT::v2bf16) {
18511851
assert(N->getOpcode() == NVPTXISD::StoreV4 && "Unexpected load opcode.");
18521852
EltVT = MVT::i32;
18531853
ToType = NVPTX::PTXLdStInstCode::Untyped;

0 commit comments

Comments
 (0)