Skip to content

[MacroFusion] Support commutable instructions #82751

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions llvm/include/llvm/Target/TargetSchedule.td
Original file line number Diff line number Diff line change
Expand Up @@ -622,11 +622,22 @@ class BothFusionPredicateWithMCInstPredicate<MCInstPredicate pred>
// Tie firstOpIdx and secondOpIdx. The operand of `FirstMI` at position
// `firstOpIdx` should be the same as the operand of `SecondMI` at position
// `secondOpIdx`.
// If the fusion has `IsCommutable` being true and the operand at `secondOpIdx`
// has commutable operand, then the commutable operand will be checked too.
class TieReg<int firstOpIdx, int secondOpIdx> : BothFusionPredicate {
int FirstOpIdx = firstOpIdx;
int SecondOpIdx = secondOpIdx;
}

// The operand of `SecondMI` at position `firstOpIdx` should be the same as the
// operand at position `secondOpIdx`.
// If the fusion has `IsCommutable` being true and the operand at `secondOpIdx`
// has commutable operand, then the commutable operand will be checked too.
class SameReg<int firstOpIdx, int secondOpIdx> : SecondFusionPredicate {
int FirstOpIdx = firstOpIdx;
int SecondOpIdx = secondOpIdx;
}

// A predicate for wildcard. The generated code will be like:
// ```
// if (!FirstMI)
Expand Down Expand Up @@ -655,9 +666,12 @@ def OneUse : OneUsePred;
// return true;
// }
// ```
//
// `IsCommutable` means whether we should handle commutable operands.
class Fusion<string name, string fieldName, string desc, list<FusionPredicate> predicates>
: SubtargetFeature<name, fieldName, "true", desc> {
list<FusionPredicate> Predicates = predicates;
bit IsCommutable = 0;
}

// The generated predicator will be like:
Expand All @@ -671,6 +685,7 @@ class Fusion<string name, string fieldName, string desc, list<FusionPredicate> p
// /* Predicate for `SecondMI` */
// /* Wildcard */
// /* Predicate for `FirstMI` */
// /* Check same registers */
// /* Check One Use */
// /* Tie registers */
// /* Epilog */
Expand All @@ -688,11 +703,7 @@ class SimpleFusion<string name, string fieldName, string desc,
SecondFusionPredicateWithMCInstPredicate<secondPred>,
WildcardTrue,
FirstFusionPredicateWithMCInstPredicate<firstPred>,
SecondFusionPredicateWithMCInstPredicate<
CheckAny<[
CheckIsVRegOperand<0>,
CheckSameRegOperand<0, 1>
]>>,
SameReg<0, 1>,
OneUse,
TieReg<0, 1>,
],
Expand Down
61 changes: 58 additions & 3 deletions llvm/test/TableGen/MacroFusion.td
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,21 @@ def TestFusion: SimpleFusion<"test-fusion", "HasTestFusion", "Test Fusion",
CheckRegOperand<0, X0>
]>>;

let IsCommutable = 1 in
def TestCommutableFusion: SimpleFusion<"test-commutable-fusion", "HasTestCommutableFusion",
"Test Commutable Fusion",
CheckOpcode<[Inst0]>,
CheckAll<[
CheckOpcode<[Inst1]>,
CheckRegOperand<0, X0>
]>>;

