Skip to content

[RISCV] Compute integers once in isSimpleVIDSequence. NFCI #82590

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 22, 2024

Conversation

lukel97
Copy link
Contributor

@lukel97 lukel97 commented Feb 22, 2024

We need to iterate through the integers twice in isSimpleVIDSequence, so instead of computing them twice just compute them once at the start.

This also replaces the individual checks that each element is constant with a single call to BuildVectorSDNode::isConstant.

We need to iterate through the integers twice in isSimpleVIDSequence, so
instead of computing them twice just compute them once at the start.

This also replaces the individual checks that each element is constant with a
single call to BuildVectorSDNode::isConstant.
@llvmbot
Copy link
Member

llvmbot commented Feb 22, 2024

@llvm/pr-subscribers-backend-risc-v

Author: Luke Lau (lukel97)

Changes

We need to iterate through the integers twice in isSimpleVIDSequence, so
instead of computing them twice just compute them once at the start.

This also replaces the individual checks that each element is constant with a
single call to BuildVectorSDNode::isConstant.


Full diff: https://github.com/llvm/llvm-project/pull/82590.diff

1 Files Affected:

  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+29-35)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 75be97ff32bbe5..cf0dc36a51b61b 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -3242,44 +3242,47 @@ static std::optional<uint64_t> getExactInteger(const APFloat &APF,
 // determine whether this is worth generating code for.
 static std::optional<VIDSequence> isSimpleVIDSequence(SDValue Op,
                                                       unsigned EltSizeInBits) {
-  unsigned NumElts = Op.getNumOperands();
   assert(Op.getOpcode() == ISD::BUILD_VECTOR && "Unexpected BUILD_VECTOR");
+  if (!cast<BuildVectorSDNode>(Op)->isConstant())
+    return std::nullopt;
   bool IsInteger = Op.getValueType().isInteger();
 
   std::optional<unsigned> SeqStepDenom;
   std::optional<int64_t> SeqStepNum, SeqAddend;
   std::optional<std::pair<uint64_t, unsigned>> PrevElt;
   assert(EltSizeInBits >= Op.getValueType().getScalarSizeInBits());
-  for (unsigned Idx = 0; Idx < NumElts; Idx++) {
-    // Assume undef elements match the sequence; we just have to be careful
-    // when interpolating across them.
-    if (Op.getOperand(Idx).isUndef())
-      continue;
 
-    uint64_t Val;
+  // First extract the ops into a list of constant integer values. This may not
+  // be possible for floats if they're not all representable as integers.
+  SmallVector<std::optional<uint64_t>> Elts(Op.getNumOperands());
+  const unsigned OpSize = Op.getScalarValueSizeInBits();
+  for (auto [Idx, Elt] : enumerate(Op->op_values())) {
+    if (Elt.isUndef()) {
+      Elts[Idx] = std::nullopt;
+      continue;
+    }
     if (IsInteger) {
-      // The BUILD_VECTOR must be all constants.
-      if (!isa<ConstantSDNode>(Op.getOperand(Idx)))
-        return std::nullopt;
-      Val = Op.getConstantOperandVal(Idx) &
-            maskTrailingOnes<uint64_t>(Op.getScalarValueSizeInBits());
+      Elts[Idx] = Elt->getAsZExtVal() & maskTrailingOnes<uint64_t>(OpSize);
     } else {
-      // The BUILD_VECTOR must be all constants.
-      if (!isa<ConstantFPSDNode>(Op.getOperand(Idx)))
-        return std::nullopt;
-      if (auto ExactInteger = getExactInteger(
-              cast<ConstantFPSDNode>(Op.getOperand(Idx))->getValueAPF(),
-              Op.getScalarValueSizeInBits()))
-        Val = *ExactInteger;
-      else
+      auto ExactInteger =
+          getExactInteger(cast<ConstantFPSDNode>(Elt)->getValueAPF(), OpSize);
+      if (!ExactInteger)
         return std::nullopt;
+      Elts[Idx] = *ExactInteger;
     }
+  }
+
+  for (auto [Idx, Elt] : enumerate(Elts)) {
+    // Assume undef elements match the sequence; we just have to be careful
+    // when interpolating across them.
+    if (!Elt)
+      continue;
 
     if (PrevElt) {
       // Calculate the step since the last non-undef element, and ensure
       // it's consistent across the entire sequence.
       unsigned IdxDiff = Idx - PrevElt->second;
-      int64_t ValDiff = SignExtend64(Val - PrevElt->first, EltSizeInBits);
+      int64_t ValDiff = SignExtend64(*Elt - PrevElt->first, EltSizeInBits);
 
       // A zero-value value difference means that we're somewhere in the middle
       // of a fractional step, e.g. <0,0,0*,0,1,1,1,1>. Wait until we notice a
@@ -3309,8 +3312,8 @@ static std::optional<VIDSequence> isSimpleVIDSequence(SDValue Op,
     }
 
     // Record this non-undef element for later.
-    if (!PrevElt || PrevElt->first != Val)
-      PrevElt = std::make_pair(Val, Idx);
+    if (!PrevElt || PrevElt->first != *Elt)
+      PrevElt = std::make_pair(*Elt, Idx);
   }
 
   // We need to have logged a step for this to count as a legal index sequence.
@@ -3319,21 +3322,12 @@ static std::optional<VIDSequence> isSimpleVIDSequence(SDValue Op,
 
   // Loop back through the sequence and validate elements we might have skipped
   // while waiting for a valid step. While doing this, log any sequence addend.
-  for (unsigned Idx = 0; Idx < NumElts; Idx++) {
-    if (Op.getOperand(Idx).isUndef())
+  for (auto [Idx, Elt] : enumerate(Elts)) {
+    if (!Elt)
       continue;
-    uint64_t Val;
-    if (IsInteger) {
-      Val = Op.getConstantOperandVal(Idx) &
-            maskTrailingOnes<uint64_t>(Op.getScalarValueSizeInBits());
-    } else {
-      Val = *getExactInteger(
-          cast<ConstantFPSDNode>(Op.getOperand(Idx))->getValueAPF(),
-          Op.getScalarValueSizeInBits());
-    }
     uint64_t ExpectedVal =
         (int64_t)(Idx * (uint64_t)*SeqStepNum) / *SeqStepDenom;
-    int64_t Addend = SignExtend64(Val - ExpectedVal, EltSizeInBits);
+    int64_t Addend = SignExtend64(*Elt - ExpectedVal, EltSizeInBits);
     if (!SeqAddend)
       SeqAddend = Addend;
     else if (Addend != SeqAddend)

Copy link
Contributor

@wangpc-pp wangpc-pp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@lukel97 lukel97 merged commit edd4aee into llvm:main Feb 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants