Skip to content

Commit e3ff038

Browse files
committed
Enable a compiler plugin to use the async transform after patmat
Currently, the async transformation is performed during the typer phase, like all other macros. We have to levy a few artificial restrictions on whern an async boundary may be: for instance we don't support await within a pattern guard. A more natural home for the transform would be after patterns have been translated. The test case in this commit shows how to use the async transform from a custom compiler phase after patmat. The remainder of the commit updates the implementation to handle the new tree shapes. For states that correspond to a label definition, we use `-symbol.id` as the state ID. This made it easier to emit the forward jumps to when processing the label application before we had seen the label definition. I've also made the transformation more efficient in the way it checks whether a given tree encloses an `await` call: we traverse the input tree at the start of the macro, and decorate it with tree attachments containig the answer to this question. Even after the ANF and state machine transforms introduce new layers of synthetic trees, the `containsAwait` code need only traverse shallowly through those trees to find a child that has the cached answer from the original traversal. I had to special case the ANF transform for expressions that always lead to a label jump: we avoids trying to push an assignment to a result variable into `if (cond) jump1() else jump2()`, in trees of the form: ``` % cat sandbox/jump.scala class Test { def test = { (null: Any) match { case _: String => "" case _ => "" } } } % qscalac -Xprint:patmat -Xprint-types sandbox/jump.scala def test: String = { case <synthetic> val x1: Any = (null{Null(null)}: Any){Any}; case5(){ if (x1.isInstanceOf{[T0]=> Boolean}[String]{Boolean}) matchEnd4{(x: String)String}(""{String("")}){String} else case6{()String}(){String}{String} }{String}; case6(){ matchEnd4{(x: String)String}(""{String("")}){String} }{String}; matchEnd4(x: String){ x{String} }{String} }{String} ```
1 parent 93f207f commit e3ff038

11 files changed

+430
-54
lines changed

src/main/scala/scala/async/internal/AnfTransform.scala

