@@ -258,6 +258,13 @@ IndexSubset *SILFunctionType::getDifferentiabilityResultIndices() {
258
258
259
259
numSemanticResults += getNumAutoDiffSemanticResultsParameters ();
260
260
261
+ // Check yields.
262
+ for (auto yieldAndIndex : enumerate(getYields ()))
263
+ if (!yieldAndIndex.value ().hasOption (
264
+ SILParameterInfo::NotDifferentiable))
265
+ resultIndices.push_back (numSemanticResults + yieldAndIndex.index ());
266
+
267
+ numSemanticResults += getNumIndirectFormalYields ();
261
268
return IndexSubset::get (getASTContext (), numSemanticResults, resultIndices);
262
269
}
263
270
@@ -555,10 +562,14 @@ static CanSILFunctionType getAutoDiffDifferentialType(
555
562
param.getConvention ());
556
563
differentialParams.push_back ({paramTanType, paramConv});
557
564
}
565
+
558
566
SmallVector<SILResultInfo, 1 > differentialResults;
567
+ unsigned firstSemanticParamResultIdx = originalFnTy->getNumResults ();
568
+ unsigned firstYieldResultIndex = originalFnTy->getNumResults () +
569
+ originalFnTy->getNumAutoDiffSemanticResultsParameters ();
559
570
for (auto resultIndex : resultIndices->getIndices ()) {
560
571
// Handle formal original result.
561
- if (resultIndex < originalFnTy-> getNumResults () ) {
572
+ if (resultIndex < firstSemanticParamResultIdx ) {
562
573
auto &result = originalResults[resultIndex];
563
574
auto resultTanType = getAutoDiffTangentTypeForLinearMap (
564
575
result.getInterfaceType (), lookupConformance,
@@ -571,26 +582,38 @@ static CanSILFunctionType getAutoDiffDifferentialType(
571
582
result.getConvention ());
572
583
differentialResults.push_back ({resultTanType, resultConv});
573
584
continue ;
574
- }
575
- // Handle original semantic result parameters.
576
- auto resultParamIndex = resultIndex - originalFnTy->getNumResults ();
577
- auto resultParamIt = std::next (
585
+ } else if (resultIndex < firstYieldResultIndex) {
586
+ // Handle original semantic result parameters.
587
+ auto resultParamIndex = resultIndex - originalFnTy->getNumResults ();
588
+ auto resultParamIt = std::next (
578
589
originalFnTy->getAutoDiffSemanticResultsParameters ().begin (),
579
590
resultParamIndex);
580
- auto paramIndex =
581
- std::distance (originalFnTy->getParameters ().begin (), &*resultParamIt);
582
- // If the original semantic result parameter is a differentiability
583
- // parameter, then it already has a corresponding differential
584
- // parameter. Skip adding a corresponding differential result.
585
- if (parameterIndices->contains (paramIndex))
586
- continue ;
591
+ auto paramIndex =
592
+ std::distance (originalFnTy->getParameters ().begin (), &*resultParamIt);
593
+ // If the original semantic result parameter is a differentiability
594
+ // parameter, then it already has a corresponding differential
595
+ // parameter. Skip adding a corresponding differential result.
596
+ if (parameterIndices->contains (paramIndex))
597
+ continue ;
587
598
588
- auto resultParam = originalFnTy->getParameters ()[paramIndex];
589
- auto resultParamTanType = getAutoDiffTangentTypeForLinearMap (
590
- resultParam.getInterfaceType (), lookupConformance,
599
+ auto resultParam = originalFnTy->getParameters ()[paramIndex];
600
+ auto resultParamTanType = getAutoDiffTangentTypeForLinearMap (
601
+ resultParam.getInterfaceType (), lookupConformance,
591
602
substGenericParams, substReplacements, ctx);
592
- differentialResults.emplace_back (resultParamTanType,
593
- ResultConvention::Indirect);
603
+ differentialResults.emplace_back (resultParamTanType,
604
+ ResultConvention::Indirect);
605
+ } else {
606
+ assert (originalFnTy->isCoroutine ());
607
+ assert (originalFnTy->getCoroutineKind () == SILCoroutineKind::YieldOnce);
608
+ auto yieldResultIndex = resultIndex - firstYieldResultIndex;
609
+ auto yieldResult = originalFnTy->getYields ()[yieldResultIndex];
610
+ auto resultParamTanType = getAutoDiffTangentTypeForLinearMap (
611
+ yieldResult.getInterfaceType (), lookupConformance,
612
+ substGenericParams, substReplacements, ctx);
613
+ ParameterConvention paramTanConvention = yieldResult.getConvention ();
614
+ assert (yieldResult.getConvention () == ParameterConvention::Indirect_Inout);
615
+ differentialParams.emplace_back (resultParamTanType, paramTanConvention);
616
+ }
594
617
}
595
618
596
619
SubstitutionMap substitutions;
@@ -604,9 +627,9 @@ static CanSILFunctionType getAutoDiffDifferentialType(
604
627
}
605
628
return SILFunctionType::get (
606
629
GenericSignature (), SILFunctionType::ExtInfo (), SILCoroutineKind::None,
607
- ParameterConvention::Direct_Guaranteed, differentialParams, {},
608
- differentialResults, llvm::None, substitutions ,
609
- /* invocationSubstitutions*/ SubstitutionMap (), ctx);
630
+ ParameterConvention::Direct_Guaranteed,
631
+ differentialParams, {}, differentialResults, llvm::None,
632
+ substitutions, /* invocationSubstitutions*/ SubstitutionMap (), ctx);
610
633
}
611
634
612
635
// / Returns the pullback type for the given original function type, parameter
@@ -696,11 +719,15 @@ static CanSILFunctionType getAutoDiffPullbackType(
696
719
return conv;
697
720
};
698
721
699
- // Collect pullback parameters.
722
+ // Collect pullback parameters & yields
700
723
SmallVector<SILParameterInfo, 1 > pullbackParams;
724
+ SmallVector<SILYieldInfo, 1 > pullbackYields;
725
+ unsigned firstSemanticParamResultIdx = originalFnTy->getNumResults ();
726
+ unsigned firstYieldResultIndex = originalFnTy->getNumResults () +
727
+ originalFnTy->getNumAutoDiffSemanticResultsParameters ();
701
728
for (auto resultIndex : resultIndices->getIndices ()) {
702
729
// Handle formal original result.
703
- if (resultIndex < originalFnTy-> getNumResults () ) {
730
+ if (resultIndex < firstSemanticParamResultIdx ) {
704
731
auto &origRes = originalResults[resultIndex];
705
732
auto resultTanType = getAutoDiffTangentTypeForLinearMap (
706
733
origRes.getInterfaceType (), lookupConformance,
@@ -712,28 +739,38 @@ static CanSILFunctionType getAutoDiffPullbackType(
712
739
->getCanonicalType (),
713
740
origRes.getConvention ());
714
741
pullbackParams.emplace_back (resultTanType, paramConv);
715
- continue ;
742
+ } else if (resultIndex < firstYieldResultIndex) {
743
+ // Handle original semantic result parameters.
744
+ auto resultParamIndex = resultIndex - firstSemanticParamResultIdx;
745
+ auto resultParamIt = std::next (
746
+ originalFnTy->getAutoDiffSemanticResultsParameters ().begin (),
747
+ resultParamIndex);
748
+ auto paramIndex =
749
+ std::distance (originalFnTy->getParameters ().begin (), &*resultParamIt);
750
+ auto resultParam = originalFnTy->getParameters ()[paramIndex];
751
+ // The pullback parameter convention depends on whether the original `inout`
752
+ // parameter is a differentiability parameter.
753
+ // - If yes, the pullback parameter convention is `@inout`.
754
+ // - If no, the pullback parameter convention is `@in_guaranteed`.
755
+ auto resultParamTanType = getAutoDiffTangentTypeForLinearMap (
756
+ resultParam.getInterfaceType (), lookupConformance,
757
+ substGenericParams, substReplacements, ctx);
758
+ ParameterConvention paramTanConvention = resultParam.getConvention ();
759
+ if (!parameterIndices->contains (paramIndex))
760
+ paramTanConvention = ParameterConvention::Indirect_In_Guaranteed;
761
+ pullbackParams.emplace_back (resultParamTanType, paramTanConvention);
762
+ } else {
763
+ assert (originalFnTy->isCoroutine ());
764
+ assert (originalFnTy->getCoroutineKind () == SILCoroutineKind::YieldOnce);
765
+ auto yieldResultIndex = resultIndex - firstYieldResultIndex;
766
+ auto yieldResult = originalFnTy->getYields ()[yieldResultIndex];
767
+ auto resultParamTanType = getAutoDiffTangentTypeForLinearMap (
768
+ yieldResult.getInterfaceType (), lookupConformance,
769
+ substGenericParams, substReplacements, ctx);
770
+ ParameterConvention paramTanConvention = yieldResult.getConvention ();
771
+ assert (yieldResult.getConvention () == ParameterConvention::Indirect_Inout);
772
+ pullbackYields.emplace_back (resultParamTanType, paramTanConvention);
716
773
}
717
- // Handle original semantic result parameters.
718
- auto resultParamIndex = resultIndex - originalFnTy->getNumResults ();
719
- auto resultParamIt = std::next (
720
- originalFnTy->getAutoDiffSemanticResultsParameters ().begin (),
721
- resultParamIndex);
722
- auto paramIndex =
723
- std::distance (originalFnTy->getParameters ().begin (), &*resultParamIt);
724
- auto resultParam = originalFnTy->getParameters ()[paramIndex];
725
- // The pullback parameter convention depends on whether the original `inout`
726
- // parameter is a differentiability parameter.
727
- // - If yes, the pullback parameter convention is `@inout`.
728
- // - If no, the pullback parameter convention is `@in_guaranteed`.
729
- auto resultParamTanType = getAutoDiffTangentTypeForLinearMap (
730
- resultParam.getInterfaceType (), lookupConformance,
731
- substGenericParams, substReplacements, ctx);
732
- ParameterConvention paramTanConvention = resultParam.getConvention ();
733
- if (!parameterIndices->contains (paramIndex))
734
- paramTanConvention = ParameterConvention::Indirect_In_Guaranteed;
735
-
736
- pullbackParams.emplace_back (resultParamTanType, paramTanConvention);
737
774
}
738
775
739
776
// Collect pullback results.
@@ -756,6 +793,7 @@ static CanSILFunctionType getAutoDiffPullbackType(
756
793
param.getConvention ());
757
794
pullbackResults.push_back ({paramTanType, resultTanConvention});
758
795
}
796
+
759
797
SubstitutionMap substitutions;
760
798
if (!substGenericParams.empty ()) {
761
799
auto genericSig =
@@ -768,7 +806,7 @@ static CanSILFunctionType getAutoDiffPullbackType(
768
806
return SILFunctionType::get (
769
807
GenericSignature (), SILFunctionType::ExtInfo (), originalFnTy->getCoroutineKind (),
770
808
ParameterConvention::Direct_Guaranteed,
771
- pullbackParams, {} , pullbackResults, llvm::None, substitutions,
809
+ pullbackParams, pullbackYields , pullbackResults, llvm::None, substitutions,
772
810
/* invocationSubstitutions*/ SubstitutionMap (), ctx);
773
811
}
774
812
@@ -920,10 +958,9 @@ CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType(
920
958
// Compute the derivative function results.
921
959
SmallVector<SILResultInfo, 4 > newResults;
922
960
newResults.reserve (getNumResults () + 1 );
923
- for (auto &result : constrainedOriginalFnTy->getResults ()) {
961
+ for (auto &result : constrainedOriginalFnTy->getResults ())
924
962
newResults.push_back (result);
925
- }
926
- newResults.push_back ({closureType, ResultConvention::Owned});
963
+ newResults.emplace_back (closureType, ResultConvention::Owned);
927
964
928
965
// Compute the derivative function ExtInfo.
929
966
// If original function is `@convention(c)`, the derivative function should
0 commit comments