Skip to content

Commit 9efc651

Browse files
committed
Diagnose non-differentiable yields
1 parent 8f7c339 commit 9efc651

File tree

9 files changed

+191
-119
lines changed

9 files changed

+191
-119
lines changed

include/swift/AST/Types.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5134,7 +5134,7 @@ class SILFunctionType final
51345134
unsigned getNumAutoDiffSemanticResults() const {
51355135
return getNumResults() +
51365136
getNumAutoDiffSemanticResultsParameters() +
5137-
getNumIndirectFormalYields();
5137+
getNumYields();
51385138
}
51395139

51405140
/// Get the generic signature that the component types are specified

lib/SIL/IR/SILFunctionType.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ IndexSubset *SILFunctionType::getDifferentiabilityResultIndices() {
264264
SILParameterInfo::NotDifferentiable))
265265
resultIndices.push_back(numSemanticResults + yieldAndIndex.index());
266266

267-
numSemanticResults += getNumIndirectFormalYields();
267+
numSemanticResults += getNumYields();
268268
return IndexSubset::get(getASTContext(), numSemanticResults, resultIndices);
269269
}
270270

lib/SILOptimizer/Differentiation/Common.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -244,14 +244,11 @@ void collectMinimalIndicesForFunctionCall(
244244
resultIndices.push_back(semanticResultParamResultIndex++);
245245
}
246246

