Skip to content

[RISCV] Compute integers once in isSimpleVIDSequence. NFCI #82590

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 22, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 29 additions & 35 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3242,44 +3242,47 @@ static std::optional<uint64_t> getExactInteger(const APFloat &APF,
// determine whether this is worth generating code for.
static std::optional<VIDSequence> isSimpleVIDSequence(SDValue Op,
unsigned EltSizeInBits) {
unsigned NumElts = Op.getNumOperands();
assert(Op.getOpcode() == ISD::BUILD_VECTOR && "Unexpected BUILD_VECTOR");
if (!cast<BuildVectorSDNode>(Op)->isConstant())
return std::nullopt;
bool IsInteger = Op.getValueType().isInteger();

std::optional<unsigned> SeqStepDenom;
std::optional<int64_t> SeqStepNum, SeqAddend;
std::optional<std::pair<uint64_t, unsigned>> PrevElt;
assert(EltSizeInBits >= Op.getValueType().getScalarSizeInBits());
for (unsigned Idx = 0; Idx < NumElts; Idx++) {
// Assume undef elements match the sequence; we just have to be careful
// when interpolating across them.
if (Op.getOperand(Idx).isUndef())
continue;

uint64_t Val;
// First extract the ops into a list of constant integer values. This may not
// be possible for floats if they're not all representable as integers.
SmallVector<std::optional<uint64_t>> Elts(Op.getNumOperands());
const unsigned OpSize = Op.getScalarValueSizeInBits();
for (auto [Idx, Elt] : enumerate(Op->op_values())) {
if (Elt.isUndef()) {
Elts[Idx] = std::nullopt;
continue;
}
if (IsInteger) {
// The BUILD_VECTOR must be all constants.
if (!isa<ConstantSDNode>(Op.getOperand(Idx)))
return std::nullopt;
Val = Op.getConstantOperandVal(Idx) &
maskTrailingOnes<uint64_t>(Op.getScalarValueSizeInBits());
Elts[Idx] = Elt->getAsZExtVal() & maskTrailingOnes<uint64_t>(OpSize);
} else {
// The BUILD_VECTOR must be all constants.
if (!isa<ConstantFPSDNode>(Op.getOperand(Idx)))
return std::nullopt;
if (auto ExactInteger = getExactInteger(
cast<ConstantFPSDNode>(Op.getOperand(Idx))->getValueAPF(),
Op.getScalarValueSizeInBits()))
Val = *ExactInteger;
else
auto ExactInteger =
getExactInteger(cast<ConstantFPSDNode>(Elt)->getValueAPF(), OpSize);
if (!ExactInteger)
return std::nullopt;
Elts[Idx] = *ExactInteger;
}
}

for (auto [Idx, Elt] : enumerate(Elts)) {
// Assume undef elements match the sequence; we just have to be careful
// when interpolating across them.
if (!Elt)
continue;

if (PrevElt) {
// Calculate the step since the last non-undef element, and ensure
// it's consistent across the entire sequence.
unsigned IdxDiff = Idx - PrevElt->second;
int64_t ValDiff = SignExtend64(Val - PrevElt->first, EltSizeInBits);
int64_t ValDiff = SignExtend64(*Elt - PrevElt->first, EltSizeInBits);

// A zero-value value difference means that we're somewhere in the middle
// of a fractional step, e.g. <0,0,0*,0,1,1,1,1>. Wait until we notice a
Expand Down Expand Up @@ -3309,8 +3312,8 @@ static std::optional<VIDSequence> isSimpleVIDSequence(SDValue Op,
}

// Record this non-undef element for later.
if (!PrevElt || PrevElt->first != Val)
PrevElt = std::make_pair(Val, Idx);
if (!PrevElt || PrevElt->first != *Elt)
PrevElt = std::make_pair(*Elt, Idx);
}

// We need to have logged a step for this to count as a legal index sequence.
Expand All @@ -3319,21 +3322,12 @@ static std::optional<VIDSequence> isSimpleVIDSequence(SDValue Op,

// Loop back through the sequence and validate elements we might have skipped
// while waiting for a valid step. While doing this, log any sequence addend.
for (unsigned Idx = 0; Idx < NumElts; Idx++) {
if (Op.getOperand(Idx).isUndef())
for (auto [Idx, Elt] : enumerate(Elts)) {
if (!Elt)
continue;
uint64_t Val;
if (IsInteger) {
Val = Op.getConstantOperandVal(Idx) &
maskTrailingOnes<uint64_t>(Op.getScalarValueSizeInBits());
} else {
Val = *getExactInteger(
cast<ConstantFPSDNode>(Op.getOperand(Idx))->getValueAPF(),
Op.getScalarValueSizeInBits());
}
uint64_t ExpectedVal =
(int64_t)(Idx * (uint64_t)*SeqStepNum) / *SeqStepDenom;
int64_t Addend = SignExtend64(Val - ExpectedVal, EltSizeInBits);
int64_t Addend = SignExtend64(*Elt - ExpectedVal, EltSizeInBits);
if (!SeqAddend)
SeqAddend = Addend;
else if (Addend != SeqAddend)
Expand Down