@@ -527,6 +527,69 @@ class VJPCloner::Implementation final
527
527
getOpValue (origCallee)->getDefiningInstruction ());
528
528
}
529
529
530
+ // Check and diagnose non-differentiable original function type.
531
+ bool diagnoseNondifferentiableOriginalFunctionType (CanSILFunctionType originalFnTy,
532
+ FullApplySite fai, SILValue origCallee,
533
+ const AutoDiffConfig &config) const {
534
+ // Check and diagnose non-differentiable arguments.
535
+ for (auto paramIndex : config.parameterIndices ->getIndices ()) {
536
+ if (!originalFnTy->getParameters ()[paramIndex]
537
+ .getSILStorageInterfaceType ()
538
+ .isDifferentiable (getModule ())) {
539
+ auto arg = fai.getArgumentsWithoutIndirectResults ()[paramIndex];
540
+ // FIXME: This shouldn't be necessary and might indicate a bug in
541
+ // the transformation.
542
+ RegularLocation nonAutoGenLoc (arg.getLoc ());
543
+ nonAutoGenLoc.markNonAutoGenerated ();
544
+ auto startLoc = nonAutoGenLoc.getStartSourceLoc ();
545
+ auto endLoc = nonAutoGenLoc.getEndSourceLoc ();
546
+ context.emitNondifferentiabilityError (
547
+ arg, invoker, diag::autodiff_nondifferentiable_argument)
548
+ .fixItInsert (startLoc, " withoutDerivative(at: " )
549
+ .fixItInsertAfter (endLoc, " )" );
550
+ return true ;
551
+ }
552
+ }
553
+
554
+ // Check and diagnose non-differentiable results.
555
+ unsigned firstSemanticParamResultIdx = originalFnTy->getNumResults ();
556
+ unsigned firstYieldResultIndex = originalFnTy->getNumResults () +
557
+ originalFnTy->getNumAutoDiffSemanticResultsParameters ();
558
+
559
+ for (auto resultIndex : config.resultIndices ->getIndices ()) {
560
+ SILType remappedResultType;
561
+ if (resultIndex >= firstYieldResultIndex) {
562
+ auto yieldResultIdx = resultIndex - firstYieldResultIndex;
563
+ const auto & yield = originalFnTy->getYields ()[yieldResultIdx];
564
+ // We do not have a good way to differentiate direct yields
565
+ if (yield.isAutoDiffSemanticResult ())
566
+ remappedResultType = yield.getSILStorageInterfaceType ();
567
+ } else if (resultIndex >= firstSemanticParamResultIdx) {
568
+ auto semanticResultArgIdx = resultIndex - firstSemanticParamResultIdx;
569
+ auto semanticResultArg =
570
+ *std::next (fai.getAutoDiffSemanticResultArguments ().begin (),
571
+ semanticResultArgIdx);
572
+ remappedResultType = semanticResultArg->getType ();
573
+ } else {
574
+ remappedResultType = originalFnTy->getResults ()[resultIndex]
575
+ .getSILStorageInterfaceType ();
576
+ }
577
+
578
+ if (!remappedResultType || !remappedResultType.isDifferentiable (getModule ())) {
579
+ auto startLoc = fai.getLoc ().getStartSourceLoc ();
580
+ auto endLoc = fai.getLoc ().getEndSourceLoc ();
581
+ context.emitNondifferentiabilityError (
582
+ origCallee, invoker,
583
+ diag::autodiff_nondifferentiable_result)
584
+ .fixItInsert (startLoc, " withoutDerivative(at: " )
585
+ .fixItInsertAfter (endLoc, " )" );
586
+ return true ;
587
+ }
588
+ }
589
+
590
+ return false ;
591
+ }
592
+
530
593
void visitBeginApplyInst (BeginApplyInst *bai) {
531
594
// If callee should not be differentiated, do standard cloning.
532
595
if (!pullbackInfo.shouldDifferentiateApplySite (bai)) {
@@ -569,6 +632,13 @@ class VJPCloner::Implementation final
569
632
IndexSubset::get (getASTContext (),
570
633
bai->getSubstCalleeType ()->getNumAutoDiffSemanticResults (),
571
634
activeResultIndices));
635
+
636
+ if (diagnoseNondifferentiableOriginalFunctionType (originalFnTy,
637
+ bai, origCallee, config)) {
638
+ errorOccurred = true ;
639
+ return ;
640
+ }
641
+
572
642
// Emit the VJP.
573
643
SILValue vjpValue;
574
644
@@ -586,9 +656,13 @@ class VJPCloner::Implementation final
586
656
ParameterConvention::Direct_Guaranteed);
587
657
origCallee = vjpPartialApply;
588
658
originalFnTy = origCallee->getType ().castTo <SILFunctionType>();
659
+
589
660
// Diagnose if new original function type is non-differentiable.
590
- // if (diagnoseNondifferentiableOriginalFunctionType(originalFnTy))
591
- // return;
661
+ if (diagnoseNondifferentiableOriginalFunctionType (originalFnTy,
662
+ bai, origCallee, config)) {
663
+ errorOccurred = true ;
664
+ return ;
665
+ }
592
666
}
593
667
594
668
auto *diffFuncInst =
@@ -760,60 +834,11 @@ class VJPCloner::Implementation final
760
834
}
761
835
}
762
836
763
- // Check and diagnose non-differentiable original function type.
764
- auto diagnoseNondifferentiableOriginalFunctionType =
765
- [&](CanSILFunctionType origFnTy) {
766
- // Check and diagnose non-differentiable arguments.
767
- for (auto paramIndex : config.parameterIndices ->getIndices ()) {
768
- if (!originalFnTy->getParameters ()[paramIndex]
769
- .getSILStorageInterfaceType ()
770
- .isDifferentiable (getModule ())) {
771
- auto arg = ai->getArgumentsWithoutIndirectResults ()[paramIndex];
772
- // FIXME: This shouldn't be necessary and might indicate a bug in
773
- // the transformation.
774
- RegularLocation nonAutoGenLoc (arg.getLoc ());
775
- nonAutoGenLoc.markNonAutoGenerated ();
776
- auto startLoc = nonAutoGenLoc.getStartSourceLoc ();
777
- auto endLoc = nonAutoGenLoc.getEndSourceLoc ();
778
- context
779
- .emitNondifferentiabilityError (
780
- arg, invoker, diag::autodiff_nondifferentiable_argument)
781
- .fixItInsert (startLoc, " withoutDerivative(at: " )
782
- .fixItInsertAfter (endLoc, " )" );
783
- errorOccurred = true ;
784
- return true ;
785
- }
786
- }
787
- // Check and diagnose non-differentiable results.
788
- for (auto resultIndex : config.resultIndices ->getIndices ()) {
789
- SILType remappedResultType;
790
- if (resultIndex >= originalFnTy->getNumResults ()) {
791
- auto semanticResultArgIdx = resultIndex - originalFnTy->getNumResults ();
792
- auto semanticResultArg =
793
- *std::next (ai->getAutoDiffSemanticResultArguments ().begin (),
794
- semanticResultArgIdx);
795
- remappedResultType = semanticResultArg->getType ();
796
- } else {
797
- remappedResultType = originalFnTy->getResults ()[resultIndex]
798
- .getSILStorageInterfaceType ();
799
- }
800
- if (!remappedResultType.isDifferentiable (getModule ())) {
801
- auto startLoc = ai->getLoc ().getStartSourceLoc ();
802
- auto endLoc = ai->getLoc ().getEndSourceLoc ();
803
- context
804
- .emitNondifferentiabilityError (
805
- origCallee, invoker,
806
- diag::autodiff_nondifferentiable_result)
807
- .fixItInsert (startLoc, " withoutDerivative(at: " )
808
- .fixItInsertAfter (endLoc, " )" );
809
- errorOccurred = true ;
810
- return true ;
811
- }
812
- }
813
- return false ;
814
- };
815
- if (diagnoseNondifferentiableOriginalFunctionType (originalFnTy))
837
+ if (diagnoseNondifferentiableOriginalFunctionType (originalFnTy,
838
+ ai, origCallee, config)) {
839
+ errorOccurred = true ;
816
840
return ;
841
+ }
817
842
818
843
// If VJP has not yet been found, emit an `differentiable_function`
819
844
// instruction on the remapped original function operand and
@@ -846,9 +871,13 @@ class VJPCloner::Implementation final
846
871
ParameterConvention::Direct_Guaranteed);
847
872
origCallee = vjpPartialApply;
848
873
originalFnTy = origCallee->getType ().castTo <SILFunctionType>();
874
+
849
875
// Diagnose if new original function type is non-differentiable.
850
- if (diagnoseNondifferentiableOriginalFunctionType (originalFnTy))
876
+ if (diagnoseNondifferentiableOriginalFunctionType (originalFnTy,
877
+ ai, origCallee, config)) {
878
+ errorOccurred = true ;
851
879
return ;
880
+ }
852
881
}
853
882
854
883
auto *diffFuncInst = context.createDifferentiableFunction (
0 commit comments