diff --git a/llvm/include/llvm/Target/TargetSchedule.td b/llvm/include/llvm/Target/TargetSchedule.td index 069eb2900bfe6..d8158eb01ad45 100644 --- a/llvm/include/llvm/Target/TargetSchedule.td +++ b/llvm/include/llvm/Target/TargetSchedule.td @@ -622,11 +622,22 @@ class BothFusionPredicateWithMCInstPredicate // Tie firstOpIdx and secondOpIdx. The operand of `FirstMI` at position // `firstOpIdx` should be the same as the operand of `SecondMI` at position // `secondOpIdx`. +// If the fusion has `IsCommutable` being true and the operand at `secondOpIdx` +// has commutable operand, then the commutable operand will be checked too. class TieReg : BothFusionPredicate { int FirstOpIdx = firstOpIdx; int SecondOpIdx = secondOpIdx; } +// The operand of `SecondMI` at position `firstOpIdx` should be the same as the +// operand at position `secondOpIdx`. +// If the fusion has `IsCommutable` being true and the operand at `secondOpIdx` +// has commutable operand, then the commutable operand will be checked too. +class SameReg : SecondFusionPredicate { + int FirstOpIdx = firstOpIdx; + int SecondOpIdx = secondOpIdx; +} + // A predicate for wildcard. The generated code will be like: // ``` // if (!FirstMI) @@ -655,9 +666,12 @@ def OneUse : OneUsePred; // return true; // } // ``` +// +// `IsCommutable` means whether we should handle commutable operands. class Fusion predicates> : SubtargetFeature { list Predicates = predicates; + bit IsCommutable = 0; } // The generated predicator will be like: @@ -671,6 +685,7 @@ class Fusion p // /* Predicate for `SecondMI` */ // /* Wildcard */ // /* Predicate for `FirstMI` */ +// /* Check same registers */ // /* Check One Use */ // /* Tie registers */ // /* Epilog */ @@ -688,11 +703,7 @@ class SimpleFusion, WildcardTrue, FirstFusionPredicateWithMCInstPredicate, - SecondFusionPredicateWithMCInstPredicate< - CheckAny<[ - CheckIsVRegOperand<0>, - CheckSameRegOperand<0, 1> - ]>>, + SameReg<0, 1>, OneUse, TieReg<0, 1>, ], diff --git a/llvm/test/TableGen/MacroFusion.td b/llvm/test/TableGen/MacroFusion.td index ce76e7f0f7fa6..05c970cbd2245 100644 --- a/llvm/test/TableGen/MacroFusion.td +++ b/llvm/test/TableGen/MacroFusion.td @@ -46,11 +46,21 @@ def TestFusion: SimpleFusion<"test-fusion", "HasTestFusion", "Test Fusion", CheckRegOperand<0, X0> ]>>; +let IsCommutable = 1 in +def TestCommutableFusion: SimpleFusion<"test-commutable-fusion", "HasTestCommutableFusion", + "Test Commutable Fusion", + CheckOpcode<[Inst0]>, + CheckAll<[ + CheckOpcode<[Inst1]>, + CheckRegOperand<0, X0> + ]>>; + // CHECK-PREDICATOR: #ifdef GET_Test_MACRO_FUSION_PRED_DECL // CHECK-PREDICATOR-NEXT: #undef GET_Test_MACRO_FUSION_PRED_DECL // CHECK-PREDICATOR-EMPTY: // CHECK-PREDICATOR-NEXT: namespace llvm { // CHECK-PREDICATOR-NEXT: bool isTestBothFusionPredicate(const TargetInstrInfo &, const TargetSubtargetInfo &, const MachineInstr *, const MachineInstr &); +// CHECK-PREDICATOR-NEXT: bool isTestCommutableFusion(const TargetInstrInfo &, const TargetSubtargetInfo &, const MachineInstr *, const MachineInstr &); // CHECK-PREDICATOR-NEXT: bool isTestFusion(const TargetInstrInfo &, const TargetSubtargetInfo &, const MachineInstr *, const MachineInstr &); // CHECK-PREDICATOR-NEXT: } // end namespace llvm // CHECK-PREDICATOR-EMPTY: @@ -78,7 +88,7 @@ def TestFusion: SimpleFusion<"test-fusion", "HasTestFusion", "Test Fusion", // CHECK-PREDICATOR-NEXT: } // CHECK-PREDICATOR-NEXT: return true; // CHECK-PREDICATOR-NEXT: } -// CHECK-PREDICATOR-NEXT: bool isTestFusion( +// CHECK-PREDICATOR-NEXT: bool isTestCommutableFusion( // CHECK-PREDICATOR-NEXT: const TargetInstrInfo &TII, // CHECK-PREDICATOR-NEXT: const TargetSubtargetInfo &STI, // CHECK-PREDICATOR-NEXT: const MachineInstr *FirstMI, @@ -99,14 +109,58 @@ def TestFusion: SimpleFusion<"test-fusion", "HasTestFusion", "Test Fusion", // CHECK-PREDICATOR-NEXT: if (( MI->getOpcode() != Test::Inst0 )) // CHECK-PREDICATOR-NEXT: return false; // CHECK-PREDICATOR-NEXT: } +// CHECK-PREDICATOR-NEXT: if (!SecondMI.getOperand(0).getReg().isVirtual()) { +// CHECK-PREDICATOR-NEXT: if (SecondMI.getOperand(0).getReg() != SecondMI.getOperand(1).getReg()) { +// CHECK-PREDICATOR-NEXT: if (!SecondMI.getDesc().isCommutable()) +// CHECK-PREDICATOR-NEXT: return false; +// CHECK-PREDICATOR-NEXT: unsigned SrcOpIdx1 = 1, SrcOpIdx2 = TargetInstrInfo::CommuteAnyOperandIndex; +// CHECK-PREDICATOR-NEXT: if (TII.findCommutedOpIndices(SecondMI, SrcOpIdx1, SrcOpIdx2)) +// CHECK-PREDICATOR-NEXT: if (SecondMI.getOperand(0).getReg() != SecondMI.getOperand(SrcOpIdx2).getReg()) +// CHECK-PREDICATOR-NEXT: return false; +// CHECK-PREDICATOR-NEXT: } +// CHECK-PREDICATOR-NEXT: } +// CHECK-PREDICATOR-NEXT: { +// CHECK-PREDICATOR-NEXT: Register FirstDest = FirstMI->getOperand(0).getReg(); +// CHECK-PREDICATOR-NEXT: if (FirstDest.isVirtual() && !MRI.hasOneNonDBGUse(FirstDest)) +// CHECK-PREDICATOR-NEXT: return false; +// CHECK-PREDICATOR-NEXT: } +// CHECK-PREDICATOR-NEXT: if (!(FirstMI->getOperand(0).isReg() && +// CHECK-PREDICATOR-NEXT: SecondMI.getOperand(1).isReg() && +// CHECK-PREDICATOR-NEXT: FirstMI->getOperand(0).getReg() == SecondMI.getOperand(1).getReg())) { +// CHECK-PREDICATOR-NEXT: if (!SecondMI.getDesc().isCommutable()) +// CHECK-PREDICATOR-NEXT: return false; +// CHECK-PREDICATOR-NEXT: unsigned SrcOpIdx1 = 1, SrcOpIdx2 = TargetInstrInfo::CommuteAnyOperandIndex; +// CHECK-PREDICATOR-NEXT: if (TII.findCommutedOpIndices(SecondMI, SrcOpIdx1, SrcOpIdx2)) +// CHECK-PREDICATOR-NEXT: if (FirstMI->getOperand(0).getReg() != SecondMI.getOperand(SrcOpIdx2).getReg()) +// CHECK-PREDICATOR-NEXT: return false; +// CHECK-PREDICATOR-NEXT: } +// CHECK-PREDICATOR-NEXT: return true; +// CHECK-PREDICATOR-NEXT: } +// CHECK-PREDICATOR-NEXT: bool isTestFusion( +// CHECK-PREDICATOR-NEXT: const TargetInstrInfo &TII, +// CHECK-PREDICATOR-NEXT: const TargetSubtargetInfo &STI, +// CHECK-PREDICATOR-NEXT: const MachineInstr *FirstMI, +// CHECK-PREDICATOR-NEXT: const MachineInstr &SecondMI) { +// CHECK-PREDICATOR-NEXT: auto &MRI = SecondMI.getMF()->getRegInfo(); // CHECK-PREDICATOR-NEXT: { // CHECK-PREDICATOR-NEXT: const MachineInstr *MI = &SecondMI; // CHECK-PREDICATOR-NEXT: if (!( -// CHECK-PREDICATOR-NEXT: MI->getOperand(0).getReg().isVirtual() -// CHECK-PREDICATOR-NEXT: || MI->getOperand(0).getReg() == MI->getOperand(1).getReg() +// CHECK-PREDICATOR-NEXT: ( MI->getOpcode() == Test::Inst1 ) +// CHECK-PREDICATOR-NEXT: && MI->getOperand(0).getReg() == Test::X0 // CHECK-PREDICATOR-NEXT: )) // CHECK-PREDICATOR-NEXT: return false; // CHECK-PREDICATOR-NEXT: } +// CHECK-PREDICATOR-NEXT: if (!FirstMI) +// CHECK-PREDICATOR-NEXT: return true; +// CHECK-PREDICATOR-NEXT: { +// CHECK-PREDICATOR-NEXT: const MachineInstr *MI = FirstMI; +// CHECK-PREDICATOR-NEXT: if (( MI->getOpcode() != Test::Inst0 )) +// CHECK-PREDICATOR-NEXT: return false; +// CHECK-PREDICATOR-NEXT: } +// CHECK-PREDICATOR-NEXT: if (!SecondMI.getOperand(0).getReg().isVirtual()) { +// CHECK-PREDICATOR-NEXT: if (SecondMI.getOperand(0).getReg() != SecondMI.getOperand(1).getReg()) +// CHECK-PREDICATOR-NEXT: return false; +// CHECK-PREDICATOR-NEXT: } // CHECK-PREDICATOR-NEXT: { // CHECK-PREDICATOR-NEXT: Register FirstDest = FirstMI->getOperand(0).getReg(); // CHECK-PREDICATOR-NEXT: if (FirstDest.isVirtual() && !MRI.hasOneNonDBGUse(FirstDest)) @@ -131,6 +185,7 @@ def TestFusion: SimpleFusion<"test-fusion", "HasTestFusion", "Test Fusion", // CHECK-SUBTARGET: std::vector TestGenSubtargetInfo::getMacroFusions() const { // CHECK-SUBTARGET-NEXT: std::vector Fusions; // CHECK-SUBTARGET-NEXT: if (hasFeature(Test::TestBothFusionPredicate)) Fusions.push_back(llvm::isTestBothFusionPredicate); +// CHECK-SUBTARGET-NEXT: if (hasFeature(Test::TestCommutableFusion)) Fusions.push_back(llvm::isTestCommutableFusion); // CHECK-SUBTARGET-NEXT: if (hasFeature(Test::TestFusion)) Fusions.push_back(llvm::isTestFusion); // CHECK-SUBTARGET-NEXT: return Fusions; // CHECK-SUBTARGET-NEXT: } diff --git a/llvm/utils/TableGen/MacroFusionPredicatorEmitter.cpp b/llvm/utils/TableGen/MacroFusionPredicatorEmitter.cpp index 7f494e532b1f4..91c3b0b4359cf 100644 --- a/llvm/utils/TableGen/MacroFusionPredicatorEmitter.cpp +++ b/llvm/utils/TableGen/MacroFusionPredicatorEmitter.cpp @@ -40,12 +40,10 @@ #include "CodeGenTarget.h" #include "PredicateExpander.h" -#include "llvm/ADT/SmallVector.h" #include "llvm/Support/Debug.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Record.h" #include "llvm/TableGen/TableGenBackend.h" -#include #include using namespace llvm; @@ -61,14 +59,14 @@ class MacroFusionPredicatorEmitter { raw_ostream &OS); void emitMacroFusionImpl(std::vector Fusions, PredicateExpander &PE, raw_ostream &OS); - void emitPredicates(std::vector &FirstPredicate, + void emitPredicates(std::vector &FirstPredicate, bool IsCommutable, PredicateExpander &PE, raw_ostream &OS); - void emitFirstPredicate(Record *SecondPredicate, PredicateExpander &PE, - raw_ostream &OS); - void emitSecondPredicate(Record *SecondPredicate, PredicateExpander &PE, - raw_ostream &OS); - void emitBothPredicate(Record *Predicates, PredicateExpander &PE, - raw_ostream &OS); + void emitFirstPredicate(Record *SecondPredicate, bool IsCommutable, + PredicateExpander &PE, raw_ostream &OS); + void emitSecondPredicate(Record *SecondPredicate, bool IsCommutable, + PredicateExpander &PE, raw_ostream &OS); + void emitBothPredicate(Record *Predicates, bool IsCommutable, + PredicateExpander &PE, raw_ostream &OS); public: MacroFusionPredicatorEmitter(RecordKeeper &R) : Records(R), Target(R) {} @@ -103,6 +101,7 @@ void MacroFusionPredicatorEmitter::emitMacroFusionImpl( for (Record *Fusion : Fusions) { std::vector Predicates = Fusion->getValueAsListOfDefs("Predicates"); + bool IsCommutable = Fusion->getValueAsBit("IsCommutable"); OS << "bool is" << Fusion->getName() << "(\n"; OS.indent(4) << "const TargetInstrInfo &TII,\n"; @@ -111,7 +110,7 @@ void MacroFusionPredicatorEmitter::emitMacroFusionImpl( OS.indent(4) << "const MachineInstr &SecondMI) {\n"; OS.indent(2) << "auto &MRI = SecondMI.getMF()->getRegInfo();\n"; - emitPredicates(Predicates, PE, OS); + emitPredicates(Predicates, IsCommutable, PE, OS); OS.indent(2) << "return true;\n"; OS << "}\n"; @@ -122,15 +121,16 @@ void MacroFusionPredicatorEmitter::emitMacroFusionImpl( } void MacroFusionPredicatorEmitter::emitPredicates( - std::vector &Predicates, PredicateExpander &PE, raw_ostream &OS) { + std::vector &Predicates, bool IsCommutable, PredicateExpander &PE, + raw_ostream &OS) { for (Record *Predicate : Predicates) { Record *Target = Predicate->getValueAsDef("Target"); if (Target->getName() == "first_fusion_target") - emitFirstPredicate(Predicate, PE, OS); + emitFirstPredicate(Predicate, IsCommutable, PE, OS); else if (Target->getName() == "second_fusion_target") - emitSecondPredicate(Predicate, PE, OS); + emitSecondPredicate(Predicate, IsCommutable, PE, OS); else if (Target->getName() == "both_fusion_target") - emitBothPredicate(Predicate, PE, OS); + emitBothPredicate(Predicate, IsCommutable, PE, OS); else PrintFatalError(Target->getLoc(), "Unsupported 'FusionTarget': " + Target->getName()); @@ -138,6 +138,7 @@ void MacroFusionPredicatorEmitter::emitPredicates( } void MacroFusionPredicatorEmitter::emitFirstPredicate(Record *Predicate, + bool IsCommutable, PredicateExpander &PE, raw_ostream &OS) { if (Predicate->isSubClassOf("WildcardPred")) { @@ -170,6 +171,7 @@ void MacroFusionPredicatorEmitter::emitFirstPredicate(Record *Predicate, } void MacroFusionPredicatorEmitter::emitSecondPredicate(Record *Predicate, + bool IsCommutable, PredicateExpander &PE, raw_ostream &OS) { if (Predicate->isSubClassOf("FusionPredicateWithMCInstPredicate")) { @@ -182,6 +184,36 @@ void MacroFusionPredicatorEmitter::emitSecondPredicate(Record *Predicate, OS << ")\n"; OS.indent(4) << " return false;\n"; OS.indent(2) << "}\n"; + } else if (Predicate->isSubClassOf("SameReg")) { + int FirstOpIdx = Predicate->getValueAsInt("FirstOpIdx"); + int SecondOpIdx = Predicate->getValueAsInt("SecondOpIdx"); + + OS.indent(2) << "if (!SecondMI.getOperand(" << FirstOpIdx + << ").getReg().isVirtual()) {\n"; + OS.indent(4) << "if (SecondMI.getOperand(" << FirstOpIdx + << ").getReg() != SecondMI.getOperand(" << SecondOpIdx + << ").getReg())"; + + if (IsCommutable) { + OS << " {\n"; + OS.indent(6) << "if (!SecondMI.getDesc().isCommutable())\n"; + OS.indent(6) << " return false;\n"; + + OS.indent(6) + << "unsigned SrcOpIdx1 = " << SecondOpIdx + << ", SrcOpIdx2 = TargetInstrInfo::CommuteAnyOperandIndex;\n"; + OS.indent(6) + << "if (TII.findCommutedOpIndices(SecondMI, SrcOpIdx1, SrcOpIdx2))\n"; + OS.indent(6) + << " if (SecondMI.getOperand(" << FirstOpIdx + << ").getReg() != SecondMI.getOperand(SrcOpIdx2).getReg())\n"; + OS.indent(6) << " return false;\n"; + OS.indent(4) << "}\n"; + } else { + OS << "\n"; + OS.indent(4) << " return false;\n"; + } + OS.indent(2) << "}\n"; } else { PrintFatalError(Predicate->getLoc(), "Unsupported predicate for second instruction: " + @@ -190,13 +222,14 @@ void MacroFusionPredicatorEmitter::emitSecondPredicate(Record *Predicate, } void MacroFusionPredicatorEmitter::emitBothPredicate(Record *Predicate, + bool IsCommutable, PredicateExpander &PE, raw_ostream &OS) { if (Predicate->isSubClassOf("FusionPredicateWithCode")) OS << Predicate->getValueAsString("Predicate"); else if (Predicate->isSubClassOf("BothFusionPredicateWithMCInstPredicate")) { - emitFirstPredicate(Predicate, PE, OS); - emitSecondPredicate(Predicate, PE, OS); + emitFirstPredicate(Predicate, IsCommutable, PE, OS); + emitSecondPredicate(Predicate, IsCommutable, PE, OS); } else if (Predicate->isSubClassOf("TieReg")) { int FirstOpIdx = Predicate->getValueAsInt("FirstOpIdx"); int SecondOpIdx = Predicate->getValueAsInt("SecondOpIdx"); @@ -206,8 +239,28 @@ void MacroFusionPredicatorEmitter::emitBothPredicate(Record *Predicate, << ").isReg() &&\n"; OS.indent(2) << " FirstMI->getOperand(" << FirstOpIdx << ").getReg() == SecondMI.getOperand(" << SecondOpIdx - << ").getReg()))\n"; - OS.indent(2) << " return false;\n"; + << ").getReg()))"; + + if (IsCommutable) { + OS << " {\n"; + OS.indent(4) << "if (!SecondMI.getDesc().isCommutable())\n"; + OS.indent(4) << " return false;\n"; + + OS.indent(4) + << "unsigned SrcOpIdx1 = " << SecondOpIdx + << ", SrcOpIdx2 = TargetInstrInfo::CommuteAnyOperandIndex;\n"; + OS.indent(4) + << "if (TII.findCommutedOpIndices(SecondMI, SrcOpIdx1, SrcOpIdx2))\n"; + OS.indent(4) + << " if (FirstMI->getOperand(" << FirstOpIdx + << ").getReg() != SecondMI.getOperand(SrcOpIdx2).getReg())\n"; + OS.indent(4) << " return false;\n"; + OS.indent(2) << "}"; + } else { + OS << "\n"; + OS.indent(2) << " return false;"; + } + OS << "\n"; } else PrintFatalError(Predicate->getLoc(), "Unsupported predicate for both instruction: " +