Skip to content

Commit 5ee7dc0

Browse files
authored
[RISCV] Match gather(splat(ptr)) as zero strided load (#65769)
We were already handling the case where the broadcast was being done via a GEP, but we hadn't handled the case of a broadcast via a shuffle.
1 parent d55ac38 commit 5ee7dc0

File tree

2 files changed

+23
-54
lines changed

2 files changed

+23
-54
lines changed

llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp

+20-10
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ class RISCVGatherScatterLowering : public FunctionPass {
6767
bool tryCreateStridedLoadStore(IntrinsicInst *II, Type *DataType, Value *Ptr,
6868
Value *AlignOp);
6969

70-
std::pair<Value *, Value *> determineBaseAndStride(GetElementPtrInst *GEP,
70+
std::pair<Value *, Value *> determineBaseAndStride(Instruction *Ptr,
7171
IRBuilderBase &Builder);
7272

7373
bool matchStridedRecurrence(Value *Index, Loop *L, Value *&Stride,
@@ -321,9 +321,19 @@ bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L,
321321
}
322322

323323
std::pair<Value *, Value *>
324-
RISCVGatherScatterLowering::determineBaseAndStride(GetElementPtrInst *GEP,
324+
RISCVGatherScatterLowering::determineBaseAndStride(Instruction *Ptr,
325325
IRBuilderBase &Builder) {
326326

327+
// A gather/scatter of a splat is a zero strided load/store.
328+
if (auto *BasePtr = getSplatValue(Ptr)) {
329+
Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType());
330+
return std::make_pair(BasePtr, ConstantInt::get(IntPtrTy, 0));
331+
}
332+
333+
auto *GEP = dyn_cast<GetElementPtrInst>(Ptr);
334+
if (!GEP)
335+
return std::make_pair(nullptr, nullptr);
336+
327337
auto I = StridedAddrs.find(GEP);
328338
if (I != StridedAddrs.end())
329339
return I->second;
@@ -452,17 +462,17 @@ bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(IntrinsicInst *II,
452462
if (!TLI->isTypeLegal(DataTypeVT))
453463
return false;
454464

455-
// Pointer should be a GEP.
456-
auto *GEP = dyn_cast<GetElementPtrInst>(Ptr);
457-
if (!GEP)
465+
// Pointer should be an instruction.
466+
auto *PtrI = dyn_cast<Instruction>(Ptr);
467+
if (!PtrI)
458468
return false;
459469

460-
LLVMContext &Ctx = GEP->getContext();
470+
LLVMContext &Ctx = PtrI->getContext();
461471
IRBuilder<InstSimplifyFolder> Builder(Ctx, *DL);
462-
Builder.SetInsertPoint(GEP);
472+
Builder.SetInsertPoint(PtrI);
463473

464474
Value *BasePtr, *Stride;
465-
std::tie(BasePtr, Stride) = determineBaseAndStride(GEP, Builder);
475+
std::tie(BasePtr, Stride) = determineBaseAndStride(PtrI, Builder);
466476
if (!BasePtr)
467477
return false;
468478
assert(Stride != nullptr);
@@ -485,8 +495,8 @@ bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(IntrinsicInst *II,
485495
II->replaceAllUsesWith(Call);
486496
II->eraseFromParent();
487497

488-
if (GEP->use_empty())
489-
RecursivelyDeleteTriviallyDeadInstructions(GEP);
498+
if (PtrI->use_empty())
499+
RecursivelyDeleteTriviallyDeadInstructions(PtrI);
490500

491501
return true;
492502
}

llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-gather.ll

+3-44
Original file line numberDiff line numberDiff line change
@@ -12918,60 +12918,19 @@ define <4 x i32> @mgather_broadcast_load_unmasked2(ptr %base) {
1291812918
; RV32-LABEL: mgather_broadcast_load_unmasked2:
1291912919
; RV32: # %bb.0:
1292012920
; RV32-NEXT: vsetivli zero, 4, e32, m1, ta, ma
12921-
; RV32-NEXT: vmv.v.x v8, a0
12922-
; RV32-NEXT: vluxei32.v v8, (zero), v8
12921+
; RV32-NEXT: vlse32.v v8, (a0), zero
1292312922
; RV32-NEXT: ret
1292412923
;
1292512924
; RV64V-LABEL: mgather_broadcast_load_unmasked2:
1292612925
; RV64V: # %bb.0:
12927-
; RV64V-NEXT: vsetivli zero, 4, e64, m2, ta, ma
12928-
; RV64V-NEXT: vmv.v.x v10, a0
12929-
; RV64V-NEXT: vsetvli zero, zero, e32, m1, ta, ma
12930-
; RV64V-NEXT: vluxei64.v v8, (zero), v10
12926+
; RV64V-NEXT: vsetivli zero, 4, e32, m1, ta, ma
12927+
; RV64V-NEXT: vlse32.v v8, (a0), zero
1293112928
; RV64V-NEXT: ret
1293212929
;
1293312930
; RV64ZVE32F-LABEL: mgather_broadcast_load_unmasked2:
1293412931
; RV64ZVE32F: # %bb.0:
12935-
; RV64ZVE32F-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
12936-
; RV64ZVE32F-NEXT: vmset.m v8
12937-
; RV64ZVE32F-NEXT: vmv.x.s a1, v8
12938-
; RV64ZVE32F-NEXT: # implicit-def: $v8
12939-
; RV64ZVE32F-NEXT: beqz zero, .LBB100_5
12940-
; RV64ZVE32F-NEXT: # %bb.1: # %else
12941-
; RV64ZVE32F-NEXT: andi a2, a1, 2
12942-
; RV64ZVE32F-NEXT: bnez a2, .LBB100_6
12943-
; RV64ZVE32F-NEXT: .LBB100_2: # %else2
12944-
; RV64ZVE32F-NEXT: andi a2, a1, 4
12945-
; RV64ZVE32F-NEXT: bnez a2, .LBB100_7
12946-
; RV64ZVE32F-NEXT: .LBB100_3: # %else5
12947-
; RV64ZVE32F-NEXT: andi a1, a1, 8
12948-
; RV64ZVE32F-NEXT: bnez a1, .LBB100_8
12949-
; RV64ZVE32F-NEXT: .LBB100_4: # %else8
12950-
; RV64ZVE32F-NEXT: ret
12951-
; RV64ZVE32F-NEXT: .LBB100_5: # %cond.load
1295212932
; RV64ZVE32F-NEXT: vsetivli zero, 4, e32, m1, ta, ma
1295312933
; RV64ZVE32F-NEXT: vlse32.v v8, (a0), zero
12954-
; RV64ZVE32F-NEXT: andi a2, a1, 2
12955-
; RV64ZVE32F-NEXT: beqz a2, .LBB100_2
12956-
; RV64ZVE32F-NEXT: .LBB100_6: # %cond.load1
12957-
; RV64ZVE32F-NEXT: lw a2, 0(a0)
12958-
; RV64ZVE32F-NEXT: vsetivli zero, 2, e32, m1, tu, ma
12959-
; RV64ZVE32F-NEXT: vmv.s.x v9, a2
12960-
; RV64ZVE32F-NEXT: vslideup.vi v8, v9, 1
12961-
; RV64ZVE32F-NEXT: andi a2, a1, 4
12962-
; RV64ZVE32F-NEXT: beqz a2, .LBB100_3
12963-
; RV64ZVE32F-NEXT: .LBB100_7: # %cond.load4
12964-
; RV64ZVE32F-NEXT: lw a2, 0(a0)
12965-
; RV64ZVE32F-NEXT: vsetivli zero, 3, e32, m1, tu, ma
12966-
; RV64ZVE32F-NEXT: vmv.s.x v9, a2
12967-
; RV64ZVE32F-NEXT: vslideup.vi v8, v9, 2
12968-
; RV64ZVE32F-NEXT: andi a1, a1, 8
12969-
; RV64ZVE32F-NEXT: beqz a1, .LBB100_4
12970-
; RV64ZVE32F-NEXT: .LBB100_8: # %cond.load7
12971-
; RV64ZVE32F-NEXT: lw a0, 0(a0)
12972-
; RV64ZVE32F-NEXT: vsetivli zero, 4, e32, m1, ta, ma
12973-
; RV64ZVE32F-NEXT: vmv.s.x v9, a0
12974-
; RV64ZVE32F-NEXT: vslideup.vi v8, v9, 3
1297512934
; RV64ZVE32F-NEXT: ret
1297612935
%head = insertelement <4 x i1> poison, i1 true, i32 0
1297712936
%allones = shufflevector <4 x i1> %head, <4 x i1> poison, <4 x i32> zeroinitializer

0 commit comments

Comments
 (0)