@@ -612,10 +612,10 @@ bool NVPTXDAGToDAGISel::SelectSETP_F16X2(SDNode *N) {
612
612
bool NVPTXDAGToDAGISel::tryEXTRACT_VECTOR_ELEMENT (SDNode *N) {
613
613
SDValue Vector = N->getOperand (0 );
614
614
615
- // We only care about 16x2 as it's the only real vector type we
615
+ // We only care about f16x2 as it's the only real vector type we
616
616
// need to deal with.
617
617
MVT VT = Vector.getSimpleValueType ();
618
- if (!Isv2x16VT (VT))
618
+ if (!(VT == MVT::v2f16 || VT == MVT::v2bf16 ))
619
619
return false ;
620
620
// Find and record all uses of this vector that extract element 0 or 1.
621
621
SmallVector<SDNode *, 4 > E0 , E1 ;
@@ -828,7 +828,6 @@ pickOpcodeForVT(MVT::SimpleValueType VT, unsigned Opcode_i8,
828
828
return Opcode_i16;
829
829
case MVT::v2f16:
830
830
case MVT::v2bf16:
831
- case MVT::v2i16:
832
831
return Opcode_i32;
833
832
case MVT::f32:
834
833
return Opcode_f32;
@@ -910,8 +909,9 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
910
909
// Vector Setting
911
910
unsigned vecType = NVPTX::PTXLdStInstCode::Scalar;
912
911
if (SimpleVT.isVector ()) {
913
- assert (Isv2x16VT (LoadedVT) && " Unexpected vector type" );
914
- // v2f16/v2bf16/v2i16 is loaded using ld.b32
912
+ assert ((LoadedVT == MVT::v2f16 || LoadedVT == MVT::v2bf16) &&
913
+ " Unexpected vector type" );
914
+ // v2f16/v2bf16 is loaded using ld.b32
915
915
fromTypeWidth = 32 ;
916
916
}
917
917
@@ -1061,10 +1061,10 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
1061
1061
1062
1062
EVT EltVT = N->getValueType (0 );
1063
1063
1064
- // v8x16 is a special case. PTX doesn't have ld.v8.16
1065
- // instruction. Instead, we split the vector into v2x16 chunks and
1064
+ // v8f16 is a special case. PTX doesn't have ld.v8.f16
1065
+ // instruction. Instead, we split the vector into v2f16 chunks and
1066
1066
// load them with ld.v4.b32.
1067
- if (Isv2x16VT ( EltVT) ) {
1067
+ if (EltVT == MVT::v2f16 || EltVT == MVT::v2bf16 ) {
1068
1068
assert (N->getOpcode () == NVPTXISD::LoadV4 && " Unexpected load opcode." );
1069
1069
EltVT = MVT::i32;
1070
1070
FromType = NVPTX::PTXLdStInstCode::Untyped;
@@ -1260,13 +1260,12 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
1260
1260
if (EltVT.isVector ()) {
1261
1261
NumElts = EltVT.getVectorNumElements ();
1262
1262
EltVT = EltVT.getVectorElementType ();
1263
- // vectors of 16bits type are loaded/stored as multiples of v2x16 elements.
1263
+ // vectors of f16 are loaded/stored as multiples of v2f16 elements.
1264
1264
if ((EltVT == MVT::f16 && N->getValueType (0 ) == MVT::v2f16) ||
1265
- (EltVT == MVT::bf16 && N->getValueType (0 ) == MVT::v2bf16) ||
1266
- (EltVT == MVT::i16 && N->getValueType (0 ) == MVT::v2i16)) {
1267
- assert (NumElts % 2 == 0 && " Vector must have even number of elements" );
1268
- EltVT = N->getValueType (0 );
1269
- NumElts /= 2 ;
1265
+ (EltVT == MVT::bf16 && N->getValueType (0 ) == MVT::v2bf16)) {
1266
+ assert (NumElts % 2 == 0 && " Vector must have even number of elements" );
1267
+ EltVT = N->getValueType (0 );
1268
+ NumElts /= 2 ;
1270
1269
}
1271
1270
}
1272
1271
@@ -1679,8 +1678,9 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
1679
1678
MVT ScalarVT = SimpleVT.getScalarType ();
1680
1679
unsigned toTypeWidth = ScalarVT.getSizeInBits ();
1681
1680
if (SimpleVT.isVector ()) {
1682
- assert (Isv2x16VT (StoreVT) && " Unexpected vector type" );
1683
- // v2x16 is stored using st.b32
1681
+ assert ((StoreVT == MVT::v2f16 || StoreVT == MVT::v2bf16) &&
1682
+ " Unexpected vector type" );
1683
+ // v2f16 is stored using st.b32
1684
1684
toTypeWidth = 32 ;
1685
1685
}
1686
1686
@@ -1844,10 +1844,10 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
1844
1844
return false ;
1845
1845
}
1846
1846
1847
- // v8x16 is a special case. PTX doesn't have st.v8.x16
1848
- // instruction. Instead, we split the vector into v2x16 chunks and
1847
+ // v8f16 is a special case. PTX doesn't have st.v8.f16
1848
+ // instruction. Instead, we split the vector into v2f16 chunks and
1849
1849
// store them with st.v4.b32.
1850
- if (Isv2x16VT ( EltVT) ) {
1850
+ if (EltVT == MVT::v2f16 || EltVT == MVT::v2bf16 ) {
1851
1851
assert (N->getOpcode () == NVPTXISD::StoreV4 && " Unexpected load opcode." );
1852
1852
EltVT = MVT::i32;
1853
1853
ToType = NVPTX::PTXLdStInstCode::Untyped;
0 commit comments