Skip to content

Commit 3aba19b

Browse files
committed
First cut of co-routine differentiation
1 parent debe0dd commit 3aba19b

File tree

16 files changed

+711
-175
lines changed

16 files changed

+711
-175
lines changed

include/swift/AST/IndexSubset.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,10 @@ class IndexSubset : public llvm::FoldingSetNode {
108108
static IndexSubset *get(ASTContext &ctx, unsigned capacity,
109109
ArrayRef<unsigned> indices) {
110110
SmallBitVector indicesBitVec(capacity, false);
111-
for (auto index : indices)
111+
for (auto index : indices) {
112+
assert(index < capacity);
112113
indicesBitVec.set(index);
114+
}
113115
return IndexSubset::get(ctx, indicesBitVec);
114116
}
115117

include/swift/AST/Types.h

+4-1
Original file line numberDiff line numberDiff line change
@@ -5130,8 +5130,11 @@ class SILFunctionType final
51305130
/// Returns the number of function potential semantic results:
51315131
/// * Usual results
51325132
/// * Inout parameters
5133+
/// * yields
51335134
unsigned getNumAutoDiffSemanticResults() const {
5134-
return getNumResults() + getNumAutoDiffSemanticResultsParameters();
5135+
return getNumResults() +
5136+
getNumAutoDiffSemanticResultsParameters() +
5137+
getNumIndirectFormalYields();
51355138
}
51365139

51375140
/// Get the generic signature that the component types are specified

include/swift/SIL/SILFunctionConventions.h

+8
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,14 @@ class SILFunctionConventions {
248248
idx < indirectResults + getNumIndirectSILErrorResults();
249249
}
250250

251+
unsigned getNumAutoDiffSemanticResults() const {
252+
return funcTy->getNumAutoDiffSemanticResults();
253+
}
254+
255+
unsigned getNumAutoDiffSemanticResultParameters() const {
256+
return funcTy->getNumAutoDiffSemanticResultsParameters();
257+
}
258+
251259
/// Are any SIL results passed as address-typed arguments?
252260
bool hasIndirectSILResults() const { return getNumIndirectSILResults() != 0; }
253261
bool hasIndirectSILErrorResults() const { return getNumIndirectSILErrorResults() != 0; }

include/swift/SILOptimizer/Differentiation/ADContext.h

+9-2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#ifndef SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_ADCONTEXT_H
1818
#define SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_ADCONTEXT_H
1919

20+
#include "swift/SIL/ApplySite.h"
2021
#include "swift/SILOptimizer/Differentiation/Common.h"
2122
#include "swift/SILOptimizer/Differentiation/DifferentiationInvoker.h"
2223

@@ -51,6 +52,12 @@ struct NestedApplyInfo {
5152
/// The original pullback type before reabstraction. `None` if the pullback
5253
/// type is not reabstracted.
5354
llvm::Optional<CanSILFunctionType> originalPullbackType;
55+
/// Index of `apply` pullback in nested pullback call
56+
unsigned pullbackIdx = -1U;
57+
/// Pullback value itself that is memoized in some cases (e.g. pullback is
58+
/// called by `begin_apply`, but should be destroyed after `end_apply`).
59+
SILValue pullback = SILValue();
60+
SILValue beginApplyToken = SILValue();
5461
};
5562

5663
/// Per-module contextual information for the Differentiation pass.
@@ -97,7 +104,7 @@ class ADContext {
97104

98105
/// Mapping from original `apply` instructions to their corresponding
99106
/// `NestedApplyInfo`s.
100-
llvm::DenseMap<ApplyInst *, NestedApplyInfo> nestedApplyInfo;
107+
llvm::DenseMap<FullApplySite, NestedApplyInfo> nestedApplyInfo;
101108

102109
/// List of generated functions (JVPs, VJPs, pullbacks, and thunks).
103110
/// Saved for deletion during cleanup.
@@ -185,7 +192,7 @@ class ADContext {
185192
invokers.insert({witness, DifferentiationInvoker(witness)});
186193
}
187194

188-
llvm::DenseMap<ApplyInst *, NestedApplyInfo> &getNestedApplyInfo() {
195+
llvm::DenseMap<FullApplySite, NestedApplyInfo> &getNestedApplyInfo() {
189196
return nestedApplyInfo;
190197
}
191198

include/swift/SILOptimizer/Differentiation/Common.h

+3-2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "swift/AST/DiagnosticsSIL.h"
2121
#include "swift/AST/Expr.h"
2222
#include "swift/AST/SemanticAttrs.h"
23+
#include "swift/SIL/ApplySite.h"
2324
#include "swift/SIL/SILDifferentiabilityWitness.h"
2425
#include "swift/SIL/SILFunction.h"
2526
#include "swift/SIL/Projection.h"
@@ -112,15 +113,15 @@ void collectAllDirectResultsInTypeOrder(SILFunction &function,
112113
/// Given a function call site, gathers all of its actual results (both direct
113114
/// and indirect) in an order defined by its result type.
114115
void collectAllActualResultsInTypeOrder(
115-
ApplyInst *ai, ArrayRef<SILValue> extractedDirectResults,
116+
FullApplySite fai, ArrayRef<SILValue> extractedDirectResults,
116117
SmallVectorImpl<SILValue> &results);
117118

118119
/// For an `apply` instruction with active results, compute:
119120
/// - The results of the `apply` instruction, in type order.
120121
/// - The set of minimal parameter and result indices for differentiating the
121122
/// `apply` instruction.
122123
void collectMinimalIndicesForFunctionCall(
123-
ApplyInst *ai, const AutoDiffConfig &parentConfig,
124+
FullApplySite fai, const AutoDiffConfig &parentConfig,
124125
const DifferentiableActivityInfo &activityInfo,
125126
SmallVectorImpl<SILValue> &results, SmallVectorImpl<unsigned> &paramIndices,
126127
SmallVectorImpl<unsigned> &resultIndices);

include/swift/SILOptimizer/Differentiation/LinearMapInfo.h

+12-12
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,9 @@ class LinearMapInfo {
7777
/// For differentials: these are successor enums.
7878
llvm::DenseMap<SILBasicBlock *, EnumDecl *> branchingTraceDecls;
7979

80-
/// Mapping from `apply` instructions in the original function to the
80+
/// Mapping from `apply` / `begin_apply` instructions in the original function to the
8181
/// corresponding linear map tuple type index.
82-
llvm::DenseMap<ApplyInst *, unsigned> linearMapIndexMap;
82+
llvm::DenseMap<FullApplySite, unsigned> linearMapIndexMap;
8383

8484
/// Mapping from predecessor-successor basic block pairs in the original
8585
/// function to the corresponding branching trace enum case.
@@ -112,9 +112,9 @@ class LinearMapInfo {
112112
void populateBranchingTraceDecl(SILBasicBlock *originalBB,
113113
SILLoopInfo *loopInfo);
114114

115-
/// Given an `apply` instruction, conditionally gets a linear map tuple field
116-
/// AST type for its linear map function if it is active.
117-
Type getLinearMapType(ADContext &context, ApplyInst *ai);
115+
/// Given an `apply` / `begin_apply` instruction, conditionally gets a linear
116+
/// map tuple field AST type for its linear map function if it is active.
117+
Type getLinearMapType(ADContext &context, FullApplySite fai);
118118

119119
/// Generates linear map struct and branching enum declarations for the given
120120
/// function. Linear map structs are populated with linear map fields and a
@@ -180,18 +180,18 @@ class LinearMapInfo {
180180
}
181181

182182
/// Finds the linear map index in the pullback tuple for the given
183-
/// `apply` instruction in the original function.
184-
unsigned lookUpLinearMapIndex(ApplyInst *ai) const {
185-
assert(ai->getFunction() == original);
186-
auto lookup = linearMapIndexMap.find(ai);
183+
/// `apply` / `begin_apply` instruction in the original function.
184+
unsigned lookUpLinearMapIndex(FullApplySite fas) const {
185+
assert(fas->getFunction() == original);
186+
auto lookup = linearMapIndexMap.find(fas);
187187
assert(lookup != linearMapIndexMap.end() &&
188188
"No linear map field corresponding to the given `apply`");
189189
return lookup->getSecond();
190190
}
191191

192-
Type lookUpLinearMapType(ApplyInst *ai) const {
193-
unsigned idx = lookUpLinearMapIndex(ai);
194-
return getLinearMapTupleType(ai->getParentBlock())->getElement(idx).getType();
192+
Type lookUpLinearMapType(FullApplySite fas) const {
193+
unsigned idx = lookUpLinearMapIndex(fas);
194+
return getLinearMapTupleType(fas->getParent())->getElement(idx).getType();
195195
}
196196

197197
bool hasHeapAllocatedContext() const {

include/swift/SILOptimizer/Differentiation/Thunk.h

+5
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,11 @@ SILFunction *getOrCreateReabstractionThunk(SILOptFunctionBuilder &fb,
5656
CanSILFunctionType fromType,
5757
CanSILFunctionType toType);
5858

59+
SILValue reabstractCoroutine(
60+
SILBuilder &builder, SILOptFunctionBuilder &fb, SILLocation loc,
61+
SILValue fn, CanSILFunctionType toType,
62+
std::function<SubstitutionMap(SubstitutionMap)> remapSubstitutions);
63+
5964
/// Reabstracts the given function-typed value `fn` to the target type `toType`.
6065
/// Remaps substitutions using `remapSubstitutions`.
6166
SILValue reabstractFunction(

lib/AST/AutoDiff.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,10 @@ DifferentiabilityWitnessFunctionKind::getAsDerivativeFunctionKind() const {
101101
llvm_unreachable("invalid derivative kind");
102102
}
103103

104+
void AutoDiffConfig::dump() const {
105+
print(llvm::errs());
106+
}
107+
104108
void AutoDiffConfig::print(llvm::raw_ostream &s) const {
105109
s << "(parameters=";
106110
parameterIndices->print(s);

lib/SIL/IR/SILFunctionType.cpp

+84-47
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,13 @@ IndexSubset *SILFunctionType::getDifferentiabilityResultIndices() {
258258

259259
numSemanticResults += getNumAutoDiffSemanticResultsParameters();
260260

261+
// Check yields.
262+
for (auto yieldAndIndex : enumerate(getYields()))
263+
if (!yieldAndIndex.value().hasOption(
264+
SILParameterInfo::NotDifferentiable))
265+
resultIndices.push_back(numSemanticResults + yieldAndIndex.index());
266+
267+
numSemanticResults += getNumIndirectFormalYields();
261268
return IndexSubset::get(getASTContext(), numSemanticResults, resultIndices);
262269
}
263270

@@ -555,10 +562,14 @@ static CanSILFunctionType getAutoDiffDifferentialType(
555562
param.getConvention());
556563
differentialParams.push_back({paramTanType, paramConv});
557564
}
565+
558566
SmallVector<SILResultInfo, 1> differentialResults;
567+
unsigned firstSemanticParamResultIdx = originalFnTy->getNumResults();
568+
unsigned firstYieldResultIndex = originalFnTy->getNumResults() +
569+
originalFnTy->getNumAutoDiffSemanticResultsParameters();
559570
for (auto resultIndex : resultIndices->getIndices()) {
560571
// Handle formal original result.
561-
if (resultIndex < originalFnTy->getNumResults()) {
572+
if (resultIndex < firstSemanticParamResultIdx) {
562573
auto &result = originalResults[resultIndex];
563574
auto resultTanType = getAutoDiffTangentTypeForLinearMap(
564575
result.getInterfaceType(), lookupConformance,
@@ -571,26 +582,38 @@ static CanSILFunctionType getAutoDiffDifferentialType(
571582
result.getConvention());
572583
differentialResults.push_back({resultTanType, resultConv});
573584
continue;
574-
}
575-
// Handle original semantic result parameters.
576-
auto resultParamIndex = resultIndex - originalFnTy->getNumResults();
577-
auto resultParamIt = std::next(
585+
} else if (resultIndex < firstYieldResultIndex) {
586+
// Handle original semantic result parameters.
587+
auto resultParamIndex = resultIndex - originalFnTy->getNumResults();
588+
auto resultParamIt = std::next(
578589
originalFnTy->getAutoDiffSemanticResultsParameters().begin(),
579590
resultParamIndex);
580-
auto paramIndex =
581-
std::distance(originalFnTy->getParameters().begin(), &*resultParamIt);
582-
// If the original semantic result parameter is a differentiability
583-
// parameter, then it already has a corresponding differential
584-
// parameter. Skip adding a corresponding differential result.
585-
if (parameterIndices->contains(paramIndex))
586-
continue;
591+
auto paramIndex =
592+
std::distance(originalFnTy->getParameters().begin(), &*resultParamIt);
593+
// If the original semantic result parameter is a differentiability
594+
// parameter, then it already has a corresponding differential
595+
// parameter. Skip adding a corresponding differential result.
596+
if (parameterIndices->contains(paramIndex))
597+
continue;
587598

588-
auto resultParam = originalFnTy->getParameters()[paramIndex];
589-
auto resultParamTanType = getAutoDiffTangentTypeForLinearMap(
590-
resultParam.getInterfaceType(), lookupConformance,
599+
auto resultParam = originalFnTy->getParameters()[paramIndex];
600+
auto resultParamTanType = getAutoDiffTangentTypeForLinearMap(
601+
resultParam.getInterfaceType(), lookupConformance,
591602
substGenericParams, substReplacements, ctx);
592-
differentialResults.emplace_back(resultParamTanType,
593-
ResultConvention::Indirect);
603+
differentialResults.emplace_back(resultParamTanType,
604+
ResultConvention::Indirect);
605+
} else {
606+
assert(originalFnTy->isCoroutine());
607+
assert(originalFnTy->getCoroutineKind() == SILCoroutineKind::YieldOnce);
608+
auto yieldResultIndex = resultIndex - firstYieldResultIndex;
609+
auto yieldResult = originalFnTy->getYields()[yieldResultIndex];
610+
auto resultParamTanType = getAutoDiffTangentTypeForLinearMap(
611+
yieldResult.getInterfaceType(), lookupConformance,
612+
substGenericParams, substReplacements, ctx);
613+
ParameterConvention paramTanConvention = yieldResult.getConvention();
614+
assert(yieldResult.getConvention() == ParameterConvention::Indirect_Inout);
615+
differentialParams.emplace_back(resultParamTanType, paramTanConvention);
616+
}
594617
}
595618

596619
SubstitutionMap substitutions;
@@ -604,9 +627,9 @@ static CanSILFunctionType getAutoDiffDifferentialType(
604627
}
605628
return SILFunctionType::get(
606629
GenericSignature(), SILFunctionType::ExtInfo(), SILCoroutineKind::None,
607-
ParameterConvention::Direct_Guaranteed, differentialParams, {},
608-
differentialResults, llvm::None, substitutions,
609-
/*invocationSubstitutions*/ SubstitutionMap(), ctx);
630+
ParameterConvention::Direct_Guaranteed,
631+
differentialParams, {}, differentialResults, llvm::None,
632+
substitutions, /*invocationSubstitutions*/ SubstitutionMap(), ctx);
610633
}
611634

612635
/// Returns the pullback type for the given original function type, parameter
@@ -696,11 +719,15 @@ static CanSILFunctionType getAutoDiffPullbackType(
696719
return conv;
697720
};
698721

699-
// Collect pullback parameters.
722+
// Collect pullback parameters & yields
700723
SmallVector<SILParameterInfo, 1> pullbackParams;
724+
SmallVector<SILYieldInfo, 1> pullbackYields;
725+
unsigned firstSemanticParamResultIdx = originalFnTy->getNumResults();
726+
unsigned firstYieldResultIndex = originalFnTy->getNumResults() +
727+
originalFnTy->getNumAutoDiffSemanticResultsParameters();
701728
for (auto resultIndex : resultIndices->getIndices()) {
702729
// Handle formal original result.
703-
if (resultIndex < originalFnTy->getNumResults()) {
730+
if (resultIndex < firstSemanticParamResultIdx) {
704731
auto &origRes = originalResults[resultIndex];
705732
auto resultTanType = getAutoDiffTangentTypeForLinearMap(
706733
origRes.getInterfaceType(), lookupConformance,
@@ -712,28 +739,38 @@ static CanSILFunctionType getAutoDiffPullbackType(
712739
->getCanonicalType(),
713740
origRes.getConvention());
714741
pullbackParams.emplace_back(resultTanType, paramConv);
715-
continue;
742+
} else if (resultIndex < firstYieldResultIndex) {
743+
// Handle original semantic result parameters.
744+
auto resultParamIndex = resultIndex - firstSemanticParamResultIdx;
745+
auto resultParamIt = std::next(
746+
originalFnTy->getAutoDiffSemanticResultsParameters().begin(),
747+
resultParamIndex);
748+
auto paramIndex =
749+
std::distance(originalFnTy->getParameters().begin(), &*resultParamIt);
750+
auto resultParam = originalFnTy->getParameters()[paramIndex];
751+
// The pullback parameter convention depends on whether the original `inout`
752+
// parameter is a differentiability parameter.
753+
// - If yes, the pullback parameter convention is `@inout`.
754+
// - If no, the pullback parameter convention is `@in_guaranteed`.
755+
auto resultParamTanType = getAutoDiffTangentTypeForLinearMap(
756+
resultParam.getInterfaceType(), lookupConformance,
757+
substGenericParams, substReplacements, ctx);
758+
ParameterConvention paramTanConvention = resultParam.getConvention();
759+
if (!parameterIndices->contains(paramIndex))
760+
paramTanConvention = ParameterConvention::Indirect_In_Guaranteed;
761+
pullbackParams.emplace_back(resultParamTanType, paramTanConvention);
762+
} else {
763+
assert(originalFnTy->isCoroutine());
764+
assert(originalFnTy->getCoroutineKind() == SILCoroutineKind::YieldOnce);
765+
auto yieldResultIndex = resultIndex - firstYieldResultIndex;
766+
auto yieldResult = originalFnTy->getYields()[yieldResultIndex];
767+
auto resultParamTanType = getAutoDiffTangentTypeForLinearMap(
768+
yieldResult.getInterfaceType(), lookupConformance,
769+
substGenericParams, substReplacements, ctx);
770+
ParameterConvention paramTanConvention = yieldResult.getConvention();
771+
assert(yieldResult.getConvention() == ParameterConvention::Indirect_Inout);
772+
pullbackYields.emplace_back(resultParamTanType, paramTanConvention);
716773
}
717-
// Handle original semantic result parameters.
718-
auto resultParamIndex = resultIndex - originalFnTy->getNumResults();
719-
auto resultParamIt = std::next(
720-
originalFnTy->getAutoDiffSemanticResultsParameters().begin(),
721-
resultParamIndex);
722-
auto paramIndex =
723-
std::distance(originalFnTy->getParameters().begin(), &*resultParamIt);
724-
auto resultParam = originalFnTy->getParameters()[paramIndex];
725-
// The pullback parameter convention depends on whether the original `inout`
726-
// parameter is a differentiability parameter.
727-
// - If yes, the pullback parameter convention is `@inout`.
728-
// - If no, the pullback parameter convention is `@in_guaranteed`.
729-
auto resultParamTanType = getAutoDiffTangentTypeForLinearMap(
730-
resultParam.getInterfaceType(), lookupConformance,
731-
substGenericParams, substReplacements, ctx);
732-
ParameterConvention paramTanConvention = resultParam.getConvention();
733-
if (!parameterIndices->contains(paramIndex))
734-
paramTanConvention = ParameterConvention::Indirect_In_Guaranteed;
735-
736-
pullbackParams.emplace_back(resultParamTanType, paramTanConvention);
737774
}
738775

739776
// Collect pullback results.
@@ -756,6 +793,7 @@ static CanSILFunctionType getAutoDiffPullbackType(
756793
param.getConvention());
757794
pullbackResults.push_back({paramTanType, resultTanConvention});
758795
}
796+
759797
SubstitutionMap substitutions;
760798
if (!substGenericParams.empty()) {
761799
auto genericSig =
@@ -768,7 +806,7 @@ static CanSILFunctionType getAutoDiffPullbackType(
768806
return SILFunctionType::get(
769807
GenericSignature(), SILFunctionType::ExtInfo(), originalFnTy->getCoroutineKind(),
770808
ParameterConvention::Direct_Guaranteed,
771-
pullbackParams, {}, pullbackResults, llvm::None, substitutions,
809+
pullbackParams, pullbackYields, pullbackResults, llvm::None, substitutions,
772810
/*invocationSubstitutions*/ SubstitutionMap(), ctx);
773811
}
774812

@@ -920,10 +958,9 @@ CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType(
920958
// Compute the derivative function results.
921959
SmallVector<SILResultInfo, 4> newResults;
922960
newResults.reserve(getNumResults() + 1);
923-
for (auto &result : constrainedOriginalFnTy->getResults()) {
961+
for (auto &result : constrainedOriginalFnTy->getResults())
924962
newResults.push_back(result);
925-
}
926-
newResults.push_back({closureType, ResultConvention::Owned});
963+
newResults.emplace_back(closureType, ResultConvention::Owned);
927964

928965
// Compute the derivative function ExtInfo.
929966
// If original function is `@convention(c)`, the derivative function should

0 commit comments

Comments
 (0)