// CHECK-PREDICATOR: #ifdef GET_Test_MACRO_FUSION_PRED_DECL
// CHECK-PREDICATOR-NEXT: #undef GET_Test_MACRO_FUSION_PRED_DECL
// CHECK-PREDICATOR-EMPTY:
// CHECK-PREDICATOR-NEXT: namespace llvm {
// CHECK-PREDICATOR-NEXT: bool isTestBothFusionPredicate(const TargetInstrInfo &, const TargetSubtargetInfo &, const MachineInstr *, const MachineInstr &);
// CHECK-PREDICATOR-NEXT: bool isTestCommutableFusion(const TargetInstrInfo &, const TargetSubtargetInfo &, const MachineInstr *, const MachineInstr &);
// CHECK-PREDICATOR-NEXT: bool isTestFusion(const TargetInstrInfo &, const TargetSubtargetInfo &, const MachineInstr *, const MachineInstr &);
// CHECK-PREDICATOR-NEXT: } // end namespace llvm
// CHECK-PREDICATOR-EMPTY:
Expand Down Expand Up @@ -78,7 +88,7 @@ def TestFusion: SimpleFusion<"test-fusion", "HasTestFusion", "Test Fusion",
// CHECK-PREDICATOR-NEXT: }
// CHECK-PREDICATOR-NEXT: return true;
// CHECK-PREDICATOR-NEXT: }
// CHECK-PREDICATOR-NEXT: bool isTestFusion(
// CHECK-PREDICATOR-NEXT: bool isTestCommutableFusion(
// CHECK-PREDICATOR-NEXT: const TargetInstrInfo &TII,
// CHECK-PREDICATOR-NEXT: const TargetSubtargetInfo &STI,
// CHECK-PREDICATOR-NEXT: const MachineInstr *FirstMI,
Expand All @@ -99,14 +109,58 @@ def TestFusion: SimpleFusion<"test-fusion", "HasTestFusion", "Test Fusion",
// CHECK-PREDICATOR-NEXT: if (( MI->getOpcode() != Test::Inst0 ))
// CHECK-PREDICATOR-NEXT: return false;
// CHECK-PREDICATOR-NEXT: }
// CHECK-PREDICATOR-NEXT: if (!SecondMI.getOperand(0).getReg().isVirtual()) {
// CHECK-PREDICATOR-NEXT: if (SecondMI.getOperand(0).getReg() != SecondMI.getOperand(1).getReg()) {
// CHECK-PREDICATOR-NEXT: if (!SecondMI.getDesc().isCommutable())
// CHECK-PREDICATOR-NEXT: return false;
// CHECK-PREDICATOR-NEXT: unsigned SrcOpIdx1 = 1, SrcOpIdx2 = TargetInstrInfo::CommuteAnyOperandIndex;
// CHECK-PREDICATOR-NEXT: if (TII.findCommutedOpIndices(SecondMI, SrcOpIdx1, SrcOpIdx2))
// CHECK-PREDICATOR-NEXT: if (SecondMI.getOperand(0).getReg() != SecondMI.getOperand(SrcOpIdx2).getReg())
// CHECK-PREDICATOR-NEXT: return false;
// CHECK-PREDICATOR-NEXT: }
// CHECK-PREDICATOR-NEXT: }
// CHECK-PREDICATOR-NEXT: {
// CHECK-PREDICATOR-NEXT: Register FirstDest = FirstMI->getOperand(0).getReg();
// CHECK-PREDICATOR-NEXT: if (FirstDest.isVirtual() && !MRI.hasOneNonDBGUse(FirstDest))
// CHECK-PREDICATOR-NEXT: return false;
// CHECK-PREDICATOR-NEXT: }
// CHECK-PREDICATOR-NEXT: if (!(FirstMI->getOperand(0).isReg() &&
// CHECK-PREDICATOR-NEXT: SecondMI.getOperand(1).isReg() &&
// CHECK-PREDICATOR-NEXT: FirstMI->getOperand(0).getReg() == SecondMI.getOperand(1).getReg())) {
// CHECK-PREDICATOR-NEXT: if (!SecondMI.getDesc().isCommutable())
// CHECK-PREDICATOR-NEXT: return false;
// CHECK-PREDICATOR-NEXT: unsigned SrcOpIdx1 = 1, SrcOpIdx2 = TargetInstrInfo::CommuteAnyOperandIndex;
// CHECK-PREDICATOR-NEXT: if (TII.findCommutedOpIndices(SecondMI, SrcOpIdx1, SrcOpIdx2))
// CHECK-PREDICATOR-NEXT: if (FirstMI->getOperand(0).getReg() != SecondMI.getOperand(SrcOpIdx2).getReg())
// CHECK-PREDICATOR-NEXT: return false;
// CHECK-PREDICATOR-NEXT: }
// CHECK-PREDICATOR-NEXT: return true;
// CHECK-PREDICATOR-NEXT: }
// CHECK-PREDICATOR-NEXT: bool isTestFusion(
// CHECK-PREDICATOR-NEXT: const TargetInstrInfo &TII,
// CHECK-PREDICATOR-NEXT: const TargetSubtargetInfo &STI,
// CHECK-PREDICATOR-NEXT: const MachineInstr *FirstMI,
// CHECK-PREDICATOR-NEXT: const MachineInstr &SecondMI) {
// CHECK-PREDICATOR-NEXT: auto &MRI = SecondMI.getMF()->getRegInfo();
// CHECK-PREDICATOR-NEXT: {
// CHECK-PREDICATOR-NEXT: const MachineInstr *MI = &SecondMI;
// CHECK-PREDICATOR-NEXT: if (!(
// CHECK-PREDICATOR-NEXT: MI->getOperand(0).getReg().isVirtual()
// CHECK-PREDICATOR-NEXT: || MI->getOperand(0).getReg() == MI->getOperand(1).getReg()
// CHECK-PREDICATOR-NEXT: ( MI->getOpcode() == Test::Inst1 )
// CHECK-PREDICATOR-NEXT: && MI->getOperand(0).getReg() == Test::X0
// CHECK-PREDICATOR-NEXT: ))
// CHECK-PREDICATOR-NEXT: return false;
// CHECK-PREDICATOR-NEXT: }
// CHECK-PREDICATOR-NEXT: if (!FirstMI)
// CHECK-PREDICATOR-NEXT: return true;
// CHECK-PREDICATOR-NEXT: {
// CHECK-PREDICATOR-NEXT: const MachineInstr *MI = FirstMI;
// CHECK-PREDICATOR-NEXT: if (( MI->getOpcode() != Test::Inst0 ))
// CHECK-PREDICATOR-NEXT: return false;
// CHECK-PREDICATOR-NEXT: }
// CHECK-PREDICATOR-NEXT: if (!SecondMI.getOperand(0).getReg().isVirtual()) {
// CHECK-PREDICATOR-NEXT: if (SecondMI.getOperand(0).getReg() != SecondMI.getOperand(1).getReg())
// CHECK-PREDICATOR-NEXT: return false;
// CHECK-PREDICATOR-NEXT: }
// CHECK-PREDICATOR-NEXT: {
// CHECK-PREDICATOR-NEXT: Register FirstDest = FirstMI->getOperand(0).getReg();
// CHECK-PREDICATOR-NEXT: if (FirstDest.isVirtual() && !MRI.hasOneNonDBGUse(FirstDest))
Expand All @@ -131,6 +185,7 @@ def TestFusion: SimpleFusion<"test-fusion", "HasTestFusion", "Test Fusion",
// CHECK-SUBTARGET: std::vector<MacroFusionPredTy> TestGenSubtargetInfo::getMacroFusions() const {
// CHECK-SUBTARGET-NEXT: std::vector<MacroFusionPredTy> Fusions;
// CHECK-SUBTARGET-NEXT: if (hasFeature(Test::TestBothFusionPredicate)) Fusions.push_back(llvm::isTestBothFusionPredicate);
// CHECK-SUBTARGET-NEXT: if (hasFeature(Test::TestCommutableFusion)) Fusions.push_back(llvm::isTestCommutableFusion);
// CHECK-SUBTARGET-NEXT: if (hasFeature(Test::TestFusion)) Fusions.push_back(llvm::isTestFusion);
// CHECK-SUBTARGET-NEXT: return Fusions;
// CHECK-SUBTARGET-NEXT: }
89 changes: 71 additions & 18 deletions llvm/utils/TableGen/MacroFusionPredicatorEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,10 @@