247-
// Record all indirect yields
247+
// Record all yields. While we do not have a way to represent direct yields
248+
// (_read accessors) we run activity analysis for them. These will be
249+
// diagnosed later.
248250
if (BeginApplyInst *bai = dyn_cast<BeginApplyInst>(*ai)) {
249251
for (const auto &yieldAndIdx : enumerate(calleeConvs.getYields())) {
250-
auto &yield = yieldAndIdx.value();
251-
// We do not have a way to represent non @inout yields
252-
if (!yield.isAutoDiffSemanticResult())
253-
continue;
254-
255252
results.push_back(bai->getYieldedValues()[yieldAndIdx.index()]);
256253
resultIndices.push_back(semanticResultParamResultIndex++);
257254
}

lib/SILOptimizer/Differentiation/LinearMapInfo.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,8 +242,11 @@ Type LinearMapInfo::getLinearMapType(ADContext &context, FullApplySite fai) {
242242
SILType remappedResultType;
243243
if (resultIndex >= firstYieldResultIndex) {
244244
auto yieldResultIdx = resultIndex - firstYieldResultIndex;
245-
remappedResultType =
246-
origFnTy->getYields()[yieldResultIdx].getSILStorageInterfaceType();
245+
const auto& yield = origFnTy->getYields()[yieldResultIdx];
246+
// We do not have a good way to differentiate direct yields
247+
if (!yield.isAutoDiffSemanticResult())
248+
return true;
249+
remappedResultType = yield.getSILStorageInterfaceType();
247250
} else if (resultIndex >= firstSemanticParamResultIdx) {
248251
auto semanticResultArgIdx = resultIndex - firstSemanticParamResultIdx;
249252
auto semanticResultArg =

lib/SILOptimizer/Differentiation/PullbackCloner.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1160,7 +1160,7 @@ class PullbackCloner::Implementation final
11601160
BeginApplyInst *bai = aai->getBeginApply();
11611161
assert(getPullbackInfo().shouldDifferentiateApplySite(bai));
11621162

1163-
// Coroutine differentiation is not yet supported.
1163+
// abort_apply differentiation is not yet supported.
11641164
getContext().emitNondifferentiabilityError(
11651165
bai, getInvoker(), diag::autodiff_coroutines_not_supported);
11661166
errorOccurred = true;

lib/SILOptimizer/Differentiation/VJPCloner.cpp

Lines changed: 85 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,69 @@ class VJPCloner::Implementation final
527527
getOpValue(origCallee)->getDefiningInstruction());
528528
}
529529

530+
// Check and diagnose non-differentiable original function type.
531+
bool diagnoseNondifferentiableOriginalFunctionType(CanSILFunctionType originalFnTy,
532+
FullApplySite fai, SILValue origCallee,
533+
const AutoDiffConfig &config) const {
534+
// Check and diagnose non-differentiable arguments.
535+
for (auto paramIndex : config.parameterIndices->getIndices()) {
536+
if (!originalFnTy->getParameters()[paramIndex]
537+
.getSILStorageInterfaceType()
538+
.isDifferentiable(getModule())) {
539+
auto arg = fai.getArgumentsWithoutIndirectResults()[paramIndex];
540+
// FIXME: This shouldn't be necessary and might indicate a bug in
541+
// the transformation.
542+
RegularLocation nonAutoGenLoc(arg.getLoc());
543+
nonAutoGenLoc.markNonAutoGenerated();
544+
auto startLoc = nonAutoGenLoc.getStartSourceLoc();
545+
auto endLoc = nonAutoGenLoc.getEndSourceLoc();
546+
context.emitNondifferentiabilityError(
547+
arg, invoker, diag::autodiff_nondifferentiable_argument)
548+
.fixItInsert(startLoc, "withoutDerivative(at: ")
549+
.fixItInsertAfter(endLoc, ")");
550+
return true;
551+
}
552+
}
553+
554+
// Check and diagnose non-differentiable results.
555+
unsigned firstSemanticParamResultIdx = originalFnTy->getNumResults();
556+
unsigned firstYieldResultIndex = originalFnTy->getNumResults() +
557+
originalFnTy->getNumAutoDiffSemanticResultsParameters();
558+
559+
for (auto resultIndex : config.resultIndices->getIndices()) {
560+
SILType remappedResultType;
561+
if (resultIndex >= firstYieldResultIndex) {
562+
auto yieldResultIdx = resultIndex - firstYieldResultIndex;
563+
const auto& yield = originalFnTy->getYields()[yieldResultIdx];
564+
// We do not have a good way to differentiate direct yields
565+
if (yield.isAutoDiffSemanticResult())
566+
remappedResultType = yield.getSILStorageInterfaceType();
567+
} else if (resultIndex >= firstSemanticParamResultIdx) {
568+
auto semanticResultArgIdx = resultIndex - firstSemanticParamResultIdx;
569+
auto semanticResultArg =
570+
*std::next(fai.getAutoDiffSemanticResultArguments().begin(),
571+
semanticResultArgIdx);
572+
remappedResultType = semanticResultArg->getType();
573+
} else {
574+
remappedResultType = originalFnTy->getResults()[resultIndex]
575+
.getSILStorageInterfaceType();
576+
}
577+
578+
if (!remappedResultType || !remappedResultType.isDifferentiable(getModule())) {
579+
auto startLoc = fai.getLoc().getStartSourceLoc();
580+
auto endLoc = fai.getLoc().getEndSourceLoc();
581+
context.emitNondifferentiabilityError(
582+
origCallee, invoker,
583+
diag::autodiff_nondifferentiable_result)
584+
.fixItInsert(startLoc, "withoutDerivative(at: ")
585+
.fixItInsertAfter(endLoc, ")");
586+
return true;
587+
}
588+
}
589+
590+
return false;
591+
}
592+
530593
void visitBeginApplyInst(BeginApplyInst *bai) {
531594
// If callee should not be differentiated, do standard cloning.
532595
if (!pullbackInfo.shouldDifferentiateApplySite(bai)) {
@@ -569,6 +632,13 @@ class VJPCloner::Implementation final
569632
IndexSubset::get(getASTContext(),
570633
bai->getSubstCalleeType()->getNumAutoDiffSemanticResults(),
571634
activeResultIndices));
635+
636+
if (diagnoseNondifferentiableOriginalFunctionType(originalFnTy,
637+
bai, origCallee, config)) {
638+
errorOccurred = true;
639+
return;
640+
}
641+
572642
// Emit the VJP.
573643
SILValue vjpValue;
574644

@@ -586,9 +656,13 @@ class VJPCloner::Implementation final
586656
ParameterConvention::Direct_Guaranteed);
587657
origCallee = vjpPartialApply;
588658
originalFnTy = origCallee->getType().castTo<SILFunctionType>();
659+
589660
// Diagnose if new original function type is non-differentiable.
590-
//if (diagnoseNondifferentiableOriginalFunctionType(originalFnTy))
591-
// return;
661+
if (diagnoseNondifferentiableOriginalFunctionType(originalFnTy,
662+
bai, origCallee, config)) {
663+
errorOccurred = true;
664+
return;
665+
}
592666
}
593667

594668
auto *diffFuncInst =
@@ -760,60 +834,11 @@ class VJPCloner::Implementation final
760834
}
761835
}
762836