+62-11
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,18 @@ private[async] trait AnfTransform {
1616
import c.internal._
1717
import decorators._
1818

19-
def anfTransform(tree: Tree): Block = {
19+
def anfTransform(tree: Tree, owner: Symbol): Block = {
2020
// Must prepend the () for issue #31.
21-
val block = c.typecheck(atPos(tree.pos)(Block(List(Literal(Constant(()))), tree))).setType(tree.tpe)
21+
val block = c.typecheck(atPos(tree.pos)(newBlock(List(Literal(Constant(()))), tree))).setType(tree.tpe)
2222

2323
sealed abstract class AnfMode
2424
case object Anf extends AnfMode
2525
case object Linearizing extends AnfMode
2626

27+
val tree1 = adjustTypeOfTranslatedPatternMatches(block, owner)
28+
2729
var mode: AnfMode = Anf
28-
typingTransform(block)((tree, api) => {
30+
typingTransform(tree1, owner)((tree, api) => {
2931
def blockToList(tree: Tree): List[Tree] = tree match {
3032
case Block(stats, expr) => stats :+ expr
3133
case t => t :: Nil
@@ -34,7 +36,7 @@ private[async] trait AnfTransform {
3436
def listToBlock(trees: List[Tree]): Block = trees match {
3537
case trees @ (init :+ last) =>
3638
val pos = trees.map(_.pos).reduceLeft(_ union _)
37-
Block(init, last).setType(last.tpe).setPos(pos)
39+
newBlock(init, last).setType(last.tpe).setPos(pos)
3840
}
3941

4042
object linearize {
@@ -66,6 +68,17 @@ private[async] trait AnfTransform {
6668
stats :+ valDef :+ atPos(tree.pos)(ref1)
6769

6870
case If(cond, thenp, elsep) =>
71+
// If we run the ANF transform post patmat, deal with trees like `(if (cond) jump1(){String} else jump2(){String}){String}`
72+
// as though it was typed with `Unit`.
73+
def isPatMatGeneratedJump(t: Tree): Boolean = t match {
74+
case Block(_, expr) => isPatMatGeneratedJump(expr)
75+
case If(_, thenp, elsep) => isPatMatGeneratedJump(thenp) && isPatMatGeneratedJump(elsep)
76+
case _: Apply if isLabel(t.symbol) => true
77+
case _ => false
78+
}
79+
if (isPatMatGeneratedJump(expr)) {
80+
internal.setType(expr, definitions.UnitTpe)
81+
}
6982
// if type of if-else is Unit don't introduce assignment,
7083
// but add Unit value to bring it into form expected by async transform
7184
if (expr.tpe =:= definitions.UnitTpe) {
@@ -77,7 +90,7 @@ private[async] trait AnfTransform {
7790
def branchWithAssign(orig: Tree) = api.typecheck(atPos(orig.pos) {
7891
def cast(t: Tree) = mkAttributedCastPreservingAnnotations(t, tpe(varDef.symbol))
7992
orig match {
80-
case Block(thenStats, thenExpr) => Block(thenStats, Assign(Ident(varDef.symbol), cast(thenExpr)))
93+
case Block(thenStats, thenExpr) => newBlock(thenStats, Assign(Ident(varDef.symbol), cast(thenExpr)))
8194
case _ => Assign(Ident(varDef.symbol), cast(orig))
8295
}
8396
})
@@ -115,7 +128,7 @@ private[async] trait AnfTransform {
115128
}
116129
}
117130

118-
private def defineVar(prefix: String, tp: Type, pos: Position): ValDef = {
131+
def defineVar(prefix: String, tp: Type, pos: Position): ValDef = {
119132
val sym = api.currentOwner.newTermSymbol(name.fresh(prefix), pos, MUTABLE | SYNTHETIC).setInfo(uncheckedBounds(tp))
120133
valDef(sym, mkZero(uncheckedBounds(tp))).setType(NoType).setPos(pos)
121134
}
@@ -152,8 +165,7 @@ private[async] trait AnfTransform {
152165
}
153166

154167
def _transformToList(tree: Tree): List[Tree] = trace(tree) {
155-
val containsAwait = tree exists isAwait
156-
if (!containsAwait) {
168+
if (!containsAwait(tree)) {
157169
tree match {
158170
case Block(stats, expr) =>
159171
// avoids nested block in `while(await(false)) ...`.
@@ -207,10 +219,11 @@ private[async] trait AnfTransform {
207219
funStats ++ argStatss.flatten.flatten :+ typedNewApply
208220

209221
case Block(stats, expr) =>
210-
(stats :+ expr).flatMap(linearize.transformToList)
222+
val trees = stats.flatMap(linearize.transformToList).filterNot(isLiteralUnit) ::: linearize.transformToList(expr)
223+
eliminateLabelParameters(trees)
211224

212225
case ValDef(mods, name, tpt, rhs) =>
213-
if (rhs exists isAwait) {
226+
if (containsAwait(rhs)) {
214227
val stats :+ expr = api.atOwner(api.currentOwner.owner)(linearize.transformToList(rhs))
215228
stats.foreach(_.changeOwner(api.currentOwner, api.currentOwner.owner))
216229
stats :+ treeCopy.ValDef(tree, mods, name, tpt, expr)
@@ -247,7 +260,7 @@ private[async] trait AnfTransform {
247260
scrutStats :+ treeCopy.Match(tree, scrutExpr, caseDefs)
248261

249262
case LabelDef(name, params, rhs) =>
250-
List(LabelDef(name, params, Block(linearize.transformToList(rhs), Literal(Constant(())))).setSymbol(tree.symbol))
263+
List(LabelDef(name, params, newBlock(linearize.transformToList(rhs), Literal(Constant(())))).setSymbol(tree.symbol))
251264

252265
case TypeApply(fun, targs) =>
253266
val funStats :+ simpleFun = linearize.transformToList(fun)
@@ -259,6 +272,44 @@ private[async] trait AnfTransform {
259272
}
260273
}
261274

275+
// Replace the label parameters on `matchEnd` with use of a `matchRes` temporary variable
276+
def eliminateLabelParameters(statsExpr: List[Tree]): List[Tree] = {
277+
import internal.{methodType, setInfo}
278+
val caseDefToMatchResult = collection.mutable.Map[Symbol, Symbol]()
279+
280+
val matchResults = collection.mutable.Buffer[Tree]()
281+
val statsExpr0 = statsExpr.reverseMap {
282+
case ld @ LabelDef(_, param :: Nil, body) =>
283+
val matchResult = linearize.defineVar(name.matchRes, param.tpe, ld.pos)
284+
matchResults += matchResult
285+
caseDefToMatchResult(ld.symbol) = matchResult.symbol
286+
val ld2 = treeCopy.LabelDef(ld, ld.name, Nil, body.substituteSymbols(param.symbol :: Nil, matchResult.symbol :: Nil))
287+
setInfo(ld.symbol, methodType(Nil, ld.symbol.info.resultType))
288+
ld2
289+
case t =>
290+
if (caseDefToMatchResult.isEmpty) t
291+
else typingTransform(t)((tree, api) =>
292+
tree match {
293+
case Apply(fun, arg :: Nil) if isLabel(fun.symbol) && caseDefToMatchResult.contains(fun.symbol) =>
294+
api.typecheck(atPos(tree.pos)(newBlock(Assign(Ident(caseDefToMatchResult(fun.symbol)), api.recur(arg)) :: Nil, treeCopy.Apply(tree, fun, Nil))))
295+
case Block(stats, expr) =>
296+
api.default(tree) match {
297+
case Block(stats, Block(stats1, expr)) =>
298+
treeCopy.Block(tree, stats ::: stats1, expr)
299+
case t => t
300+
}
301+
case _ =>
302+
api.default(tree)
303+
}
304+
)
305+
}
306+
matchResults.toList match {
307+
case Nil => statsExpr0.reverse
308+
case r1 :: Nil => (r1 +: statsExpr0.reverse) :+ atPos(tree.pos)(gen.mkAttributedIdent(r1.symbol))
309+
case _ => c.error(macroPos, "Internal error: unexpected tree encountered during ANF transform " + statsExpr); statsExpr
310+
}
311+
}
312+
262313
def anfLinearize(tree: Tree): Block = {
263314
val trees: List[Tree] = mode match {
264315
case Anf => anf._transformToList(tree)

src/main/scala/scala/async/internal/AsyncBase.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ abstract class AsyncBase {
4343
(body: c.Expr[T])
4444
(execContext: c.Expr[futureSystem.ExecContext]): c.Expr[futureSystem.Fut[T]] = {
4545
import c.universe._, c.internal._, decorators._
46-
val asyncMacro = AsyncMacro(c, self)
46+
val asyncMacro = AsyncMacro(c, self)(body.tree)
4747

48-
val code = asyncMacro.asyncTransform[T](body.tree, execContext.tree)(c.weakTypeTag[T])
48+
val code = asyncMacro.asyncTransform[T](execContext.tree)(c.weakTypeTag[T])
4949
AsyncUtils.vprintln(s"async state machine transform expands to:\n ${code}")
5050

5151
// Mark range positions for synthetic code as transparent to allow some wiggle room for overlapping ranges

src/main/scala/scala/async/internal/AsyncId.scala

+5-5
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,11 @@ object AsyncTestLV extends AsyncBase {
4141
* A trivial implementation of [[FutureSystem]] that performs computations
4242
* on the current thread. Useful for testing.
4343
*/
44+
class Box[A] {
45+
var a: A = _
46+
}
4447
object IdentityFutureSystem extends FutureSystem {
45-
46-
class Prom[A] {
47-
var a: A = _
48-
}
48+
type Prom[A] = Box[A]
4949

5050
type Fut[A] = A
5151
type ExecContext = Unit
@@ -57,7 +57,7 @@ object IdentityFutureSystem extends FutureSystem {
5757

5858
def execContext: Expr[ExecContext] = c.Expr[Unit](Literal(Constant(())))
5959

60-
def promType[A: WeakTypeTag]: Type = weakTypeOf[Prom[A]]
60+
def promType[A: WeakTypeTag]: Type = weakTypeOf[Box[A]]
6161
def tryType[A: WeakTypeTag]: Type = weakTypeOf[scala.util.Try[A]]
6262
def execContextType: Type = weakTypeOf[Unit]
6363

src/main/scala/scala/async/internal/AsyncMacro.scala

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
package scala.async.internal
22

33
object AsyncMacro {
4-
def apply(c0: reflect.macros.Context, base: AsyncBase): AsyncMacro { val c: c0.type } = {
4+
def apply(c0: reflect.macros.Context, base: AsyncBase)(body0: c0.Tree): AsyncMacro { val c: c0.type } = {
55
import language.reflectiveCalls
66
new AsyncMacro { self =>
77
val c: c0.type = c0
8+
val body: c.Tree = body0
89
// This member is required by `AsyncTransform`:
910
val asyncBase: AsyncBase = base
1011
// These members are required by `ExprBuilder`:
1112
val futureSystem: FutureSystem = base.futureSystem
1213
val futureSystemOps: futureSystem.Ops {val c: self.c.type} = futureSystem.mkOps(c)
14+
val containsAwait: c.Tree => Boolean = containsAwaitCached(body0)
1315
}
1416
}
1517
}
@@ -19,7 +21,10 @@ private[async] trait AsyncMacro
1921
with ExprBuilder with AsyncTransform with AsyncAnalysis with LiveVariables {
2022

2123
val c: scala.reflect.macros.Context
24+
val body: c.Tree
25+
val containsAwait: c.Tree => Boolean
2226

2327
lazy val macroPos = c.macroApplication.pos.makeTransparent
2428
def atMacroPos(t: c.Tree) = c.universe.atPos(macroPos)(t)
29+
2530
}

src/main/scala/scala/async/internal/AsyncTransform.scala

+4-4
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ trait AsyncTransform {
99

1010
val asyncBase: AsyncBase
1111

12-
def asyncTransform[T](body: Tree, execContext: Tree)
12+
def asyncTransform[T](execContext: Tree)
1313
(resultType: WeakTypeTag[T]): Tree = {
1414

1515
// We annotate the type of the whole expression as `T @uncheckedBounds` so as not to introduce
@@ -22,7 +22,7 @@ trait AsyncTransform {
2222
// Transform to A-normal form:
2323
// - no await calls in qualifiers or arguments,
2424
// - if/match only used in statement position.
25-
val anfTree0: Block = anfTransform(body)
25+
val anfTree0: Block = anfTransform(body, c.internal.enclosingOwner)
2626

2727
val anfTree = futureSystemOps.postAnfTransform(anfTree0)
2828

@@ -35,15 +35,15 @@ trait AsyncTransform {
3535
val stateMachine: ClassDef = {
3636
val body: List[Tree] = {
3737
val stateVar = ValDef(Modifiers(Flag.MUTABLE | Flag.PRIVATE | Flag.LOCAL), name.state, TypeTree(definitions.IntTpe), Literal(Constant(StateAssigner.Initial)))
38-
val result = ValDef(NoMods, name.result, TypeTree(futureSystemOps.promType[T](uncheckedBoundsResultTag)), futureSystemOps.createProm[T](uncheckedBoundsResultTag).tree)
38+
val resultAndAccessors = mkMutableField(futureSystemOps.promType[T](uncheckedBoundsResultTag), name.result, futureSystemOps.createProm[T](uncheckedBoundsResultTag).tree)
3939
val execContextValDef = ValDef(NoMods, name.execContext, TypeTree(), execContext)
4040

4141
val apply0DefDef: DefDef = {
4242
// We extend () => Unit so we can pass this class as the by-name argument to `Future.apply`.
4343
// See SI-1247 for the the optimization that avoids creation.
4444
DefDef(NoMods, name.apply, Nil, Nil, TypeTree(definitions.UnitTpe), Apply(Ident(name.apply), literalNull :: Nil))
4545
}
46-
List(emptyConstructor, stateVar, result, execContextValDef) ++ List(applyDefDefDummyBody, apply0DefDef)
46+
List(emptyConstructor, stateVar) ++ resultAndAccessors ++ List(execContextValDef) ++ List(applyDefDefDummyBody, apply0DefDef)
4747
}
4848

4949
val tryToUnit = appliedType(definitions.FunctionClass(1), futureSystemOps.tryType[Any], typeOf[Unit])

0 commit comments

Comments
 (0)