#include "CodeGenTarget.h"
#include "PredicateExpander.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
#include <set>
#include <vector>

using namespace llvm;
Expand All @@ -61,14 +59,14 @@ class MacroFusionPredicatorEmitter {
raw_ostream &OS);
void emitMacroFusionImpl(std::vector<Record *> Fusions, PredicateExpander &PE,
raw_ostream &OS);
void emitPredicates(std::vector<Record *> &FirstPredicate,
void emitPredicates(std::vector<Record *> &FirstPredicate, bool IsCommutable,
PredicateExpander &PE, raw_ostream &OS);
void emitFirstPredicate(Record *SecondPredicate, PredicateExpander &PE,
raw_ostream &OS);
void emitSecondPredicate(Record *SecondPredicate, PredicateExpander &PE,
raw_ostream &OS);
void emitBothPredicate(Record *Predicates, PredicateExpander &PE,
raw_ostream &OS);
void emitFirstPredicate(Record *SecondPredicate, bool IsCommutable,
PredicateExpander &PE, raw_ostream &OS);
void emitSecondPredicate(Record *SecondPredicate, bool IsCommutable,
PredicateExpander &PE, raw_ostream &OS);
void emitBothPredicate(Record *Predicates, bool IsCommutable,
PredicateExpander &PE, raw_ostream &OS);

public:
MacroFusionPredicatorEmitter(RecordKeeper &R) : Records(R), Target(R) {}
Expand Down Expand Up @@ -103,6 +101,7 @@ void MacroFusionPredicatorEmitter::emitMacroFusionImpl(
for (Record *Fusion : Fusions) {
std::vector<Record *> Predicates =
Fusion->getValueAsListOfDefs("Predicates");
bool IsCommutable = Fusion->getValueAsBit("IsCommutable");

OS << "bool is" << Fusion->getName() << "(\n";
OS.indent(4) << "const TargetInstrInfo &TII,\n";
Expand All @@ -111,7 +110,7 @@ void MacroFusionPredicatorEmitter::emitMacroFusionImpl(
OS.indent(4) << "const MachineInstr &SecondMI) {\n";
OS.indent(2) << "auto &MRI = SecondMI.getMF()->getRegInfo();\n";

emitPredicates(Predicates, PE, OS);
emitPredicates(Predicates, IsCommutable, PE, OS);

OS.indent(2) << "return true;\n";
OS << "}\n";
Expand All @@ -122,22 +121,24 @@ void MacroFusionPredicatorEmitter::emitMacroFusionImpl(
}

void MacroFusionPredicatorEmitter::emitPredicates(
std::vector<Record *> &Predicates, PredicateExpander &PE, raw_ostream &OS) {
std::vector<Record *> &Predicates, bool IsCommutable, PredicateExpander &PE,
raw_ostream &OS) {
for (Record *Predicate : Predicates) {
Record *Target = Predicate->getValueAsDef("Target");
if (Target->getName() == "first_fusion_target")
emitFirstPredicate(Predicate, PE, OS);
emitFirstPredicate(Predicate, IsCommutable, PE, OS);
else if (Target->getName() == "second_fusion_target")
emitSecondPredicate(Predicate, PE, OS);
emitSecondPredicate(Predicate, IsCommutable, PE, OS);
else if (Target->getName() == "both_fusion_target")
emitBothPredicate(Predicate, PE, OS);
emitBothPredicate(Predicate, IsCommutable, PE, OS);
else
PrintFatalError(Target->getLoc(),
"Unsupported 'FusionTarget': " + Target->getName());
}
}

void MacroFusionPredicatorEmitter::emitFirstPredicate(Record *Predicate,
bool IsCommutable,
PredicateExpander &PE,
raw_ostream &OS) {
if (Predicate->isSubClassOf("WildcardPred")) {
Expand Down Expand Up @@ -170,6 +171,7 @@ void MacroFusionPredicatorEmitter::emitFirstPredicate(Record *Predicate,
}

void MacroFusionPredicatorEmitter::emitSecondPredicate(Record *Predicate,
bool IsCommutable,
PredicateExpander &PE,
raw_ostream &OS) {
if (Predicate->isSubClassOf("FusionPredicateWithMCInstPredicate")) {
Expand All @@ -182,6 +184,36 @@ void MacroFusionPredicatorEmitter::emitSecondPredicate(Record *Predicate,
OS << ")\n";
OS.indent(4) << " return false;\n";
OS.indent(2) << "}\n";
} else if (Predicate->isSubClassOf("SameReg")) {
int FirstOpIdx = Predicate->getValueAsInt("FirstOpIdx");
int SecondOpIdx = Predicate->getValueAsInt("SecondOpIdx");

OS.indent(2) << "if (!SecondMI.getOperand(" << FirstOpIdx
<< ").getReg().isVirtual()) {\n";
OS.indent(4) << "if (SecondMI.getOperand(" << FirstOpIdx
<< ").getReg() != SecondMI.getOperand(" << SecondOpIdx
<< ").getReg())";

if (IsCommutable) {
OS << " {\n";
OS.indent(6) << "if (!SecondMI.getDesc().isCommutable())\n";
OS.indent(6) << " return false;\n";

OS.indent(6)
<< "unsigned SrcOpIdx1 = " << SecondOpIdx
<< ", SrcOpIdx2 = TargetInstrInfo::CommuteAnyOperandIndex;\n";
OS.indent(6)
<< "if (TII.findCommutedOpIndices(SecondMI, SrcOpIdx1, SrcOpIdx2))\n";
OS.indent(6)
<< " if (SecondMI.getOperand(" << FirstOpIdx
<< ").getReg() != SecondMI.getOperand(SrcOpIdx2).getReg())\n";
OS.indent(6) << " return false;\n";
OS.indent(4) << "}\n";
} else {
OS << "\n";
OS.indent(4) << " return false;\n";
}
OS.indent(2) << "}\n";
} else {
PrintFatalError(Predicate->getLoc(),
"Unsupported predicate for second instruction: " +
Expand All @@ -190,13 +222,14 @@ void MacroFusionPredicatorEmitter::emitSecondPredicate(Record *Predicate,
}

void MacroFusionPredicatorEmitter::emitBothPredicate(Record *Predicate,
bool IsCommutable,
PredicateExpander &PE,
raw_ostream &OS) {
if (Predicate->isSubClassOf("FusionPredicateWithCode"))
OS << Predicate->getValueAsString("Predicate");
else if (Predicate->isSubClassOf("BothFusionPredicateWithMCInstPredicate")) {
emitFirstPredicate(Predicate, PE, OS);
emitSecondPredicate(Predicate, PE, OS);
emitFirstPredicate(Predicate, IsCommutable, PE, OS);
emitSecondPredicate(Predicate, IsCommutable, PE, OS);
} else if (Predicate->isSubClassOf("TieReg")) {
int FirstOpIdx = Predicate->getValueAsInt("FirstOpIdx");
int SecondOpIdx = Predicate->getValueAsInt("SecondOpIdx");
Expand All @@ -206,8 +239,28 @@ void MacroFusionPredicatorEmitter::emitBothPredicate(Record *Predicate,
<< ").isReg() &&\n";
OS.indent(2) << " FirstMI->getOperand(" << FirstOpIdx
<< ").getReg() == SecondMI.getOperand(" << SecondOpIdx
<< ").getReg()))\n";
OS.indent(2) << " return false;\n";
<< ").getReg()))";

if (IsCommutable) {
OS << " {\n";
OS.indent(4) << "if (!SecondMI.getDesc().isCommutable())\n";
OS.indent(4) << " return false;\n";

OS.indent(4)
<< "unsigned SrcOpIdx1 = " << SecondOpIdx
<< ", SrcOpIdx2 = TargetInstrInfo::CommuteAnyOperandIndex;\n";
OS.indent(4)
<< "if (TII.findCommutedOpIndices(SecondMI, SrcOpIdx1, SrcOpIdx2))\n";
OS.indent(4)
<< " if (FirstMI->getOperand(" << FirstOpIdx
<< ").getReg() != SecondMI.getOperand(SrcOpIdx2).getReg())\n";
OS.indent(4) << " return false;\n";
OS.indent(2) << "}";
} else {
OS << "\n";
OS.indent(2) << " return false;";
}
OS << "\n";
} else
PrintFatalError(Predicate->getLoc(),
"Unsupported predicate for both instruction: " +
Expand Down