763-
// Check and diagnose non-differentiable original function type.
764-
auto diagnoseNondifferentiableOriginalFunctionType =
765-
[&](CanSILFunctionType origFnTy) {
766-
// Check and diagnose non-differentiable arguments.
767-
for (auto paramIndex : config.parameterIndices->getIndices()) {
768-
if (!originalFnTy->getParameters()[paramIndex]
769-
.getSILStorageInterfaceType()
770-
.isDifferentiable(getModule())) {
771-
auto arg = ai->getArgumentsWithoutIndirectResults()[paramIndex];
772-
// FIXME: This shouldn't be necessary and might indicate a bug in
773-
// the transformation.
774-
RegularLocation nonAutoGenLoc(arg.getLoc());
775-
nonAutoGenLoc.markNonAutoGenerated();
776-
auto startLoc = nonAutoGenLoc.getStartSourceLoc();
777-
auto endLoc = nonAutoGenLoc.getEndSourceLoc();
778-
context
779-
.emitNondifferentiabilityError(
780-
arg, invoker, diag::autodiff_nondifferentiable_argument)
781-
.fixItInsert(startLoc, "withoutDerivative(at: ")
782-
.fixItInsertAfter(endLoc, ")");
783-
errorOccurred = true;
784-
return true;
785-
}
786-
}
787-
// Check and diagnose non-differentiable results.
788-
for (auto resultIndex : config.resultIndices->getIndices()) {
789-
SILType remappedResultType;
790-
if (resultIndex >= originalFnTy->getNumResults()) {
791-
auto semanticResultArgIdx = resultIndex - originalFnTy->getNumResults();
792-
auto semanticResultArg =
793-
*std::next(ai->getAutoDiffSemanticResultArguments().begin(),
794-
semanticResultArgIdx);
795-
remappedResultType = semanticResultArg->getType();
796-
} else {
797-
remappedResultType = originalFnTy->getResults()[resultIndex]
798-
.getSILStorageInterfaceType();
799-
}
800-
if (!remappedResultType.isDifferentiable(getModule())) {
801-
auto startLoc = ai->getLoc().getStartSourceLoc();
802-
auto endLoc = ai->getLoc().getEndSourceLoc();
803-
context
804-
.emitNondifferentiabilityError(
805-
origCallee, invoker,
806-
diag::autodiff_nondifferentiable_result)
807-
.fixItInsert(startLoc, "withoutDerivative(at: ")
808-
.fixItInsertAfter(endLoc, ")");
809-
errorOccurred = true;
810-
return true;
811-
}
812-
}
813-
return false;
814-
};
815-
if (diagnoseNondifferentiableOriginalFunctionType(originalFnTy))
837+
if (diagnoseNondifferentiableOriginalFunctionType(originalFnTy,
838+
ai, origCallee, config)) {
839+
errorOccurred = true;
816840
return;
841+
}
817842

818843
// If VJP has not yet been found, emit an `differentiable_function`
819844
// instruction on the remapped original function operand and
@@ -846,9 +871,13 @@ class VJPCloner::Implementation final
846871
ParameterConvention::Direct_Guaranteed);
847872
origCallee = vjpPartialApply;
848873
originalFnTy = origCallee->getType().castTo<SILFunctionType>();
874+
849875
// Diagnose if new original function type is non-differentiable.
850-
if (diagnoseNondifferentiableOriginalFunctionType(originalFnTy))
876+
if (diagnoseNondifferentiableOriginalFunctionType(originalFnTy,
877+
ai, origCallee, config)) {
878+
errorOccurred = true;
851879
return;
880+
}
852881
}
853882

854883
auto *diffFuncInst = context.createDifferentiableFunction(

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -156,9 +156,16 @@ static bool diagnoseUnsupportedControlFlow(ADContext &context,
156156
isa<CheckedCastBranchInst>(term) ||
157157
isa<CheckedCastAddrBranchInst>(term) || isa<TryApplyInst>(term))
158158
continue;
159-
160-
if (isa<YieldInst>(term)) {
161-
// FIXME: diagnose proper yields
159+
160+
// We can differentiate only indirect yields
161+
if (auto *yi = dyn_cast<YieldInst>(term)) {
162+
for (const auto &val : yi->getAllOperands())
163+
if (!yi->getYieldInfoForOperand(val).isAutoDiffSemanticResult()) {
164+
context.emitNondifferentiabilityError(
165+
term, invoker, diag::autodiff_control_flow_not_supported);
166+
return true;
167+
}
168+
162169
continue;
163170
}
164171

@@ -554,8 +561,10 @@ emitDerivativeFunctionReference(
554561
SILType resultType;
555562
if (resultIndex >= firstYieldResultIndex) {
556563
auto yieldResultIndex = resultIndex - firstYieldResultIndex;
557-
resultType = originalFnTy->getYields()[yieldResultIndex]
558-
.getSILStorageInterfaceType();
564+
auto yield = originalFnTy->getYields()[yieldResultIndex];
565+
// We can only differentiate indirect yields
566+
if (yield.isAutoDiffSemanticResult())
567+
resultType = yield.getSILStorageInterfaceType();
559568
} else if (resultIndex >= firstSemanticParamResultIdx) {
560569
auto semanticResultParamIdx = resultIndex - firstSemanticParamResultIdx;
561570
auto semanticResultParam =
@@ -566,7 +575,7 @@ emitDerivativeFunctionReference(
566575
resultType = originalFnTy->getResults()[resultIndex]
567576
.getSILStorageInterfaceType();
568577
}
569-
if (!resultType.isDifferentiable(context.getModule())) {
578+
if (!resultType || !resultType.isDifferentiable(context.getModule())) {
570579
context.emitNondifferentiabilityError(
571580
original, invoker, diag::autodiff_nondifferentiable_result);
572581
return llvm::None;

0 commit comments

Comments
 (0)