From 4d33d8ae99b7c19765d51dbbd1cc0cc09e679e89 Mon Sep 17 00:00:00 2001 From: Jason Zaugg Date: Fri, 20 Mar 2020 13:18:15 +1000 Subject: [PATCH] Use compiler integrated async phase under -Xasync --- .travis.yml | 4 +- README.md | 126 ++-- build.sbt | 2 +- pending/run/fallback0/MinimalScalaTest.scala | 0 src/main/scala/scala/async/Async.scala | 35 +- .../scala/async/FutureStateMachine.scala | 80 +++ .../scala/async/internal/AnfTransform.scala | 424 ------------ .../scala/async/internal/AsyncAnalysis.scala | 110 --- .../scala/async/internal/AsyncBase.scala | 78 --- .../scala/scala/async/internal/AsyncId.scala | 107 --- .../scala/async/internal/AsyncMacro.scala | 51 -- .../scala/async/internal/AsyncNames.scala | 121 ---- .../scala/async/internal/AsyncTransform.scala | 257 ------- .../scala/async/internal/AsyncUtils.scala | 26 - .../scala/async/internal/ExprBuilder.scala | 650 ------------------ .../scala/async/internal/FutureSystem.scala | 156 ----- .../scala/scala/async/internal/Lifter.scala | 179 ----- .../scala/async/internal/LiveVariables.scala | 313 --------- .../async/internal/ScalaConcurrentAsync.scala | 29 - .../scala/async/internal/StateAssigner.scala | 23 - .../scala/scala/async/internal/StateSet.scala | 38 - .../scala/async/internal/TransformUtils.scala | 590 ---------------- src/test/scala/scala/async/FutureSpec.scala | 541 +++++++++++++++ src/test/scala/scala/async/SmokeTest.scala | 32 + src/test/scala/scala/async/TestLatch.scala | 48 -- src/test/scala/scala/async/TestUtil.scala | 66 ++ .../scala/scala/async/TreeInterrogation.scala | 112 --- .../scala/async/neg/LocalClasses0Spec.scala | 42 -- .../scala/scala/async/neg/NakedAwait.scala | 183 ----- .../scala/scala/async/neg/SampleNegSpec.scala | 27 - src/test/scala/scala/async/package.scala | 90 --- .../async/run/SyncOptimizationSpec.scala | 40 -- .../scala/scala/async/run/WarningsSpec.scala | 105 --- .../async/run/anf/AnfTransformSpec.scala | 459 ------------- .../scala/async/run/await0/Await0Spec.scala | 81 --- .../scala/async/run/block0/AsyncSpec.scala | 65 -- .../scala/scala/async/run/block1/block1.scala | 49 -- .../async/run/exceptions/ExceptionsSpec.scala | 66 -- .../scala/async/run/futures/FutureSpec.scala | 560 --------------- .../scala/async/run/hygiene/Hygiene.scala | 92 --- .../scala/async/run/ifelse0/IfElse0.scala | 64 -- .../scala/async/run/ifelse0/WhileSpec.scala | 127 ---- .../scala/async/run/ifelse1/IfElse1.scala | 212 ------ .../scala/async/run/ifelse2/ifelse2.scala | 55 -- .../scala/async/run/ifelse3/IfElse3.scala | 58 -- .../scala/async/run/ifelse4/IfElse4.scala | 71 -- .../scala/async/run/late/LateExpansion.scala | 612 ----------------- .../scala/async/run/lazyval/LazyValSpec.scala | 37 - .../async/run/live/LiveVariablesSpec.scala | 299 -------- .../scala/scala/async/run/match0/Match0.scala | 154 ----- .../scala/async/run/nesteddef/NestedDef.scala | 106 --- .../scala/async/run/noawait/NoAwaitSpec.scala | 44 -- .../run/stackoverflow/StackOverflowSpec.scala | 36 - .../scala/async/run/toughtype/ToughType.scala | 362 ---------- .../uncheckedBounds/UncheckedBoundsSpec.scala | 47 -- 55 files changed, 818 insertions(+), 7523 deletions(-) delete mode 100644 pending/run/fallback0/MinimalScalaTest.scala create mode 100644 src/main/scala/scala/async/FutureStateMachine.scala delete mode 100644 src/main/scala/scala/async/internal/AnfTransform.scala delete mode 100644 src/main/scala/scala/async/internal/AsyncAnalysis.scala delete mode 100644 src/main/scala/scala/async/internal/AsyncBase.scala delete mode 100644 src/main/scala/scala/async/internal/AsyncId.scala delete mode 100644 src/main/scala/scala/async/internal/AsyncMacro.scala delete mode 100644 src/main/scala/scala/async/internal/AsyncNames.scala delete mode 100644 src/main/scala/scala/async/internal/AsyncTransform.scala delete mode 100644 src/main/scala/scala/async/internal/AsyncUtils.scala delete mode 100644 src/main/scala/scala/async/internal/ExprBuilder.scala delete mode 100644 src/main/scala/scala/async/internal/FutureSystem.scala delete mode 100644 src/main/scala/scala/async/internal/Lifter.scala delete mode 100644 src/main/scala/scala/async/internal/LiveVariables.scala delete mode 100644 src/main/scala/scala/async/internal/ScalaConcurrentAsync.scala delete mode 100644 src/main/scala/scala/async/internal/StateAssigner.scala delete mode 100644 src/main/scala/scala/async/internal/StateSet.scala delete mode 100644 src/main/scala/scala/async/internal/TransformUtils.scala create mode 100644 src/test/scala/scala/async/FutureSpec.scala create mode 100644 src/test/scala/scala/async/SmokeTest.scala delete mode 100644 src/test/scala/scala/async/TestLatch.scala create mode 100644 src/test/scala/scala/async/TestUtil.scala delete mode 100644 src/test/scala/scala/async/TreeInterrogation.scala delete mode 100644 src/test/scala/scala/async/neg/LocalClasses0Spec.scala delete mode 100644 src/test/scala/scala/async/neg/NakedAwait.scala delete mode 100644 src/test/scala/scala/async/neg/SampleNegSpec.scala delete mode 100644 src/test/scala/scala/async/package.scala delete mode 100644 src/test/scala/scala/async/run/SyncOptimizationSpec.scala delete mode 100644 src/test/scala/scala/async/run/WarningsSpec.scala delete mode 100644 src/test/scala/scala/async/run/anf/AnfTransformSpec.scala delete mode 100644 src/test/scala/scala/async/run/await0/Await0Spec.scala delete mode 100644 src/test/scala/scala/async/run/block0/AsyncSpec.scala delete mode 100644 src/test/scala/scala/async/run/block1/block1.scala delete mode 100644 src/test/scala/scala/async/run/exceptions/ExceptionsSpec.scala delete mode 100644 src/test/scala/scala/async/run/futures/FutureSpec.scala delete mode 100644 src/test/scala/scala/async/run/hygiene/Hygiene.scala delete mode 100644 src/test/scala/scala/async/run/ifelse0/IfElse0.scala delete mode 100644 src/test/scala/scala/async/run/ifelse0/WhileSpec.scala delete mode 100644 src/test/scala/scala/async/run/ifelse1/IfElse1.scala delete mode 100644 src/test/scala/scala/async/run/ifelse2/ifelse2.scala delete mode 100644 src/test/scala/scala/async/run/ifelse3/IfElse3.scala delete mode 100644 src/test/scala/scala/async/run/ifelse4/IfElse4.scala delete mode 100644 src/test/scala/scala/async/run/late/LateExpansion.scala delete mode 100644 src/test/scala/scala/async/run/lazyval/LazyValSpec.scala delete mode 100644 src/test/scala/scala/async/run/live/LiveVariablesSpec.scala delete mode 100644 src/test/scala/scala/async/run/match0/Match0.scala delete mode 100644 src/test/scala/scala/async/run/nesteddef/NestedDef.scala delete mode 100644 src/test/scala/scala/async/run/noawait/NoAwaitSpec.scala delete mode 100644 src/test/scala/scala/async/run/stackoverflow/StackOverflowSpec.scala delete mode 100644 src/test/scala/scala/async/run/toughtype/ToughType.scala delete mode 100644 src/test/scala/scala/async/run/uncheckedBounds/UncheckedBoundsSpec.scala diff --git a/.travis.yml b/.travis.yml index 2e0d04fc..79a052a4 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,8 +5,8 @@ import: scala/scala-dev:travis/default.yml language: scala scala: - - 2.12.11 - - 2.13.2 + - 2.12.12 + - 2.13.3 env: - ADOPTOPENJDK=8 diff --git a/README.md b/README.md index d5f08434..afffc79c 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,20 @@ # scala-async [![Build Status](https://travis-ci.org/scala/scala-async.svg?branch=master)](https://travis-ci.org/scala/scala-async) [](http://search.maven.org/#search%7Cga%7C1%7Cg%3Aorg.scala-lang.modules%20a%3Ascala-async_2.12) [](http://search.maven.org/#search%7Cga%7C1%7Cg%3Aorg.scala-lang.modules%20a%3Ascala-async_2.13) -## Supported Scala versions - -This branch (version series 0.10.x) targets Scala 2.12 and 2.13. `scala-async` is no longer maintained for older versions. +A DSL to enable a direct style of programming with when composing values wrapped in Scala `Future`s. ## Quick start To include scala-async in an existing project use the library published on Maven Central. For sbt projects add the following to your build definition - build.sbt or project/Build.scala: +### Use a modern Scala compiler + +As of scala-async 1.0, Scala 2.12.12+ or 2.13.3+ are required. + +### Add dependency + +#### SBT Example + ```scala libraryDependencies += "org.scala-lang.modules" %% "scala-async" % "0.10.0" libraryDependencies += "org.scala-lang" % "scala-reflect" % scalaVersion.value % Provided @@ -17,28 +23,58 @@ libraryDependencies += "org.scala-lang" % "scala-reflect" % scalaVersion.value % For Maven projects add the following to your (make sure to use the correct Scala version suffix to match your project’s Scala binary version): +#### Maven Example + ```scala - org.scala-lang.modules - scala-async_2.12 - 0.10.0 + org.scala-lang.modules + scala-async_2.13 + 1.0.0 - org.scala-lang - scala-reflect - 2.12.11 - provided + org.scala-lang + scala-reflect + 2.13.3 + provided ``` -After adding scala-async to your classpath, write your first `async` block: +### Enable compiler support for `async` + +Add the `-Xasync` to the Scala compiler options. + +#### SBT Example +```scala +scalaOptions += "-Xasync" +``` + +#### Maven Example + +```xml + + ... + + net.alchim31.maven + scala-maven-plugin + 4.4.0 + + + -Xasync + + + + ... + +``` + +### Start coding ```scala import scala.concurrent.ExecutionContext.Implicits.global import scala.async.Async.{async, await} val future = async { - val f1 = async { ...; true } + val f1: Future[Boolean] = async { ...; true } val f2 = async { ...; 42 } if (await(f1)) await(f2) else 0 } @@ -93,6 +129,22 @@ def combined: Future[Int] = async { } ``` +## Limitations + +### `await` must be directly in the control flow of the async expression + +The `await` cannot be nested under a local method, object, class or lambda: + +``` +async { + List(1).foreach { x => await(f(x) } // invali +} +``` + +### `await` must be not be nested within `try` / `catch` / `finally`. + +This implementation restriction may be lifted in future versions. + ## Comparison with direct use of `Future` API This computation could also be expressed by directly using the @@ -119,53 +171,3 @@ The `async` approach has two advantages over the use of required at each generator (`<-`) in the for-comprehension. This reduces the size of generated code, and can avoid boxing of intermediate results. - -## Comparison with CPS plugin - -The existing continuations (CPS) plugin for Scala can also be used -to provide a syntactic layer like `async`. This approach has been -used in Akka's [Dataflow Concurrency](http://doc.akka.io/docs/akka/2.3-M1/scala/dataflow.html) -(now deprecated in favour of this library). - -CPS-based rewriting of asynchronous code also produces a closure -for each suspension. It can also lead to type errors that are -difficult to understand. - -## How it works - - - The `async` macro analyses the block of code, looking for control - structures and locations of `await` calls. It then breaks the code - into 'chunks'. Each chunk contains a linear sequence of statements - that concludes with a branching decision, or with the registration - of a subsequent state handler as the continuation. - - Before this analysis and transformation, the program is normalized - into a form amenable to this manipulation. This is called the - "A Normal Form" (ANF), and roughly means that: - - `if` and `match` constructs are only used as statements; - they cannot be used as an expression. - - calls to `await` are not allowed in compound expressions. - - Identify vals, vars and defs that are accessed from multiple - states. These will be lifted out to fields in the state machine - object. - - Synthesize a class that holds: - - an integer representing the current state ID. - - the lifted definitions. - - an `apply(value: Try[Any]): Unit` method that will be - called on completion of each future. The behavior of - this method is determined by the current state. It records - the downcast result of the future in a field, and calls the - `resume()` method. - - the `resume(): Unit` method that switches on the current state - and runs the users code for one 'chunk', and either: - a) registers the state machine as the handler for the next future - b) completes the result Promise of the `async` block, if at the terminal state. - - an `apply(): Unit` method that starts the computation. - -## Limitations - - - See the [neg](https://github.com/scala/async/tree/master/src/test/scala/scala/async/neg) test cases - for constructs that are not allowed in an `async` block. - - See the [issue list](https://github.com/scala/async/issues?state=open) for which of these restrictions are planned - to be dropped in the future. - - See [#32](https://github.com/scala/async/issues/32) for why `await` is not possible in closures, and for suggestions on - ways to structure the code to work around this limitation. diff --git a/build.sbt b/build.sbt index 026c76fe..2745bdd4 100644 --- a/build.sbt +++ b/build.sbt @@ -4,13 +4,13 @@ ScalaModulePlugin.scalaModuleOsgiSettings name := "scala-async" libraryDependencies += "org.scala-lang" % "scala-reflect" % scalaVersion.value % "provided" -libraryDependencies += "org.scala-lang" % "scala-compiler" % scalaVersion.value % "test" // for ToolBox libraryDependencies += "junit" % "junit" % "4.12" % "test" libraryDependencies += "com.novocode" % "junit-interface" % "0.11" % "test" ScalaModulePlugin.enableOptimizer testOptions += Tests.Argument(TestFrameworks.JUnit, "-q", "-v", "-s") scalacOptions in Test ++= Seq("-Yrangepos") +scalacOptions ++= List("-deprecation" , "-Xasync") parallelExecution in Global := false diff --git a/pending/run/fallback0/MinimalScalaTest.scala b/pending/run/fallback0/MinimalScalaTest.scala deleted file mode 100644 index e69de29b..00000000 diff --git a/src/main/scala/scala/async/Async.scala b/src/main/scala/scala/async/Async.scala index e99891be..b4399e15 100644 --- a/src/main/scala/scala/async/Async.scala +++ b/src/main/scala/scala/async/Async.scala @@ -13,8 +13,9 @@ package scala.async import scala.language.experimental.macros -import scala.concurrent.{Future, ExecutionContext} +import scala.concurrent.{ExecutionContext, Future} import scala.annotation.compileTimeOnly +import scala.reflect.macros.whitebox /** * Async blocks provide a direct means to work with [[scala.concurrent.Future]]. @@ -50,7 +51,7 @@ object Async { * Run the block of code `body` asynchronously. `body` may contain calls to `await` when the results of * a `Future` are needed; this is translated into non-blocking code. */ - def async[T](body: => T)(implicit execContext: ExecutionContext): Future[T] = macro internal.ScalaConcurrentAsync.asyncImpl[T] + def async[T](body: => T)(implicit execContext: ExecutionContext): Future[T] = macro asyncImpl[T] /** * Non-blocking await the on result of `awaitable`. This may only be used directly within an enclosing `async` block. @@ -58,6 +59,34 @@ object Async { * Internally, this will register the remainder of the code in enclosing `async` block as a callback * in the `onComplete` handler of `awaitable`, and will *not* block a thread. */ - @compileTimeOnly("`await` must be enclosed in an `async` block") + @compileTimeOnly("[async] `await` must be enclosed in an `async` block") def await[T](awaitable: Future[T]): T = ??? // No implementation here, as calls to this are translated to `onComplete` by the macro. + + def asyncImpl[T: c.WeakTypeTag](c: whitebox.Context) + (body: c.Tree) + (execContext: c.Tree): c.Tree = { + import c.universe._ + if (!c.compilerSettings.contains("-Xasync")) { + c.abort(c.macroApplication.pos, "The async requires the compiler option -Xasync (supported only by Scala 2.12.12+ / 2.13.3+)") + } else try { + val awaitSym = typeOf[Async.type].decl(TermName("await")) + def mark(t: DefDef): Tree = { + import language.reflectiveCalls + c.internal.asInstanceOf[{ + def markForAsyncTransform(owner: Symbol, method: DefDef, awaitSymbol: Symbol, config: Map[String, AnyRef]): DefDef + }].markForAsyncTransform(c.internal.enclosingOwner, t, awaitSym, Map.empty) + } + val name = TypeName("stateMachine$async") + q""" + final class $name extends _root_.scala.async.FutureStateMachine(${execContext}) { + // FSM translated method + ${mark(q"""override def apply(tr$$async: _root_.scala.util.Try[_root_.scala.AnyRef]) = ${body}""")} + } + new $name().start() : ${c.macroApplication.tpe} + """ + } catch { + case e: ReflectiveOperationException => + c.abort(c.macroApplication.pos, "-Xasync is provided as a Scala compiler option, but the async macro is unable to call c.internal.markForAsyncTransform. " + e.getClass.getName + " " + e.getMessage) + } + } } diff --git a/src/main/scala/scala/async/FutureStateMachine.scala b/src/main/scala/scala/async/FutureStateMachine.scala new file mode 100644 index 00000000..48d2692b --- /dev/null +++ b/src/main/scala/scala/async/FutureStateMachine.scala @@ -0,0 +1,80 @@ +/* + * Scala (https://www.scala-lang.org) + * + * Copyright EPFL and Lightbend, Inc. + * + * Licensed under Apache License 2.0 + * (http://www.apache.org/licenses/LICENSE-2.0). + * + * See the NOTICE file distributed with this work for + * additional information regarding copyright ownership. + */ +package scala.async + +import java.util.Objects + +import scala.util.{Failure, Success, Try} +import scala.concurrent.{ExecutionContext, Future, Promise} + +/** The base class for state machines generated by the `scala.async.Async.async` macro. + * Not intended to be directly extended in user-written code. + */ +abstract class FutureStateMachine(execContext: ExecutionContext) extends Function1[Try[AnyRef], Unit] { + Objects.requireNonNull(execContext) + + type F = scala.concurrent.Future[AnyRef] + type R = scala.util.Try[AnyRef] + + private[this] val result$async: Promise[AnyRef] = Promise[AnyRef](); + private[this] var state$async: Int = 0 + + /** Retrieve the current value of the state variable */ + protected def state: Int = state$async + + /** Assign `i` to the state variable */ + protected def state_=(s: Int): Unit = state$async = s + + /** Complete the state machine with the given failure. */ + // scala-async accidentally started catching NonFatal exceptions in: + // https://github.com/scala/scala-async/commit/e3ff0382ae4e015fc69da8335450718951714982#diff-136ab0b6ecaee5d240cd109e2b17ccb2R411 + // This follows the new behaviour but should we fix the regression? + protected def completeFailure(t: Throwable): Unit = { + result$async.complete(Failure(t)) + } + + /** Complete the state machine with the given value. */ + protected def completeSuccess(value: AnyRef): Unit = { + result$async.complete(Success(value)) + } + + /** Register the state machine as a completion callback of the given future. */ + protected def onComplete(f: F): Unit = { + f.onComplete(this)(execContext) + } + + /** Extract the result of the given future if it is complete, or `null` if it is incomplete. */ + protected def getCompleted(f: F): Try[AnyRef] = { + if (f.isCompleted) { + f.value.get + } else null + } + + /** + * Extract the success value of the given future. If the state machine detects a failure it may + * complete the async block and return `this` as a sentinel value to indicate that the caller + * (the state machine dispatch loop) should immediately exit. + */ + protected def tryGet(tr: R): AnyRef = tr match { + case Success(value) => + value.asInstanceOf[AnyRef] + case Failure(throwable) => + completeFailure(throwable) + this // sentinel value to indicate the dispatch loop should exit. + } + + def start[T](): Future[T] = { + // This cast is safe because we know that `def apply` does not consult its argument when `state == 0`. + Future.unit.asInstanceOf[Future[AnyRef]].onComplete(this)(execContext) + result$async.future.asInstanceOf[Future[T]] + } +} diff --git a/src/main/scala/scala/async/internal/AnfTransform.scala b/src/main/scala/scala/async/internal/AnfTransform.scala deleted file mode 100644 index 86b347fb..00000000 --- a/src/main/scala/scala/async/internal/AnfTransform.scala +++ /dev/null @@ -1,424 +0,0 @@ -/* - * Scala (https://www.scala-lang.org) - * - * Copyright EPFL and Lightbend, Inc. - * - * Licensed under Apache License 2.0 - * (http://www.apache.org/licenses/LICENSE-2.0). - * - * See the NOTICE file distributed with this work for - * additional information regarding copyright ownership. - */ - -package scala.async.internal - -import scala.Predef._ -import scala.reflect.internal.util.Collections.map2 - -private[async] trait AnfTransform { - self: AsyncMacro => - - import c.universe._ - import Flag._ - import c.internal._ - import decorators._ - - def anfTransform(tree: Tree, owner: Symbol): Block = { - // Must prepend the () for issue #31. - val block = c.typecheck(atPos(tree.pos)(newBlock(List(Literal(Constant(()))), tree))).setType(tree.tpe) - - sealed abstract class AnfMode - case object Anf extends AnfMode - case object Linearizing extends AnfMode - - val tree1 = adjustTypeOfTranslatedPatternMatches(block, owner) - - var mode: AnfMode = Anf - - object trace { - private var indent = -1 - - private def indentString = " " * indent - - def apply[T](args: Any)(t: => T): T = { - def prefix = mode.toString.toLowerCase - indent += 1 - def oneLine(s: Any) = s.toString.replaceAll("""\n""", "\\\\n").take(127) - try { - if(AsyncUtils.trace) - AsyncUtils.trace(s"$indentString$prefix(${oneLine(args)})") - val result = t - if(AsyncUtils.trace) - AsyncUtils.trace(s"$indentString= ${oneLine(result)}") - result - } finally { - indent -= 1 - } - } - } - - typingTransform(tree1, owner)((tree, api) => { - def blockToList(tree: Tree): List[Tree] = tree match { - case Block(stats, expr) => stats :+ expr - case t => t :: Nil - } - - def listToBlock(trees: List[Tree]): Block = trees match { - case trees @ (init :+ last) => - val pos = trees.map(_.pos).reduceLeft{ - (p, q) => - if (!q.isRange) p - else if (p.isRange) p.withStart(p.start.min(q.start)).withEnd(p.end.max(q.end)) - else q - } - newBlock(init, last).setType(last.tpe).setPos(pos) - } - - object linearize { - def transformToList(tree: Tree): List[Tree] = { - mode = Linearizing; blockToList(api.recur(tree)) - } - - def transformToBlock(tree: Tree): Block = listToBlock(transformToList(tree)) - - def _transformToList(tree: Tree): List[Tree] = trace(tree) { - val stats :+ expr = _anf.transformToList(tree) - def statsExprUnit = - stats :+ expr :+ api.typecheck(atPos(expr.pos)(Literal(Constant(())))) - def statsExprThrow = - stats :+ expr :+ api.typecheck(atPos(expr.pos)(Throw(Apply(Select(New(gen.mkAttributedRef(defn.IllegalStateExceptionClass)), termNames.CONSTRUCTOR), Nil)))) - expr match { - case Apply(fun, args) if isAwait(fun) => - val valDef = defineVal(name.await(), expr, tree.pos) - val ref = gen.mkAttributedStableRef(valDef.symbol).setType(tree.tpe) - val ref1 = if (ref.tpe =:= definitions.UnitTpe) - // https://github.com/scala/async/issues/74 - // Use a cast to hide from "pure expression does nothing" error - // - // TODO avoid creating a ValDef for the result of this await to avoid this tree shape altogether. - // This will require some deeper changes to the later parts of the macro which currently assume regular - // tree structure around `await` calls. - api.typecheck(atPos(tree.pos)(gen.mkCast(ref, definitions.UnitTpe))) - else ref - stats :+ valDef :+ atPos(tree.pos)(ref1) - - case If(cond, thenp, elsep) => - // If we run the ANF transform post patmat, deal with trees like `(if (cond) jump1(){String} else jump2(){String}){String}` - // as though it was typed with `Unit`. - def isPatMatGeneratedJump(t: Tree): Boolean = t match { - case Block(_, expr) => isPatMatGeneratedJump(expr) - case If(_, thenp, elsep) => isPatMatGeneratedJump(thenp) && isPatMatGeneratedJump(elsep) - case _: Apply if isLabel(t.symbol) => true - case _ => false - } - if (isPatMatGeneratedJump(expr)) { - internal.setType(expr, definitions.UnitTpe) - } - // if type of if-else is Unit don't introduce assignment, - // but add Unit value to bring it into form expected by async transform - if (expr.tpe =:= definitions.UnitTpe) { - statsExprUnit - } else if (expr.tpe =:= definitions.NothingTpe) { - statsExprThrow - } else { - val varDef = defineVar(name.ifRes(), expr.tpe, tree.pos) - def typedAssign(lhs: Tree) = - api.typecheck(atPos(lhs.pos)(Assign(Ident(varDef.symbol), mkAttributedCastPreservingAnnotations(lhs, tpe(varDef.symbol))))) - - def branchWithAssign(t: Tree): Tree = { - t match { - case MatchEnd(ld) => - deriveLabelDef(ld, branchWithAssign) - case blk @ Block(thenStats, thenExpr) => - treeCopy.Block(blk, thenStats, branchWithAssign(thenExpr)).setType(definitions.UnitTpe) - case _ => - typedAssign(t) - } - } - val ifWithAssign = treeCopy.If(tree, cond, branchWithAssign(thenp), branchWithAssign(elsep)).setType(definitions.UnitTpe) - stats :+ varDef :+ ifWithAssign :+ atPos(tree.pos)(gen.mkAttributedStableRef(varDef.symbol)).setType(tree.tpe) - } - case ld @ LabelDef(name, params, rhs) => - if (ld.symbol.info.resultType.typeSymbol == definitions.UnitClass) - statsExprUnit - else - stats :+ expr - - case Match(scrut, cases) => - // if type of match is Unit don't introduce assignment, - // but add Unit value to bring it into form expected by async transform - if (expr.tpe =:= definitions.UnitTpe) { - statsExprUnit - } else if (expr.tpe =:= definitions.NothingTpe) { - statsExprThrow - } else { - val varDef = defineVar(name.matchRes(), expr.tpe, tree.pos) - def typedAssign(lhs: Tree) = - api.typecheck(atPos(lhs.pos)(Assign(Ident(varDef.symbol), mkAttributedCastPreservingAnnotations(lhs, tpe(varDef.symbol))))) - val casesWithAssign = cases map { - case cd@CaseDef(pat, guard, body) => - def bodyWithAssign(t: Tree): Tree = { - t match { - case MatchEnd(ld) => deriveLabelDef(ld, bodyWithAssign) - case b@Block(caseStats, caseExpr) => treeCopy.Block(b, caseStats, bodyWithAssign(caseExpr)).setType(definitions.UnitTpe) - case _ => typedAssign(t) - } - } - treeCopy.CaseDef(cd, pat, guard, bodyWithAssign(body)).setType(definitions.UnitTpe) - } - val matchWithAssign = treeCopy.Match(tree, scrut, casesWithAssign).setType(definitions.UnitTpe) - require(matchWithAssign.tpe != null, matchWithAssign) - stats :+ varDef :+ matchWithAssign :+ atPos(tree.pos)(gen.mkAttributedStableRef(varDef.symbol)).setType(tree.tpe) - } - case _ => - stats :+ expr - } - } - - def defineVar(name: TermName, tp: Type, pos: Position): ValDef = { - val sym = api.currentOwner.newTermSymbol(name, pos, MUTABLE | SYNTHETIC).setInfo(uncheckedBounds(tp)) - valDef(sym, mkZero(uncheckedBounds(tp))).setType(NoType).setPos(pos) - } - } - - def defineVal(name: TermName, lhs: Tree, pos: Position): ValDef = { - val sym = api.currentOwner.newTermSymbol(name, pos, SYNTHETIC).setInfo(uncheckedBounds(lhs.tpe)) - internal.valDef(sym, internal.changeOwner(lhs, api.currentOwner, sym)).setType(NoType).setPos(pos) - } - - object _anf { - def transformToList(tree: Tree): List[Tree] = { - mode = Anf; blockToList(api.recur(tree)) - } - - def _transformToList(tree: Tree): List[Tree] = trace(tree) { - if (!containsAwait(tree)) { - tree match { - case Block(stats, expr) => - // avoids nested block in `while(await(false)) ...`. - // TODO I think `containsAwait` really should return true if the code contains a label jump to an enclosing - // while/doWhile and there is an await *anywhere* inside that construct. - stats :+ expr - case _ => List(tree) - } - } else tree match { - case Select(qual, sel) => - val stats :+ expr = linearize.transformToList(qual) - stats :+ treeCopy.Select(tree, expr, sel) - - case Throw(expr) => - val stats :+ expr1 = linearize.transformToList(expr) - stats :+ treeCopy.Throw(tree, expr1) - - case Typed(expr, tpt) => - val stats :+ expr1 = linearize.transformToList(expr) - stats :+ treeCopy.Typed(tree, expr1, tpt) - - case Applied(fun, targs, argss) if argss.nonEmpty => - // we can assume that no await call appears in a by-name argument position, - // this has already been checked. - val funStats :+ simpleFun = linearize.transformToList(fun) - val (argStatss, argExprss): (List[List[List[Tree]]], List[List[Tree]]) = - mapArgumentss[List[Tree]](fun, argss) { - case Arg(expr, byName, _) if byName /*|| isPure(expr) TODO */ => (Nil, expr) - case Arg(expr, _, argName) => - linearize.transformToList(expr) match { - case stats :+ expr1 => - val valDef = defineVal(name.freshen(argName), expr1, expr1.pos) - require(valDef.tpe != null, valDef) - val stats1 = stats :+ valDef - (stats1, atPos(tree.pos.makeTransparent)(gen.stabilize(gen.mkAttributedIdent(valDef.symbol)))) - } - } - - def copyApplied(tree: Tree, depth: Int): Tree = { - tree match { - case TypeApply(_, targs) => treeCopy.TypeApply(tree, simpleFun, targs) - case _ if depth == 0 => simpleFun - case Apply(fun, args) => - val newTypedArgs = map2(args.map(_.pos), argExprss(depth - 1))((pos, arg) => api.typecheck(atPos(pos)(arg))) - treeCopy.Apply(tree, copyApplied(fun, depth - 1), newTypedArgs) - } - } - - val typedNewApply = copyApplied(tree, argss.length) - - funStats ++ argStatss.flatten.flatten :+ typedNewApply - - case Block(stats, expr) => - val stats1 = stats.flatMap(linearize.transformToList).filterNot(isLiteralUnit) - val exprs1 = linearize.transformToList(expr) - val trees = stats1 ::: exprs1 - def groupsEndingWith[T](ts: List[T])(f: T => Boolean): List[List[T]] = if (ts.isEmpty) Nil else { - ts.indexWhere(f) match { - case -1 => List(ts) - case i => - val (ts1, ts2) = ts.splitAt(i + 1) - ts1 :: groupsEndingWith(ts2)(f) - } - } - val matchGroups = groupsEndingWith(trees){ case MatchEnd(_) => true; case _ => false } - val trees1 = matchGroups.flatMap(eliminateMatchEndLabelParameter) - val result = trees1 flatMap { - case Block(stats, expr) => stats :+ expr - case t => t :: Nil - } - result - - case ValDef(mods, name, tpt, rhs) => - if (containsAwait(rhs)) { - val stats :+ expr = linearize.transformToList(rhs) - stats.foreach(_.changeOwner(api.currentOwner, api.currentOwner.owner)) - stats :+ treeCopy.ValDef(tree, mods, name, tpt, expr) - } else List(tree) - - case Assign(lhs, rhs) => - val stats :+ expr = linearize.transformToList(rhs) - stats :+ treeCopy.Assign(tree, lhs, expr) - - case If(cond, thenp, elsep) => - val condStats :+ condExpr = linearize.transformToList(cond) - val thenBlock = linearize.transformToBlock(thenp) - val elseBlock = linearize.transformToBlock(elsep) - condStats :+ treeCopy.If(tree, condExpr, thenBlock, elseBlock) - - case Match(scrut, cases) => - val scrutStats :+ scrutExpr = linearize.transformToList(scrut) - val caseDefs = cases map { - case CaseDef(pat, guard, body) => - // extract local variables for all names bound in `pat`, and rewrite `body` - // to refer to these. - // TODO we can move this into ExprBuilder once we get rid of `AsyncDefinitionUseAnalyzer`. - val block = linearize.transformToBlock(body) - val (valDefs, mappings) = (pat collect { - case b@Bind(bindName, _) => - val vd = defineVal(name.freshen(bindName.toTermName), gen.mkAttributedStableRef(b.symbol).setPos(b.pos), b.pos) - vd.symbol.updateAttachment(SyntheticBindVal) - (vd, (b.symbol, vd.symbol)) - }).unzip - val (from, to) = mappings.unzip - val b@Block(stats1, expr1) = block.substituteSymbols(from, to).asInstanceOf[Block] - val newBlock = treeCopy.Block(b, valDefs ++ stats1, expr1) - treeCopy.CaseDef(tree, pat, guard, newBlock) - } - scrutStats :+ treeCopy.Match(tree, scrutExpr, caseDefs) - - case LabelDef(name, params, rhs) => - if (tree.symbol.info.typeSymbol == definitions.UnitClass) - List(treeCopy.LabelDef(tree, name, params, api.typecheck(newBlock(linearize.transformToList(rhs), Literal(Constant(()))))).setSymbol(tree.symbol)) - else - List(treeCopy.LabelDef(tree, name, params, api.typecheck(listToBlock(linearize.transformToList(rhs)))).setSymbol(tree.symbol)) - - case TypeApply(fun, targs) => - val funStats :+ simpleFun = linearize.transformToList(fun) - funStats :+ treeCopy.TypeApply(tree, simpleFun, targs) - - case _ => - List(tree) - } - } - } - - // Replace the label parameters on `matchEnd` with use of a `matchRes` temporary variable - // - // CaseDefs are translated to labels without parameters. A terminal label, `matchEnd`, accepts - // a parameter which is the result of the match (this is regular, so even Unit-typed matches have this). - // - // For our purposes, it is easier to: - // - extract a `matchRes` variable - // - rewrite the terminal label def to take no parameters, and instead read this temp variable - // - change jumps to the terminal label to an assignment and a no-arg label application - def eliminateMatchEndLabelParameter(statsExpr: List[Tree]): List[Tree] = { - import internal.{methodType, setInfo} - val caseDefToMatchResult = collection.mutable.Map[Symbol, Symbol]() - - val matchResults = collection.mutable.Buffer[Tree]() - def modifyLabelDef(ld: LabelDef): (Tree, Tree) = { - val param = ld.params.head - val ld2 = if (ld.params.head.tpe.typeSymbol == definitions.UnitClass) { - // Unit typed match: eliminate the label def parameter, but don't create a matchres temp variable to - // store the result for cleaner generated code. - caseDefToMatchResult(ld.symbol) = NoSymbol - val rhs2 = substituteTrees(ld.rhs, param.symbol :: Nil, api.typecheck(literalUnit) :: Nil) - (treeCopy.LabelDef(ld, ld.name, Nil, api.typecheck(literalUnit)), rhs2) - } else { - // Otherwise, create the matchres var. We'll callers of the label def below. - // Remember: we're iterating through the statement sequence in reverse, so we'll get - // to the LabelDef and mutate `matchResults` before we'll get to its callers. - val matchResult = linearize.defineVar(name.matchRes(), param.tpe, ld.pos) - matchResults += matchResult - caseDefToMatchResult(ld.symbol) = matchResult.symbol - val rhs2 = ld.rhs.substituteSymbols(param.symbol :: Nil, matchResult.symbol :: Nil) - (treeCopy.LabelDef(ld, ld.name, Nil, api.typecheck(literalUnit)), rhs2) - } - setInfo(ld.symbol, methodType(Nil, definitions.UnitTpe)) - ld2 - } - val statsExpr0 = statsExpr.reverse.flatMap { - case ld @ LabelDef(_, param :: Nil, _) => - val (ld1, after) = modifyLabelDef(ld) - List(after, ld1) - case a @ ValDef(mods, name, tpt, ld @ LabelDef(_, param :: Nil, _)) => - val (ld1, after) = modifyLabelDef(ld) - List(treeCopy.ValDef(a, mods, name, tpt, after), ld1) - case t => - if (caseDefToMatchResult.isEmpty) t :: Nil - else typingTransform(t)((tree, api) => { - def typedPos(pos: Position)(t: Tree): Tree = - api.typecheck(atPos(pos)(t)) - tree match { - case Apply(fun, arg :: Nil) if isLabel(fun.symbol) && caseDefToMatchResult.contains(fun.symbol) => - val temp = caseDefToMatchResult(fun.symbol) - if (temp == NoSymbol) - typedPos(tree.pos)(newBlock(api.recur(arg) :: Nil, treeCopy.Apply(tree, fun, Nil))) - else - // setType needed for LateExpansion.shadowingRefinedType test case. There seems to be an inconsistency - // in the trees after pattern matcher. - // TODO miminize the problem in patmat and fix in scalac. - typedPos(tree.pos)(newBlock(Assign(Ident(temp), api.recur(internal.setType(arg, fun.tpe.paramLists.head.head.info))) :: Nil, treeCopy.Apply(tree, fun, Nil))) - case Block(stats, expr: Apply) if isLabel(expr.symbol) => - api.default(tree) match { - case Block(stats0, Block(stats1, expr1)) => - // flatten the block returned by `case Apply` above into the enclosing block for - // cleaner generated code. - treeCopy.Block(tree, stats0 ::: stats1, expr1) - case t => t - } - case _ => - api.default(tree) - } - }) :: Nil - } - matchResults.toList match { - case _ if caseDefToMatchResult.isEmpty => - statsExpr // return the original trees if nothing changed - case Nil => - statsExpr0.reverse :+ literalUnit // must have been a unit-typed match, no matchRes variable to definne or refer to - case r1 :: Nil => - // { var matchRes = _; ....; matchRes } - (r1 +: statsExpr0.reverse) :+ atPos(tree.pos)(gen.mkAttributedIdent(r1.symbol)) - case _ => c.error(macroPos, "Internal error: unexpected tree encountered during ANF transform " + statsExpr); statsExpr - } - } - - def anfLinearize(tree: Tree): Block = { - val trees: List[Tree] = mode match { - case Anf => _anf._transformToList(tree) - case Linearizing => linearize._transformToList(tree) - } - listToBlock(trees) - } - - tree match { - case _: ValDef | _: DefDef | _: Function | _: ClassDef | _: TypeDef => - api.atOwner(tree.symbol)(anfLinearize(tree)) - case _: ModuleDef => - api.atOwner(tree.symbol.asModule.moduleClass orElse tree.symbol)(anfLinearize(tree)) - case _ => - anfLinearize(tree) - } - }).asInstanceOf[Block] - } -} - -object SyntheticBindVal diff --git a/src/main/scala/scala/async/internal/AsyncAnalysis.scala b/src/main/scala/scala/async/internal/AsyncAnalysis.scala deleted file mode 100644 index cb5a09fa..00000000 --- a/src/main/scala/scala/async/internal/AsyncAnalysis.scala +++ /dev/null @@ -1,110 +0,0 @@ -/* - * Scala (https://www.scala-lang.org) - * - * Copyright EPFL and Lightbend, Inc. - * - * Licensed under Apache License 2.0 - * (http://www.apache.org/licenses/LICENSE-2.0). - * - * See the NOTICE file distributed with this work for - * additional information regarding copyright ownership. - */ - -package scala.async.internal - -import scala.collection.mutable.ListBuffer - -trait AsyncAnalysis { - self: AsyncMacro => - - import c.universe._ - - /** - * Analyze the contents of an `async` block in order to: - * - Report unsupported `await` calls under nested templates, functions, by-name arguments. - * - * Must be called on the original tree, not on the ANF transformed tree. - */ - def reportUnsupportedAwaits(tree: Tree): Unit = { - val analyzer = new UnsupportedAwaitAnalyzer - analyzer.traverse(tree) - // analyzer.hasUnsupportedAwaits // XB: not used?! - } - - private class UnsupportedAwaitAnalyzer extends AsyncTraverser { - var hasUnsupportedAwaits = false - - override def nestedClass(classDef: ClassDef): Unit = { - val kind = if (classDef.symbol.asClass.isTrait) "trait" else "class" - reportUnsupportedAwait(classDef, s"nested $kind") - } - - override def nestedModule(module: ModuleDef): Unit = { - reportUnsupportedAwait(module, "nested object") - } - - override def nestedMethod(defDef: DefDef): Unit = { - reportUnsupportedAwait(defDef, "nested method") - } - - override def byNameArgument(arg: Tree): Unit = { - reportUnsupportedAwait(arg, "by-name argument") - } - - override def function(function: Function): Unit = { - reportUnsupportedAwait(function, "nested function") - } - - override def patMatFunction(tree: Match): Unit = { - reportUnsupportedAwait(tree, "nested function") - } - - override def traverse(tree: Tree): Unit = { - tree match { - case Try(_, _, _) if containsAwait(tree) => - reportUnsupportedAwait(tree, "try/catch") - super.traverse(tree) - case Return(_) => - c.abort(tree.pos, "return is illegal within a async block") - case DefDef(mods, _, _, _, _, _) if mods.hasFlag(Flag.LAZY) && containsAwait(tree) => - reportUnsupportedAwait(tree, "lazy val initializer") - case ValDef(mods, _, _, _) if mods.hasFlag(Flag.LAZY) && containsAwait(tree) => - reportUnsupportedAwait(tree, "lazy val initializer") - case CaseDef(_, guard, _) if guard exists isAwait => - // TODO lift this restriction - reportUnsupportedAwait(tree, "pattern guard") - case _ => - super.traverse(tree) - } - } - - /** - * @return true, if the tree contained an unsupported await. - */ - private def reportUnsupportedAwait(tree: Tree, whyUnsupported: String): Boolean = { - val badAwaits = ListBuffer[Tree]() - object traverser extends Traverser { - override def traverse(tree: Tree): Unit = { - if (!isAsync(tree)) - super.traverse(tree) - tree match { - case rt: RefTree if isAwait(rt) => - badAwaits += rt - case _ => - } - } - } - traverser(tree) - badAwaits foreach { - tree => - reportError(tree.pos, s"await must not be used under a $whyUnsupported.") - } - badAwaits.nonEmpty - } - - private def reportError(pos: Position, msg: String): Unit = { - hasUnsupportedAwaits = true - c.abort(pos, msg) - } - } -} diff --git a/src/main/scala/scala/async/internal/AsyncBase.scala b/src/main/scala/scala/async/internal/AsyncBase.scala deleted file mode 100644 index b7de62b5..00000000 --- a/src/main/scala/scala/async/internal/AsyncBase.scala +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Scala (https://www.scala-lang.org) - * - * Copyright EPFL and Lightbend, Inc. - * - * Licensed under Apache License 2.0 - * (http://www.apache.org/licenses/LICENSE-2.0). - * - * See the NOTICE file distributed with this work for - * additional information regarding copyright ownership. - */ - -package scala.async.internal - -import scala.annotation.compileTimeOnly -import scala.reflect.macros.whitebox -import scala.reflect.api.Universe - -/** - * A base class for the `async` macro. Subclasses must provide: - * - * - Concrete types for a given future system - * - Tree manipulations to create and complete the equivalent of Future and Promise - * in that system. - * - The `async` macro declaration itself, and a forwarder for the macro implementation. - * (The latter is temporarily needed to workaround bug SI-6650 in the macro system) - * - * The default implementation, [[scala.async.Async]], binds the macro to `scala.concurrent._`. - */ -abstract class AsyncBase { - self => - - type FS <: FutureSystem - val futureSystem: FS - - /** - * A call to `await` must be nested in an enclosing `async` block. - * - * A call to `await` does not block the current thread, rather it is a delimiter - * used by the enclosing `async` macro. Code following the `await` - * call is executed asynchronously, when the argument of `await` has been completed. - * - * @param awaitable the future from which a value is awaited. - * @tparam T the type of that value. - * @return the value. - */ - @compileTimeOnly("`await` must be enclosed in an `async` block") - def await[T](awaitable: futureSystem.Fut[T]): T = ??? - - def asyncImpl[T: c.WeakTypeTag](c: whitebox.Context) - (body: c.Expr[T]) - (execContext: c.Expr[futureSystem.ExecContext]): c.Expr[futureSystem.Fut[T]] = { - import c.internal._, decorators._ - val asyncMacro = AsyncMacro(c, self)(body.tree) - - val code = asyncMacro.asyncTransform[T](execContext.tree)(c.weakTypeTag[T]) - AsyncUtils.vprintln(s"async state machine transform expands to:\n $code") - - // Mark range positions for synthetic code as transparent to allow some wiggle room for overlapping ranges - for (t <- code) t.setPos(t.pos.makeTransparent) - c.Expr[futureSystem.Fut[T]](code) - } - - protected[async] def asyncMethod(u: Universe)(asyncMacroSymbol: u.Symbol): u.Symbol = { - import u._ - if (asyncMacroSymbol == null) NoSymbol - else asyncMacroSymbol.owner.typeSignature.member(TermName("async")) - } - - protected[async] def awaitMethod(u: Universe)(asyncMacroSymbol: u.Symbol): u.Symbol = { - import u._ - if (asyncMacroSymbol == null) NoSymbol - else asyncMacroSymbol.owner.typeSignature.member(TermName("await")) - } - - protected[async] def nullOut(u: Universe)(name: u.Expr[String], v: u.Expr[Any]): u.Expr[Unit] = - u.reify { () } -} diff --git a/src/main/scala/scala/async/internal/AsyncId.scala b/src/main/scala/scala/async/internal/AsyncId.scala deleted file mode 100644 index aee3360f..00000000 --- a/src/main/scala/scala/async/internal/AsyncId.scala +++ /dev/null @@ -1,107 +0,0 @@ -/* - * Scala (https://www.scala-lang.org) - * - * Copyright EPFL and Lightbend, Inc. - * - * Licensed under Apache License 2.0 - * (http://www.apache.org/licenses/LICENSE-2.0). - * - * See the NOTICE file distributed with this work for - * additional information regarding copyright ownership. - */ - -package scala.async.internal - -import language.experimental.macros -import scala.reflect.macros.whitebox -import scala.reflect.api.Universe - -object AsyncId extends AsyncBase { - lazy val futureSystem = IdentityFutureSystem - type FS = IdentityFutureSystem.type - - def async[T](body: => T): T = macro asyncIdImpl[T] - - def asyncIdImpl[T: c.WeakTypeTag](c: whitebox.Context)(body: c.Expr[T]): c.Expr[T] = asyncImpl[T](c)(body)(c.literalUnit) -} - -object AsyncTestLV extends AsyncBase { - lazy val futureSystem = IdentityFutureSystem - type FS = IdentityFutureSystem.type - - def async[T](body: T): T = macro asyncIdImpl[T] - - def asyncIdImpl[T: c.WeakTypeTag](c: whitebox.Context)(body: c.Expr[T]): c.Expr[T] = asyncImpl[T](c)(body)(c.literalUnit) - - var log: List[(String, Any)] = Nil - def assertNulledOut(a: Any): Unit = assert(log.exists(_._2 == a), AsyncTestLV.log) - def assertNotNulledOut(a: Any): Unit = assert(!log.exists(_._2 == a), AsyncTestLV.log) - def clear(): Unit = log = Nil - - def apply(name: String, v: Any): Unit = - log ::= (name -> v) - - protected[async] override def nullOut(u: Universe)(name: u.Expr[String], v: u.Expr[Any]): u.Expr[Unit] = - u.reify { scala.async.internal.AsyncTestLV(name.splice, v.splice) } -} - -/** - * A trivial implementation of [[FutureSystem]] that performs computations - * on the current thread. Useful for testing. - */ -class Box[A] { - var a: A = _ -} -object IdentityFutureSystem extends FutureSystem { - type Prom[A] = Box[A] - - type Fut[A] = A - type ExecContext = Unit - type Tryy[A] = scala.util.Try[A] - - def mkOps(c0: whitebox.Context): Ops {val c: c0.type} = new Ops { - val c: c0.type = c0 - import c.universe._ - - def execContext: Expr[ExecContext] = c.Expr[Unit](Literal(Constant(()))) - - def promType[A: WeakTypeTag]: Type = weakTypeOf[Box[A]] - def tryType[A: WeakTypeTag]: Type = weakTypeOf[scala.util.Try[A]] - def execContextType: Type = weakTypeOf[Unit] - - def createProm[A: WeakTypeTag]: Expr[Prom[A]] = reify { - new Prom[A]() - } - - def promiseToFuture[A: WeakTypeTag](prom: Expr[Prom[A]]) = reify { - prom.splice.a - } - - def future[A: WeakTypeTag](t: Expr[A])(execContext: Expr[ExecContext]) = t - - def onComplete[A, U](future: Expr[Fut[A]], fun: Expr[Tryy[A] => U], - execContext: Expr[ExecContext]): Expr[Unit] = reify { - fun.splice.apply(util.Success(future.splice)) - c.Expr[Unit](Literal(Constant(()))).splice - } - - def completeProm[A](prom: Expr[Prom[A]], value: Expr[Tryy[A]]): Expr[Unit] = reify { - prom.splice.a = value.splice.get - c.Expr[Unit](Literal(Constant(()))).splice - } - - def tryyIsFailure[A](tryy: Expr[Tryy[A]]): Expr[Boolean] = reify { - tryy.splice.isFailure - } - - def tryyGet[A](tryy: Expr[Tryy[A]]): Expr[A] = reify { - tryy.splice.get - } - def tryySuccess[A: WeakTypeTag](a: Expr[A]): Expr[Tryy[A]] = reify { - scala.util.Success[A](a.splice) - } - def tryyFailure[A: WeakTypeTag](a: Expr[Throwable]): Expr[Tryy[A]] = reify { - scala.util.Failure[A](a.splice) - } - } -} diff --git a/src/main/scala/scala/async/internal/AsyncMacro.scala b/src/main/scala/scala/async/internal/AsyncMacro.scala deleted file mode 100644 index 16150c6f..00000000 --- a/src/main/scala/scala/async/internal/AsyncMacro.scala +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Scala (https://www.scala-lang.org) - * - * Copyright EPFL and Lightbend, Inc. - * - * Licensed under Apache License 2.0 - * (http://www.apache.org/licenses/LICENSE-2.0). - * - * See the NOTICE file distributed with this work for - * additional information regarding copyright ownership. - */ - -package scala.async.internal - -object AsyncMacro { - def apply(c0: reflect.macros.whitebox.Context, base: AsyncBase)(body0: c0.Tree): AsyncMacro { val c: c0.type } = { - // Use an attachment on RootClass as a sneaky place for a per-Global cache - val att = c0.internal.attachments(c0.universe.rootMirror.RootClass) - val names = att.get[AsyncNames[_]].getOrElse { - val names = new AsyncNames[c0.universe.type](c0.universe) - att.update(names) - names - } - - new AsyncMacro { self => - val c: c0.type = c0 - val asyncNames: AsyncNames[c.universe.type] = names.asInstanceOf[AsyncNames[c.universe.type]] - val body: c.Tree = body0 - // This member is required by `AsyncTransform`: - val asyncBase: AsyncBase = base - // These members are required by `ExprBuilder`: - val futureSystem: FutureSystem = base.futureSystem - val futureSystemOps: futureSystem.Ops {val c: self.c.type} = futureSystem.mkOps(c) - var containsAwait: c.Tree => Boolean = containsAwaitCached(body0) - } - } -} - -private[async] trait AsyncMacro - extends AnfTransform with TransformUtils with Lifter - with ExprBuilder with AsyncTransform with AsyncAnalysis with LiveVariables { - - val c: scala.reflect.macros.whitebox.Context - val body: c.Tree - var containsAwait: c.Tree => Boolean - val asyncNames: AsyncNames[c.universe.type] - - lazy val macroPos: c.universe.Position = c.macroApplication.pos.makeTransparent - def atMacroPos(t: c.Tree): c.Tree = c.universe.atPos(macroPos)(t) - -} diff --git a/src/main/scala/scala/async/internal/AsyncNames.scala b/src/main/scala/scala/async/internal/AsyncNames.scala deleted file mode 100644 index 1828aa55..00000000 --- a/src/main/scala/scala/async/internal/AsyncNames.scala +++ /dev/null @@ -1,121 +0,0 @@ -/* - * Scala (https://www.scala-lang.org) - * - * Copyright EPFL and Lightbend, Inc. - * - * Licensed under Apache License 2.0 - * (http://www.apache.org/licenses/LICENSE-2.0). - * - * See the NOTICE file distributed with this work for - * additional information regarding copyright ownership. - */ - -package scala.async.internal - -import java.util.concurrent.atomic.AtomicInteger - -import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer -import scala.reflect.api.Names - -/** - * A per-global cache of names needed by the Async macro. - */ -final class AsyncNames[U <: Names with Singleton](val u: U) { - self => - import u._ - - abstract class NameCache[N <: U#Name](base: String) { - val cached = new ArrayBuffer[N]() - protected def newName(s: String): N - def apply(i: Int): N = { - if (cached.isDefinedAt(i)) cached(i) - else { - assert(cached.length == i) - val name = newName(freshenString(base, i)) - cached += name - name - } - } - } - - final class TermNameCache(base: String) extends NameCache[U#TermName](base) { - override protected def newName(s: String): U#TermName = TermName(s) - } - final class TypeNameCache(base: String) extends NameCache[U#TypeName](base) { - override protected def newName(s: String): U#TypeName = TypeName(s) - } - private val matchRes: TermNameCache = new TermNameCache("match") - private val ifRes: TermNameCache = new TermNameCache("if") - private val await: TermNameCache = new TermNameCache("await") - - private val result = TermName("result$async") - private val completed: TermName = TermName("completed$async") - private val apply = TermName("apply") - private val stateMachine = TermName("stateMachine$async") - private val stateMachineT = stateMachine.toTypeName - private val state: u.TermName = TermName("state$async") - private val execContext = TermName("execContext$async") - private val tr: u.TermName = TermName("tr$async") - private val t: u.TermName = TermName("throwable$async") - - final class NameSource[N <: U#Name](cache: NameCache[N]) { - private val count = new AtomicInteger(0) - def apply(): N = cache(count.getAndIncrement()) - } - - class AsyncName { - final val matchRes = new NameSource[U#TermName](self.matchRes) - final val ifRes = new NameSource[U#TermName](self.matchRes) - final val await = new NameSource[U#TermName](self.await) - final val completed = self.completed - final val result = self.result - final val apply = self.apply - final val stateMachine = self.stateMachine - final val stateMachineT = self.stateMachineT - final val state: u.TermName = self.state - final val execContext = self.execContext - final val tr: u.TermName = self.tr - final val t: u.TermName = self.t - - private val seenPrefixes = mutable.AnyRefMap[Name, AtomicInteger]() - private val freshened = mutable.HashSet[Name]() - - final def freshenIfNeeded(name: TermName): TermName = { - seenPrefixes.getOrNull(name) match { - case null => - seenPrefixes.put(name, new AtomicInteger()) - name - case counter => - freshen(name, counter) - } - } - final def freshenIfNeeded(name: TypeName): TypeName = { - seenPrefixes.getOrNull(name) match { - case null => - seenPrefixes.put(name, new AtomicInteger()) - name - case counter => - freshen(name, counter) - } - } - final def freshen(name: TermName): TermName = { - val counter = seenPrefixes.getOrElseUpdate(name, new AtomicInteger()) - freshen(name, counter) - } - final def freshen(name: TypeName): TypeName = { - val counter = seenPrefixes.getOrElseUpdate(name, new AtomicInteger()) - freshen(name, counter) - } - private def freshen(name: TermName, counter: AtomicInteger): TermName = { - if (freshened.contains(name)) name - else TermName(freshenString(name.toString, counter.incrementAndGet())) - } - private def freshen(name: TypeName, counter: AtomicInteger): TypeName = { - if (freshened.contains(name)) name - else TypeName(freshenString(name.toString, counter.incrementAndGet())) - } - } - - private def freshenString(name: String, counter: Int): String = name.toString + "$async$" + counter -} diff --git a/src/main/scala/scala/async/internal/AsyncTransform.scala b/src/main/scala/scala/async/internal/AsyncTransform.scala deleted file mode 100644 index f60135bd..00000000 --- a/src/main/scala/scala/async/internal/AsyncTransform.scala +++ /dev/null @@ -1,257 +0,0 @@ -/* - * Scala (https://www.scala-lang.org) - * - * Copyright EPFL and Lightbend, Inc. - * - * Licensed under Apache License 2.0 - * (http://www.apache.org/licenses/LICENSE-2.0). - * - * See the NOTICE file distributed with this work for - * additional information regarding copyright ownership. - */ - -package scala.async.internal - -trait AsyncTransform { - self: AsyncMacro => - - import c.universe._ - import c.internal._ - import decorators._ - - val asyncBase: AsyncBase - - def asyncTransform[T](execContext: Tree) - (resultType: WeakTypeTag[T]): Tree = { - - // We annotate the type of the whole expression as `T @uncheckedBounds` so as not to introduce - // warnings about non-conformant LUBs. See SI-7694 - // This implicit propagates the annotated type in the type tag. - implicit val uncheckedBoundsResultTag: WeakTypeTag[T] = c.WeakTypeTag[T](uncheckedBounds(resultType.tpe)) - - reportUnsupportedAwaits(body) - - // Transform to A-normal form: - // - no await calls in qualifiers or arguments, - // - if/match only used in statement position. - val anfTree0: Block = anfTransform(body, c.internal.enclosingOwner) - - val anfTree = futureSystemOps.postAnfTransform(anfTree0) - - cleanupContainsAwaitAttachments(anfTree) - containsAwait = containsAwaitCached(anfTree) - - val applyDefDefDummyBody: DefDef = { - val applyVParamss = List(List(ValDef(Modifiers(Flag.PARAM), name.tr, TypeTree(futureSystemOps.tryType[Any]), EmptyTree))) - DefDef(NoMods, name.apply, Nil, applyVParamss, TypeTree(definitions.UnitTpe), literalUnit) - } - - // Create `ClassDef` of state machine with empty method bodies for `resume` and `apply`. - val stateMachine: ClassDef = { - val body: List[Tree] = { - val stateVar = ValDef(Modifiers(Flag.MUTABLE | Flag.PRIVATE | Flag.LOCAL), name.state, TypeTree(definitions.IntTpe), Literal(Constant(StateAssigner.Initial))) - val resultAndAccessors = mkMutableField(futureSystemOps.promType[T](uncheckedBoundsResultTag), name.result, futureSystemOps.createProm[T](uncheckedBoundsResultTag).tree) - val execContextValDef = ValDef(NoMods, name.execContext, TypeTree(), execContext) - - val apply0DefDef: DefDef = { - // We extend () => Unit so we can pass this class as the by-name argument to `Future.apply`. - // See SI-1247 for the the optimization that avoids creation. - DefDef(NoMods, name.apply, Nil, List(Nil), TypeTree(definitions.UnitTpe), Apply(Ident(name.apply), literalNull :: Nil)) - } - List(emptyConstructor, stateVar) ++ resultAndAccessors ++ List(execContextValDef) ++ List(applyDefDefDummyBody, apply0DefDef) - } - - val customParents = futureSystemOps.stateMachineClassParents - val tycon = if (customParents.forall(_.typeSymbol.asClass.isTrait)) { - // prefer extending a class to reduce the class file size of the state machine. - symbolOf[scala.runtime.AbstractFunction1[Any, Any]] - } else { - // ... unless a custom future system already extends some class - symbolOf[scala.Function1[Any, Any]] - } - val tryToUnit = appliedType(tycon, futureSystemOps.tryType[Any], typeOf[Unit]) - val template = Template((futureSystemOps.stateMachineClassParents ::: List(tryToUnit, typeOf[() => Unit])).map(TypeTree(_)), noSelfType, body) - - val t = ClassDef(NoMods, name.stateMachineT, Nil, template) - typecheckClassDef(t) - } - - val stateMachineClass = stateMachine.symbol - val asyncBlock: AsyncBlock = { - val symLookup = SymLookup(stateMachineClass, applyDefDefDummyBody.vparamss.head.head.symbol) - buildAsyncBlock(anfTree, symLookup) - } - - val liftedFields: List[Tree] = liftables(asyncBlock.asyncStates) - - // live variables analysis - // the result map indicates in which states a given field should be nulled out - val assignsOf = fieldsToNullOut(asyncBlock.asyncStates, liftedFields) - - for ((state, flds) <- assignsOf) { - val assigns = flds.map { fld => - val fieldSym = fld.symbol - val assign = Assign(gen.mkAttributedStableRef(thisType(fieldSym.owner), fieldSym), mkZero(fieldSym.info)) - asyncBase.nullOut(c.universe)(c.Expr[String](Literal(Constant(fieldSym.name.toString))), c.Expr[Any](Ident(fieldSym))).tree match { - case Literal(Constant(value: Unit)) => assign - case x => Block(x :: Nil, assign) - } - } - val asyncState = asyncBlock.asyncStates.find(_.state == state).get - asyncState.stats = assigns ++ asyncState.stats - } - - def startStateMachine: Tree = { - val stateMachineSpliced: Tree = spliceMethodBodies( - liftedFields, - stateMachine, - atMacroPos(asyncBlock.onCompleteHandler[T]) - ) - - def selectStateMachine(selection: TermName) = Select(Ident(name.stateMachine), selection) - - Block(List[Tree]( - stateMachineSpliced, - ValDef(NoMods, name.stateMachine, TypeTree(), Apply(Select(New(Ident(stateMachine.symbol)), termNames.CONSTRUCTOR), Nil)), - futureSystemOps.spawn(Apply(selectStateMachine(name.apply), Nil), selectStateMachine(name.execContext)) - ), - futureSystemOps.promiseToFuture(c.Expr[futureSystem.Prom[T]](selectStateMachine(name.result))).tree) - } - - val isSimple = asyncBlock.asyncStates.size == 1 - val result = if (isSimple) - futureSystemOps.spawn(body, execContext) // generate lean code for the simple case of `async { 1 + 1 }` - else - startStateMachine - - if(AsyncUtils.verbose) { - logDiagnostics(anfTree, asyncBlock, asyncBlock.asyncStates.map(_.toString)) - } - futureSystemOps.dot(enclosingOwner, body).foreach(f => f(asyncBlock.toDot)) - cleanupContainsAwaitAttachments(result) - } - - def logDiagnostics(anfTree: Tree, block: AsyncBlock, states: Seq[String]): Unit = { - def location = try { - macroPos.source.path - } catch { - case _: UnsupportedOperationException => - macroPos.toString - } - - AsyncUtils.vprintln(s"In file '$location':") - AsyncUtils.vprintln(s"${c.macroApplication}") - AsyncUtils.vprintln(s"ANF transform expands to:\n $anfTree") - states foreach (s => AsyncUtils.vprintln(s)) - AsyncUtils.vprintln("===== DOT =====") - AsyncUtils.vprintln(block.toDot) - } - - /** - * Build final `ClassDef` tree of state machine class. - * - * @param liftables trees of definitions that are lifted to fields of the state machine class - * @param tree `ClassDef` tree of the state machine class - * @param applyBody tree of onComplete handler (`apply` method) - * @return transformed `ClassDef` tree of the state machine class - */ - def spliceMethodBodies(liftables: List[Tree], tree: ClassDef, applyBody: Tree): Tree = { - val liftedSyms = liftables.map(_.symbol).toSet - val stateMachineClass = tree.symbol - liftedSyms.foreach { - sym => - if (sym != null) { - sym.setOwner(stateMachineClass) - if (sym.isModule) - sym.asModule.moduleClass.setOwner(stateMachineClass) - } - } - - def adjustType(tree: Tree): Tree = { - val resultType = if (tree.tpe eq null) null else tree.tpe.map { - case TypeRef(pre, sym, args) if liftedSyms.contains(sym) => - val tp1 = internal.typeRef(thisType(sym.owner.asClass), sym, args) - tp1 - case SingleType(pre, sym) if liftedSyms.contains(sym) => - val tp1 = internal.singleType(thisType(sym.owner.asClass), sym) - tp1 - case tp => tp - } - setType(tree, resultType) - } - - // Replace the ValDefs in the splicee with Assigns to the corresponding lifted - // fields. Similarly, replace references to them with references to the field. - // - // This transform will only be run on the RHS of `def foo`. - val useFields: (Tree, TypingTransformApi) => Tree = (tree, api) => tree match { - case _ if api.currentOwner == stateMachineClass => - api.default(tree) - case ValDef(_, _, _, rhs) if liftedSyms(tree.symbol) => - api.atOwner(api.currentOwner) { - val fieldSym = tree.symbol - if (fieldSym.asTerm.isLazy) Literal(Constant(())) - else { - val lhs = atPos(tree.pos) { - gen.mkAttributedStableRef(thisType(fieldSym.owner.asClass), fieldSym) - } - treeCopy.Assign(tree, lhs, api.recur(rhs)).setType(definitions.UnitTpe).changeOwner(fieldSym, api.currentOwner) - } - } - case _: DefTree if liftedSyms(tree.symbol) => - EmptyTree - case Ident(name) if liftedSyms(tree.symbol) => - val fieldSym = tree.symbol - atPos(tree.pos) { - gen.mkAttributedStableRef(thisType(fieldSym.owner.asClass), fieldSym).setType(tree.tpe) - } - case sel @ Select(n@New(tt: TypeTree), termNamesCONSTRUCTOR) => - adjustType(sel) - adjustType(n) - adjustType(tt) - sel - case _ => - api.default(tree) - } - - val liftablesUseFields = liftables.map { - case vd: ValDef if !vd.symbol.asTerm.isLazy => vd - case x => typingTransform(x, stateMachineClass)(useFields) - } - - tree.children.foreach(_.changeOwner(enclosingOwner, tree.symbol)) - val treeSubst = tree - - /* Fixes up DefDef: use lifted fields in `body` */ - def fixup(dd: DefDef, body: Tree, api: TypingTransformApi): Tree = { - val spliceeAnfFixedOwnerSyms = body - val newRhs = typingTransform(spliceeAnfFixedOwnerSyms, dd.symbol)(useFields) - val newRhsTyped = api.atOwner(dd, dd.symbol)(api.typecheck(newRhs)) - treeCopy.DefDef(dd, dd.mods, dd.name, dd.tparams, dd.vparamss, dd.tpt, newRhsTyped) - } - - liftablesUseFields.foreach(t => if (t.symbol != null) stateMachineClass.info.decls.enter(t.symbol)) - - val result0 = transformAt(treeSubst) { - case t@Template(parents, self, stats) => - (api: TypingTransformApi) => { - treeCopy.Template(t, parents, self, liftablesUseFields ++ stats) - } - } - val result = transformAt(result0) { - case dd@DefDef(_, name.apply, _, List(List(_)), _, _) if dd.symbol.owner == stateMachineClass => - (api: TypingTransformApi) => - val typedTree = fixup(dd, applyBody.changeOwner(enclosingOwner, dd.symbol), api) - typedTree - } - result - } - - def typecheckClassDef(cd: ClassDef): ClassDef = { - val Block(cd1 :: Nil, _) = typingTransform(atPos(macroPos)(Block(cd :: Nil, Literal(Constant(())))))( - (tree, api) => - api.typecheck(tree) - ) - cd1.asInstanceOf[ClassDef] - } -} diff --git a/src/main/scala/scala/async/internal/AsyncUtils.scala b/src/main/scala/scala/async/internal/AsyncUtils.scala deleted file mode 100644 index 81b296ca..00000000 --- a/src/main/scala/scala/async/internal/AsyncUtils.scala +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Scala (https://www.scala-lang.org) - * - * Copyright EPFL and Lightbend, Inc. - * - * Licensed under Apache License 2.0 - * (http://www.apache.org/licenses/LICENSE-2.0). - * - * See the NOTICE file distributed with this work for - * additional information regarding copyright ownership. - */ - -package scala.async.internal - -object AsyncUtils { - - - private def enabled(level: String) = sys.props.getOrElse(s"scala.async.$level", "false").equalsIgnoreCase("true") - - private[async] val verbose = enabled("debug") - private[async] val trace = enabled("trace") - - @inline private[async] def vprintln(s: => Any): Unit = if (verbose) println(s"[async] $s") - - @inline private[async] def trace(s: => Any): Unit = if (trace) println(s"[async] $s") -} diff --git a/src/main/scala/scala/async/internal/ExprBuilder.scala b/src/main/scala/scala/async/internal/ExprBuilder.scala deleted file mode 100644 index 9570af99..00000000 --- a/src/main/scala/scala/async/internal/ExprBuilder.scala +++ /dev/null @@ -1,650 +0,0 @@ -/* - * Scala (https://www.scala-lang.org) - * - * Copyright EPFL and Lightbend, Inc. - * - * Licensed under Apache License 2.0 - * (http://www.apache.org/licenses/LICENSE-2.0). - * - * See the NOTICE file distributed with this work for - * additional information regarding copyright ownership. - */ - -package scala.async.internal - -import java.util.function.IntUnaryOperator - -import scala.collection.mutable -import scala.collection.mutable.ListBuffer - -trait ExprBuilder { - builder: AsyncMacro => - - import c.universe._ - import c.internal._ - - val futureSystem: FutureSystem - val futureSystemOps: futureSystem.Ops { val c: builder.c.type } - - val stateAssigner = new StateAssigner - val labelDefStates = collection.mutable.Map[Symbol, Int]() - - trait AsyncState { - def state: Int - - def nextStates: Array[Int] - - def mkHandlerCaseForState[T: WeakTypeTag]: CaseDef - - def mkOnCompleteHandler[T: WeakTypeTag]: Option[CaseDef] = None - - var stats: List[Tree] - - def treesThenStats(trees: List[Tree]): List[Tree] = { - (stats match { - case init :+ last if tpeOf(last) =:= definitions.NothingTpe => - adaptToUnit((trees ::: init) :+ Typed(last, TypeTree(definitions.AnyTpe))) - case _ => - adaptToUnit(trees ::: stats) - }) :: Nil - } - - final def allStats: List[Tree] = this match { - case a: AsyncStateWithAwait => treesThenStats(a.awaitable.resultValDef :: Nil) - case _ => stats - } - - final def body: Tree = stats match { - case stat :: Nil => stat - case init :+ last => Block(init, last) - } - } - - /** A sequence of statements that concludes with a unconditional transition to `nextState` */ - final class SimpleAsyncState(var stats: List[Tree], val state: Int, nextState: Int, symLookup: SymLookup) - extends AsyncState { - - val nextStates: Array[Int] = - Array(nextState) - - def mkHandlerCaseForState[T: WeakTypeTag]: CaseDef = { - mkHandlerCase(state, treesThenStats(mkStateTree(nextState, symLookup) :: Nil)) - } - - override val toString: String = - s"AsyncState #$state, next = $nextState" - } - - /** A sequence of statements with a conditional transition to the next state, which will represent - * a branch of an `if` or a `match`. - */ - final class AsyncStateWithoutAwait(var stats: List[Tree], val state: Int, val nextStates: Array[Int]) extends AsyncState { - override def mkHandlerCaseForState[T: WeakTypeTag]: CaseDef = - mkHandlerCase(state, stats) - - override val toString: String = - s"AsyncStateWithoutAwait #$state, nextStates = ${nextStates.toList}" - } - - /** A sequence of statements that concludes with an `await` call. The `onComplete` - * handler will unconditionally transition to `nextState`. - */ - final class AsyncStateWithAwait(var stats: List[Tree], val state: Int, val onCompleteState: Int, nextState: Int, - val awaitable: Awaitable, symLookup: SymLookup) - extends AsyncState { - - val nextStates: Array[Int] = - Array(nextState) - - override def mkHandlerCaseForState[T: WeakTypeTag]: CaseDef = { - val fun = This(typeNames.EMPTY) - val callOnComplete = futureSystemOps.onComplete[Any, Unit](c.Expr[futureSystem.Fut[Any]](awaitable.expr), - c.Expr[futureSystem.Tryy[Any] => Unit](fun), c.Expr[futureSystem.ExecContext](Ident(name.execContext))).tree - val tryGetOrCallOnComplete: List[Tree] = - if (futureSystemOps.continueCompletedFutureOnSameThread) { - val tempName = name.completed - val initTemp = ValDef(NoMods, tempName, TypeTree(futureSystemOps.tryType[Any]), futureSystemOps.getCompleted[Any](c.Expr[futureSystem.Fut[Any]](awaitable.expr)).tree) - val ifTree = If(Apply(Select(Literal(Constant(null)), TermName("ne")), Ident(tempName) :: Nil), - adaptToUnit(ifIsFailureTree[T](Ident(tempName)) :: Nil), - Block(toList(callOnComplete), Return(literalUnit))) - initTemp :: ifTree :: Nil - } else - toList(callOnComplete) ::: Return(literalUnit) :: Nil - mkHandlerCase(state, stats ++ List(mkStateTree(onCompleteState, symLookup)) ++ tryGetOrCallOnComplete) - } - - private def tryGetTree(tryReference: => Tree) = - Assign( - Ident(awaitable.resultName), - TypeApply(Select(futureSystemOps.tryyGet[Any](c.Expr[futureSystem.Tryy[Any]](tryReference)).tree, TermName("asInstanceOf")), List(TypeTree(awaitable.resultType))) - ) - - /* if (tr.isFailure) - * result.complete(tr.asInstanceOf[Try[T]]) - * else { - * = tr.get.asInstanceOf[] - * - * - * } - */ - def ifIsFailureTree[T: WeakTypeTag](tryReference: => Tree) = { - val getAndUpdateState = Block(List(tryGetTree(tryReference)), mkStateTree(nextState, symLookup)) - if (asyncBase.futureSystem.emitTryCatch) { - If(futureSystemOps.tryyIsFailure(c.Expr[futureSystem.Tryy[T]](tryReference)).tree, - Block(toList(futureSystemOps.completeProm[T]( - c.Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), - c.Expr[futureSystem.Tryy[T]]( - TypeApply(Select(tryReference, TermName("asInstanceOf")), - List(TypeTree(futureSystemOps.tryType[T]))))).tree), - Return(literalUnit)), - getAndUpdateState - ) - } else { - getAndUpdateState - } - } - - override def mkOnCompleteHandler[T: WeakTypeTag]: Option[CaseDef] = { - Some(mkHandlerCase(onCompleteState, List(ifIsFailureTree[T](Ident(symLookup.applyTrParam))))) - } - - override val toString: String = - s"AsyncStateWithAwait #$state, next = $nextState" - } - - /* - * Builder for a single state of an async expression. - */ - final class AsyncStateBuilder(state: Int, private val symLookup: SymLookup) { - /* Statements preceding an await call. */ - private val stats = ListBuffer[Tree]() - /** The state of the target of a LabelDef application (while loop jump) */ - private var nextJumpState: Option[Int] = None - private var nextJumpSymbol: Symbol = NoSymbol - def effectiveNextState(nextState: Int) = nextJumpState.orElse(if (nextJumpSymbol == NoSymbol) None else Some(stateIdForLabel(nextJumpSymbol))).getOrElse(nextState) - - def +=(stat: Tree): this.type = { - stat match { - case Literal(Constant(())) => // This case occurs in do/while - case _ => - assert(nextJumpState.isEmpty, s"statement appeared after a label jump: $stat") - } - def addStat() = stats += stat - stat match { - case Apply(fun, args) if isLabel(fun.symbol) => - // labelDefStates belongs to the current ExprBuilder - labelDefStates get fun.symbol match { - case opt@Some(nextState) => - // A backward jump - nextJumpState = opt // re-use object - nextJumpSymbol = fun.symbol - case None => - // We haven't the corresponding LabelDef, this is a forward jump - nextJumpSymbol = fun.symbol - } - case _ => addStat() - } - this - } - - def resultWithAwait(awaitable: Awaitable, - onCompleteState: Int, - nextState: Int): AsyncState = { - new AsyncStateWithAwait(stats.toList, state, onCompleteState, effectiveNextState(nextState), awaitable, symLookup) - } - - def resultSimple(nextState: Int): AsyncState = { - new SimpleAsyncState(stats.toList, state, effectiveNextState(nextState), symLookup) - } - - def resultWithIf(condTree: Tree, thenState: Int, elseState: Int): AsyncState = { - def mkBranch(state: Int) = mkStateTree(state, symLookup) - this += If(condTree, mkBranch(thenState), mkBranch(elseState)) - new AsyncStateWithoutAwait(stats.toList, state, Array(thenState, elseState)) - } - - /** - * Build `AsyncState` ending with a match expression. - * - * The cases of the match simply resume at the state of their corresponding right-hand side. - * - * @param scrutTree tree of the scrutinee - * @param cases list of case definitions - * @param caseStates starting state of the right-hand side of the each case - * @return an `AsyncState` representing the match expression - */ - def resultWithMatch(scrutTree: Tree, cases: List[CaseDef], caseStates: Array[Int], symLookup: SymLookup): AsyncState = { - // 1. build list of changed cases - val newCases = for ((cas, num) <- cases.zipWithIndex) yield cas match { - case CaseDef(pat, guard, rhs) => - val bindAssigns = rhs.children.takeWhile(isSyntheticBindVal) - CaseDef(pat, guard, Block(bindAssigns, mkStateTree(caseStates(num), symLookup))) - } - // 2. insert changed match tree at the end of the current state - this += Match(scrutTree, newCases) - new AsyncStateWithoutAwait(stats.toList, state, caseStates) - } - - def resultWithLabel(startLabelState: Int, symLookup: SymLookup): AsyncState = { - this += mkStateTree(startLabelState, symLookup) - new AsyncStateWithoutAwait(stats.toList, state, Array(startLabelState)) - } - - override def toString: String = { - val statsBeforeAwait = stats.mkString("\n") - s"ASYNC STATE:\n$statsBeforeAwait" - } - } - - /** - * An `AsyncBlockBuilder` builds a `ListBuffer[AsyncState]` based on the expressions of a `Block(stats, expr)` (see `Async.asyncImpl`). - * - * @param stats a list of expressions - * @param expr the last expression of the block - * @param startState the start state - * @param endState the state to continue with - */ - final private class AsyncBlockBuilder(stats: List[Tree], expr: Tree, startState: Int, endState: Int, - private val symLookup: SymLookup) { - val asyncStates = ListBuffer[AsyncState]() - - var stateBuilder = new AsyncStateBuilder(startState, symLookup) - var currState = startState - - def checkForUnsupportedAwait(tree: Tree) = if (containsAwait(tree)) - c.abort(tree.pos, "await must not be used in this position") - - def nestedBlockBuilder(nestedTree: Tree, startState: Int, endState: Int) = { - val (nestedStats, nestedExpr) = statsAndExpr(nestedTree) - new AsyncBlockBuilder(nestedStats, nestedExpr, startState, endState, symLookup) - } - - import stateAssigner.nextState - def directlyAdjacentLabelDefs(t: Tree): List[Tree] = { - def isPatternCaseLabelDef(t: Tree) = t match { - case LabelDef(name, _, _) => name.toString.startsWith("case") - case _ => false - } - val span = (stats :+ expr).filterNot(isLiteralUnit).span(_ ne t) - span match { - case (before, _ :: after) => - before.reverse.takeWhile(isPatternCaseLabelDef) ::: after.takeWhile(isPatternCaseLabelDef) - case _ => - stats :+ expr - } - } - - // populate asyncStates - def add(stat: Tree, afterState: Option[Int] = None): Unit = stat match { - // the val name = await(..) pattern - case vd @ ValDef(mods, name, tpt, Apply(fun, arg :: Nil)) if isAwait(fun) => - val onCompleteState = nextState() - val afterAwaitState = afterState.getOrElse(nextState()) - val awaitable = Awaitable(arg, stat.symbol, tpt.tpe, vd) - asyncStates += stateBuilder.resultWithAwait(awaitable, onCompleteState, afterAwaitState) // complete with await - currState = afterAwaitState - stateBuilder = new AsyncStateBuilder(currState, symLookup) - - case If(cond, thenp, elsep) if containsAwait(stat) || containsForiegnLabelJump(stat) => - checkForUnsupportedAwait(cond) - - val thenStartState = nextState() - val elseStartState = nextState() - val afterIfState = afterState.getOrElse(nextState()) - - asyncStates += - // the two Int arguments are the start state of the then branch and the else branch, respectively - stateBuilder.resultWithIf(cond, thenStartState, elseStartState) - - List((thenp, thenStartState), (elsep, elseStartState)) foreach { - case (branchTree, state) => - val builder = nestedBlockBuilder(branchTree, state, afterIfState) - asyncStates ++= builder.asyncStates - } - - currState = afterIfState - stateBuilder = new AsyncStateBuilder(currState, symLookup) - - case Match(scrutinee, cases) if containsAwait(stat) => - checkForUnsupportedAwait(scrutinee) - - val caseStates = new Array[Int](cases.length) - java.util.Arrays.setAll(caseStates, new IntUnaryOperator { - override def applyAsInt(operand: Int): Int = nextState() - }) - val afterMatchState = afterState.getOrElse(nextState()) - - asyncStates += - stateBuilder.resultWithMatch(scrutinee, cases, caseStates, symLookup) - - for ((cas, num) <- cases.zipWithIndex) { - val (stats, expr) = statsAndExpr(cas.body) - val stats1 = stats.dropWhile(isSyntheticBindVal) - val builder = nestedBlockBuilder(Block(stats1, expr), caseStates(num), afterMatchState) - asyncStates ++= builder.asyncStates - } - - currState = afterMatchState - stateBuilder = new AsyncStateBuilder(currState, symLookup) - case ld @ LabelDef(name, params, rhs) - if containsAwait(rhs) || directlyAdjacentLabelDefs(ld).exists(containsAwait) => - - val startLabelState = stateIdForLabel(ld.symbol) - val afterLabelState = afterState.getOrElse(nextState()) - asyncStates += stateBuilder.resultWithLabel(startLabelState, symLookup) - labelDefStates(ld.symbol) = startLabelState - val builder = nestedBlockBuilder(rhs, startLabelState, afterLabelState) - asyncStates ++= builder.asyncStates - currState = afterLabelState - stateBuilder = new AsyncStateBuilder(currState, symLookup) - case b @ Block(stats, expr) => - for (stat <- stats) add(stat) - add(expr, afterState = Some(endState)) - case _ => - checkForUnsupportedAwait(stat) - stateBuilder += stat - } - for (stat <- (stats :+ expr)) add(stat) - val lastState = stateBuilder.resultSimple(endState) - asyncStates += lastState - } - - trait AsyncBlock { - def asyncStates: List[AsyncState] - - def onCompleteHandler[T: WeakTypeTag]: Tree - - def toDot: String - } - - case class SymLookup(stateMachineClass: Symbol, applyTrParam: Symbol) { - def stateMachineMember(name: TermName): Symbol = - stateMachineClass.info.member(name) - def memberRef(name: TermName): Tree = - gen.mkAttributedRef(stateMachineMember(name)) - } - - /** - * Uses `AsyncBlockBuilder` to create an instance of `AsyncBlock`. - * - * @param block a `Block` tree in ANF - * @param symLookup helper for looking up members of the state machine class - * @return an `AsyncBlock` - */ - def buildAsyncBlock(block: Block, symLookup: SymLookup): AsyncBlock = { - val Block(stats, expr) = block - val startState = stateAssigner.nextState() - val endState = Int.MaxValue - - val blockBuilder = new AsyncBlockBuilder(stats, expr, startState, endState, symLookup) - - new AsyncBlock { - val switchIds = mutable.AnyRefMap[Integer, Integer]() - - // render with http://graphviz.it/#/new - def toDot: String = { - val states = asyncStates - def toHtmlLabel(label: String, preText: String, builder: StringBuilder): Unit = { - val br = "
" - builder.append("").append(label).append("").append("
") - builder.append("") - preText.split("\n").foreach { - (line: String) => - builder.append(br) - builder.append(line.replaceAllLiterally("\"", """).replaceAllLiterally("<", "<").replaceAllLiterally(">", ">").replaceAllLiterally(" ", " ")) - } - builder.append(br) - builder.append("") - } - val dotBuilder = new StringBuilder() - dotBuilder.append("digraph {\n") - def stateLabel(s: Int) = { - if (s == 0) "INITIAL" else if (s == Int.MaxValue) "TERMINAL" else switchIds.get(s).map(_.toString).getOrElse(s.toString) - } - val length = states.size - for ((state, i) <- asyncStates.zipWithIndex) { - dotBuilder.append(s"""${stateLabel(state.state)} [label=""").append("<") - def show(t: Tree): String = { - (t match { - case Block(stats, expr) => stats ::: expr :: Nil - case t => t :: Nil - }).iterator.map(t => showCode(t)).mkString("\n") - } - if (i != length - 1) { - val CaseDef(_, _, body) = state.mkHandlerCaseForState - toHtmlLabel(stateLabel(state.state), show(compactStateTransform.transform(body)), dotBuilder) - } else { - toHtmlLabel(stateLabel(state.state), state.allStats.map(show(_)).mkString("\n"), dotBuilder) - } - dotBuilder.append("> ]\n") - state match { - case s: AsyncStateWithAwait => - val CaseDef(_, _, body) = s.mkOnCompleteHandler.get - dotBuilder.append(s"""${stateLabel(s.onCompleteState)} [label=""").append("<") - toHtmlLabel(stateLabel(s.onCompleteState), show(compactStateTransform.transform(body)), dotBuilder) - dotBuilder.append("> ]\n") - case _ => - } - } - for (state <- states) { - state match { - case s: AsyncStateWithAwait => - dotBuilder.append(s"""${stateLabel(state.state)} -> ${stateLabel(s.onCompleteState)} [style=dashed color=red]""") - dotBuilder.append("\n") - for (succ <- state.nextStates) { - dotBuilder.append(s"""${stateLabel(s.onCompleteState)} -> ${stateLabel(succ)}""") - dotBuilder.append("\n") - } - case _ => - for (succ <- state.nextStates) { - dotBuilder.append(s"""${stateLabel(state.state)} -> ${stateLabel(succ)}""") - dotBuilder.append("\n") - } - } - } - dotBuilder.append("}\n") - dotBuilder.toString - } - - lazy val asyncStates: List[AsyncState] = filterStates - - def filterStates = { - val all = blockBuilder.asyncStates.toList - val (initial :: rest) = all - val map = all.iterator.map(x => (x.state, x)).toMap - var seen = mutable.HashSet[Int]() - def loop(state: AsyncState): Unit = { - seen.add(state.state) - for (i <- state.nextStates) { - if (i != Int.MaxValue && !seen.contains(i)) { - loop(map(i)) - } - } - } - loop(initial) - val live = rest.filter(state => seen(state.state)) - var nextSwitchId = 0 - (initial :: live).foreach { state => - val switchId = nextSwitchId - switchIds(state.state) = switchId - nextSwitchId += 1 - state match { - case state: AsyncStateWithAwait => - val switchId = nextSwitchId - switchIds(state.onCompleteState) = switchId - nextSwitchId += 1 - case _ => - } - } - initial :: live - - } - - def mkCombinedHandlerCases[T: WeakTypeTag]: List[CaseDef] = { - val caseForLastState: CaseDef = { - val lastState = asyncStates.last - val lastStateBody = c.Expr[T](lastState.body) - val rhs = futureSystemOps.completeWithSuccess( - c.Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), lastStateBody) - mkHandlerCase(lastState.state, Block(rhs.tree, Return(literalUnit))) - } - asyncStates match { - case s :: Nil => - List(caseForLastState) - case _ => - val initCases = for (state <- asyncStates.init) yield state.mkHandlerCaseForState[T] - initCases :+ caseForLastState - } - } - - val initStates = asyncStates.init - - /** - * Builds the definition of the `resume` method. - * - * The resulting tree has the following shape: - * - * def resume(): Unit = { - * try { - * state match { - * case 0 => { - * f11 = exprReturningFuture - * f11.onComplete(onCompleteHandler)(context) - * } - * ... - * } - * } catch { - * case NonFatal(t) => result.failure(t) - * } - * } - */ - private def resumeFunTree[T: WeakTypeTag]: Tree = { - val stateMemberRef = symLookup.memberRef(name.state) - val body = Match(stateMemberRef, mkCombinedHandlerCases[T] ++ initStates.flatMap(_.mkOnCompleteHandler[T]) ++ List(CaseDef(Ident(termNames.WILDCARD), EmptyTree, Throw(Apply(Select(New(Ident(defn.IllegalStateExceptionClass)), termNames.CONSTRUCTOR), List()))))) - val body1 = compactStates(body) - - maybeTry( - body1, - List( - CaseDef( - Bind(name.t, Typed(Ident(termNames.WILDCARD), Ident(defn.ThrowableClass))), - EmptyTree, { - val thenn = { - val t = c.Expr[Throwable](Ident(name.t)) - val complete = futureSystemOps.completeProm[T]( - c.Expr[futureSystem.Prom[T]](symLookup.memberRef(name.result)), futureSystemOps.tryyFailure[T](t)).tree - Block(toList(complete), Return(literalUnit)) - } - If(Apply(Ident(defn.NonFatalClass), List(Ident(name.t))), thenn, Throw(Ident(name.t))) - thenn - })), EmptyTree) - } - - private lazy val stateMemberSymbol = symLookup.stateMachineMember(name.state) - private val compactStateTransform = new Transformer { - override def transform(tree: Tree): Tree = tree match { - case as @ Assign(lhs, Literal(Constant(i: Integer))) if lhs.symbol == stateMemberSymbol => - val replacement = switchIds(i) - treeCopy.Assign(tree, lhs, Literal(Constant(replacement))) - case _: Match | _: CaseDef | _: Block | _: If => - super.transform(tree) - case _ => tree - } - } - - private def compactStates(m: Match): Tree = { - val cases1 = m.cases.flatMap { - case cd @ CaseDef(Literal(Constant(i: Integer)), EmptyTree, rhs) => - val replacement = switchIds(i) - val rhs1 = compactStateTransform.transform(rhs) - treeCopy.CaseDef(cd, Literal(Constant(replacement)), EmptyTree, rhs1) :: Nil - case x => x :: Nil - } - treeCopy.Match(m, m.selector, cases1) - } - - def forever(t: Tree): Tree = { - val termName = TermName(name.fresh("while$")) - LabelDef(termName, Nil, Block(toList(t), Apply(Ident(termName), Nil))) - } - - /** - * Builds a `match` expression used as an onComplete handler. - * - * Assumes `tr: Try[Any]` is in scope. The resulting tree has the following shape: - * - * state match { - * case 0 => - * x11 = tr.get.asInstanceOf[Double] - * state = 1 - * resume() - * } - */ - def onCompleteHandler[T: WeakTypeTag]: Tree = { - initStates.flatMap(_.mkOnCompleteHandler[T]) - forever { - adaptToUnit(toList(resumeFunTree)) - } - } - } - } - - private def isSyntheticBindVal(tree: Tree) = tree match { - case vd@ValDef(_, lname, _, Ident(rname)) => attachments(vd.symbol).contains[SyntheticBindVal.type] - case _ => false - } - - case class Awaitable(expr: Tree, resultName: Symbol, resultType: Type, resultValDef: ValDef) - - private def mkStateTree(nextState: Int, symLookup: SymLookup): Tree = - Assign(symLookup.memberRef(name.state), Literal(Constant(nextState))) - - private def mkHandlerCase(num: Int, rhs: List[Tree]): CaseDef = - mkHandlerCase(num, adaptToUnit(rhs)) - - // We use the convention that the state machine's ID for a state corresponding to - // a labeldef will a negative number be based on the symbol ID. This allows us - // to translate a forward jump to the label as a state transition to a known state - // ID, even though the state machine transform hasn't yet processed the target label - // def. Negative numbers are used so as as not to clash with regular state IDs, which - // are allocated in ascending order from 0. - private def stateIdForLabel(sym: Symbol): Int = -symId(sym) - - private def tpeOf(t: Tree): Type = t match { - case _ if t.tpe != null => t.tpe - case Try(body, Nil, _) => tpeOf(body) - case Block(_, expr) => tpeOf(expr) - case Literal(Constant(value)) if value == (()) => definitions.UnitTpe - case Return(_) => definitions.NothingTpe - case _ => NoType - } - - private def adaptToUnit(rhs: List[Tree]): c.universe.Block = { - rhs match { - case (rhs: Block) :: Nil if tpeOf(rhs) <:< definitions.UnitTpe => - rhs - case init :+ last if tpeOf(last) <:< definitions.UnitTpe => - Block(init, last) - case init :+ (last @ Literal(Constant(()))) => - Block(init, last) - case init :+ (last @ Block(_, Return(_) | Literal(Constant(())))) => - Block(init, last) - case init :+ (Block(stats, expr)) => - Block(init, Block(stats :+ expr, literalUnit)) - case _ => - Block(rhs, literalUnit) - } - } - - private def mkHandlerCase(num: Int, rhs: Tree): CaseDef = - CaseDef(Literal(Constant(num)), EmptyTree, rhs) - - def literalUnit = Literal(Constant(())) // a def to avoid sharing trees - - def toList(tree: Tree): List[Tree] = tree match { - case Block(stats, Literal(Constant(value))) if value == (()) => stats - case _ => tree :: Nil - } - - def literalNull = Literal(Constant(null)) -} diff --git a/src/main/scala/scala/async/internal/FutureSystem.scala b/src/main/scala/scala/async/internal/FutureSystem.scala deleted file mode 100644 index 11c57ef4..00000000 --- a/src/main/scala/scala/async/internal/FutureSystem.scala +++ /dev/null @@ -1,156 +0,0 @@ -/* - * Scala (https://www.scala-lang.org) - * - * Copyright EPFL and Lightbend, Inc. - * - * Licensed under Apache License 2.0 - * (http://www.apache.org/licenses/LICENSE-2.0). - * - * See the NOTICE file distributed with this work for - * additional information regarding copyright ownership. - */ - -package scala.async.internal - -import scala.language.higherKinds -import scala.reflect.macros.whitebox - -/** - * An abstraction over a future system. - * - * Used by the macro implementations in [[scala.async.internal.AsyncBase]] to - * customize the code generation. - * - * The API mirrors that of `scala.concurrent.Future`, see the instance - * [[ScalaConcurrentFutureSystem]] for an example of how - * to implement this. - */ -trait FutureSystem { - /** A container to receive the final value of the computation */ - type Prom[A] - /** A (potentially in-progress) computation */ - type Fut[A] - /** An execution context, required to create or register an on completion callback on a Future. */ - type ExecContext - /** Any data type isomorphic to scala.util.Try. */ - type Tryy[T] - - trait Ops { - val c: whitebox.Context - import c.universe._ - - def promType[A: WeakTypeTag]: Type - def tryType[A: WeakTypeTag]: Type - def execContextType: Type - def stateMachineClassParents: List[Type] = Nil - - /** Create an empty promise */ - def createProm[A: WeakTypeTag]: Expr[Prom[A]] - - /** Extract a future from the given promise. */ - def promiseToFuture[A: WeakTypeTag](prom: Expr[Prom[A]]): Expr[Fut[A]] - - /** Construct a future to asynchronously compute the given expression */ - def future[A: WeakTypeTag](a: Expr[A])(execContext: Expr[ExecContext]): Expr[Fut[A]] - - /** Register an call back to run on completion of the given future */ - def onComplete[A, U](future: Expr[Fut[A]], fun: Expr[Tryy[A] => U], - execContext: Expr[ExecContext]): Expr[Unit] - - def continueCompletedFutureOnSameThread = false - - /** Return `null` if this future is not yet completed, or `Tryy[A]` with the completed result - * otherwise - */ - def getCompleted[A: WeakTypeTag](future: Expr[Fut[A]]): Expr[Tryy[A]] = - throw new UnsupportedOperationException("getCompleted not supported by this FutureSystem") - - /** Complete a promise with a value */ - def completeProm[A](prom: Expr[Prom[A]], value: Expr[Tryy[A]]): Expr[Unit] - def completeWithSuccess[A: WeakTypeTag](prom: Expr[Prom[A]], value: Expr[A]): Expr[Unit] = completeProm(prom, tryySuccess(value)) - - def spawn(tree: Tree, execContext: Tree): Tree = - future(c.Expr[Unit](tree))(c.Expr[ExecContext](execContext)).tree - - def tryyIsFailure[A](tryy: Expr[Tryy[A]]): Expr[Boolean] - - def tryyGet[A](tryy: Expr[Tryy[A]]): Expr[A] - def tryySuccess[A: WeakTypeTag](a: Expr[A]): Expr[Tryy[A]] - def tryyFailure[A: WeakTypeTag](a: Expr[Throwable]): Expr[Tryy[A]] - - /** A hook for custom macros to transform the tree post-ANF transform */ - def postAnfTransform(tree: Block): Block = tree - - /** A hook for custom macros to selectively generate and process a Graphviz visualization of the transformed state machine */ - def dot(enclosingOwner: Symbol, macroApplication: Tree): Option[(String => Unit)] = None - } - - def mkOps(c0: whitebox.Context): Ops { val c: c0.type } - - @deprecated("No longer honoured by the macro, all generated names now contain $async to avoid accidental clashes with lambda lifted names", "0.9.7") - def freshenAllNames: Boolean = false - def emitTryCatch: Boolean = true - @deprecated("No longer honoured by the macro, all generated names now contain $async to avoid accidental clashes with lambda lifted names", "0.9.7") - def resultFieldName: String = "result" -} - -object ScalaConcurrentFutureSystem extends FutureSystem { - - import scala.concurrent._ - - type Prom[A] = Promise[A] - type Fut[A] = Future[A] - type ExecContext = ExecutionContext - type Tryy[A] = scala.util.Try[A] - - def mkOps(c0: whitebox.Context): Ops {val c: c0.type} = new Ops { - val c: c0.type = c0 - import c.universe._ - - def promType[A: WeakTypeTag]: Type = weakTypeOf[Promise[A]] - def tryType[A: WeakTypeTag]: Type = weakTypeOf[scala.util.Try[A]] - def execContextType: Type = weakTypeOf[ExecutionContext] - - def createProm[A: WeakTypeTag]: Expr[Prom[A]] = reify { - Promise[A]() - } - - def promiseToFuture[A: WeakTypeTag](prom: Expr[Prom[A]]) = reify { - prom.splice.future - } - - def future[A: WeakTypeTag](a: Expr[A])(execContext: Expr[ExecContext]) = reify { - Future(a.splice)(execContext.splice) - } - - def onComplete[A, U](future: Expr[Fut[A]], fun: Expr[scala.util.Try[A] => U], - execContext: Expr[ExecContext]): Expr[Unit] = reify { - future.splice.onComplete(fun.splice)(execContext.splice) - } - - override def continueCompletedFutureOnSameThread: Boolean = true - - override def getCompleted[A: WeakTypeTag](future: Expr[Fut[A]]): Expr[Tryy[A]] = reify { - if (future.splice.isCompleted) future.splice.value.get else null - } - - def completeProm[A](prom: Expr[Prom[A]], value: Expr[scala.util.Try[A]]): Expr[Unit] = reify { - prom.splice.complete(value.splice) - c.Expr[Unit](Literal(Constant(()))).splice - } - - def tryyIsFailure[A](tryy: Expr[scala.util.Try[A]]): Expr[Boolean] = reify { - tryy.splice.isFailure - } - - def tryyGet[A](tryy: Expr[Tryy[A]]): Expr[A] = reify { - tryy.splice.get - } - def tryySuccess[A: WeakTypeTag](a: Expr[A]): Expr[Tryy[A]] = reify { - scala.util.Success[A](a.splice) - } - def tryyFailure[A: WeakTypeTag](a: Expr[Throwable]): Expr[Tryy[A]] = reify { - scala.util.Failure[A](a.splice) - } - } -} diff --git a/src/main/scala/scala/async/internal/Lifter.scala b/src/main/scala/scala/async/internal/Lifter.scala deleted file mode 100644 index 57fefa20..00000000 --- a/src/main/scala/scala/async/internal/Lifter.scala +++ /dev/null @@ -1,179 +0,0 @@ -/* - * Scala (https://www.scala-lang.org) - * - * Copyright EPFL and Lightbend, Inc. - * - * Licensed under Apache License 2.0 - * (http://www.apache.org/licenses/LICENSE-2.0). - * - * See the NOTICE file distributed with this work for - * additional information regarding copyright ownership. - */ - -package scala.async.internal - -import scala.collection.mutable -import scala.collection.mutable.ListBuffer - -trait Lifter { - self: AsyncMacro => - import c.universe._ - import Flag._ - import c.internal._ - import decorators._ - - /** - * Identify which DefTrees are used (including transitively) which are declared - * in some state but used (including transitively) in another state. - * - * These will need to be lifted to class members of the state machine. - */ - def liftables(asyncStates: List[AsyncState]): List[Tree] = { - object companionship { - private val companions = collection.mutable.Map[Symbol, Symbol]() - private val companionsInverse = collection.mutable.Map[Symbol, Symbol]() - private def record(sym1: Symbol, sym2: Symbol): Unit = { - companions(sym1) = sym2 - companions(sym2) = sym1 - } - - def record(defs: List[Tree]): Unit = { - // Keep note of local companions so we rename them consistently - // when lifting. - for { - cd@ClassDef(_, _, _, _) <- defs - md@ModuleDef(_, _, _) <- defs - if (cd.name.toTermName == md.name) - } record(cd.symbol, md.symbol) - } - def companionOf(sym: Symbol): Symbol = { - companions.get(sym).orElse(companionsInverse.get(sym)).getOrElse(NoSymbol) - } - } - - - val defs: mutable.LinkedHashMap[Tree, Int] = { - /** Collect the DefTrees directly enclosed within `t` that have the same owner */ - def collectDirectlyEnclosedDefs(t: Tree): List[DefTree] = t match { - case ld: LabelDef => Nil - case dt: DefTree => dt :: Nil - case _: Function => Nil - case t => - val childDefs = t.children.flatMap(collectDirectlyEnclosedDefs(_)) - companionship.record(childDefs) - childDefs - } - mutable.LinkedHashMap(asyncStates.flatMap { - asyncState => - val defs = collectDirectlyEnclosedDefs(Block(asyncState.allStats: _*)) - defs.map((_, asyncState.state)) - }: _*) - } - - // In which block are these symbols defined? - val symToDefiningState: mutable.LinkedHashMap[Symbol, Int] = defs.map { - case (k, v) => (k.symbol, v) - } - - // The definitions trees - val symToTree: mutable.LinkedHashMap[Symbol, Tree] = defs.map { - case (k, v) => (k.symbol, k) - } - - // The direct references of each definition tree - val defSymToReferenced: mutable.LinkedHashMap[Symbol, List[Symbol]] = defs.map { - case (tree, _) => (tree.symbol, tree.collect { - case rt: RefTree if symToDefiningState.contains(rt.symbol) => rt.symbol - }) - } - - // The direct references of each block, excluding references of `DefTree`-s which - // are already accounted for. - val stateIdToDirectlyReferenced: mutable.LinkedHashMap[Int, List[Symbol]] = { - val result = new mutable.LinkedHashMap[Int, ListBuffer[Symbol]]() - asyncStates.foreach( - asyncState => asyncState.stats.filterNot(t => t.isDef && !isLabel(t.symbol)).foreach(_.foreach { - case rt: RefTree - if symToDefiningState.contains(rt.symbol) => - result.getOrElseUpdate(asyncState.state, new ListBuffer) += rt.symbol - case tt: TypeTree => - tt.tpe.foreach { tp => - val termSym = tp.termSymbol - if (symToDefiningState.contains(termSym)) - result.getOrElseUpdate(asyncState.state, new ListBuffer) += termSym - val typeSym = tp.typeSymbol - if (symToDefiningState.contains(typeSym)) - result.getOrElseUpdate(asyncState.state, new ListBuffer) += typeSym - } - case _ => - }) - ) - result.map { case (a, b) => (a, b.result())} - } - - def liftableSyms: mutable.LinkedHashSet[Symbol] = { - val liftableMutableSet = mutable.LinkedHashSet[Symbol]() - def markForLift(sym: Symbol): Unit = { - if (!liftableMutableSet(sym)) { - liftableMutableSet += sym - - // Only mark transitive references of defs, modules and classes. The RHS of lifted vals/vars - // stays in its original location, so things that it refers to need not be lifted. - if (!(sym.isTerm && !sym.asTerm.isLazy && (sym.asTerm.isVal || sym.asTerm.isVar))) - defSymToReferenced(sym).foreach(sym2 => markForLift(sym2)) - } - } - // Start things with DefTrees directly referenced from statements from other states... - val liftableStatementRefs: List[Symbol] = stateIdToDirectlyReferenced.iterator.flatMap { - case (i, syms) => syms.filter(sym => symToDefiningState(sym) != i) - }.toList - // .. and likewise for DefTrees directly referenced by other DefTrees from other states - val liftableRefsOfDefTrees = defSymToReferenced.toList.flatMap { - case (referee, referents) => referents.filter(sym => symToDefiningState(sym) != symToDefiningState(referee)) - } - // Mark these for lifting, which will follow transitive references. - (liftableStatementRefs ++ liftableRefsOfDefTrees).foreach(markForLift) - liftableMutableSet - } - - liftableSyms.iterator.map(symToTree).map { - t => - val sym = t.symbol - val treeLifted = t match { - case vd@ValDef(_, _, tpt, rhs) => - sym.setFlag(MUTABLE | STABLE | PRIVATE | LOCAL) - sym.setName(name.fresh(sym.name.toTermName)) - sym.setInfo(deconst(sym.info)) - val rhs1 = if (sym.asTerm.isLazy) rhs else EmptyTree - treeCopy.ValDef(vd, Modifiers(sym.flags), sym.name, TypeTree(tpe(sym)).setPos(t.pos), rhs1) - case dd@DefDef(_, _, tparams, vparamss, tpt, rhs) => - sym.setName(this.name.freshen(sym.name.toTermName)) - sym.setFlag(PRIVATE | LOCAL) - // Was `DefDef(sym, rhs)`, but this ran afoul of `ToughTypeSpec.nestedMethodWithInconsistencyTreeAndInfoParamSymbols` - // due to the handling of type parameter skolems in `thisMethodType` in `Namers` - treeCopy.DefDef(dd, Modifiers(sym.flags), sym.name, tparams, vparamss, tpt, rhs) - case cd@ClassDef(_, _, tparams, impl) => - sym.setName(name.freshen(sym.name.toTypeName)) - companionship.companionOf(cd.symbol) match { - case NoSymbol => - case moduleSymbol => - moduleSymbol.setName(sym.name.toTermName) - moduleSymbol.asModule.moduleClass.setName(moduleSymbol.name.toTypeName) - } - treeCopy.ClassDef(cd, Modifiers(sym.flags), sym.name, tparams, impl) - case md@ModuleDef(_, _, impl) => - companionship.companionOf(md.symbol) match { - case NoSymbol => - sym.setName(name.freshen(sym.name.toTermName)) - sym.asModule.moduleClass.setName(sym.name.toTypeName) - case classSymbol => // will be renamed by `case ClassDef` above. - } - treeCopy.ModuleDef(md, Modifiers(sym.flags), sym.name, impl) - case td@TypeDef(_, _, tparams, rhs) => - sym.setName(name.freshen(sym.name.toTypeName)) - treeCopy.TypeDef(td, Modifiers(sym.flags), sym.name, tparams, rhs) - } - atPos(t.pos)(treeLifted) - }.toList - } -} diff --git a/src/main/scala/scala/async/internal/LiveVariables.scala b/src/main/scala/scala/async/internal/LiveVariables.scala deleted file mode 100644 index 2f7ecc29..00000000 --- a/src/main/scala/scala/async/internal/LiveVariables.scala +++ /dev/null @@ -1,313 +0,0 @@ -/* - * Scala (https://www.scala-lang.org) - * - * Copyright EPFL and Lightbend, Inc. - * - * Licensed under Apache License 2.0 - * (http://www.apache.org/licenses/LICENSE-2.0). - * - * See the NOTICE file distributed with this work for - * additional information regarding copyright ownership. - */ - -package scala.async.internal - -import scala.collection.mutable - -import java.util.function.IntConsumer - -import scala.collection.immutable.IntMap - -trait LiveVariables { - self: AsyncMacro => - import c.universe._ - import Flag._ - - /** - * Returns for a given state a list of fields (as trees) that should be nulled out - * upon resuming that state (at the beginning of `resume`). - * - * @param asyncStates the states of an `async` block - * @param liftables the lifted fields - * @return a map mapping a state to the fields that should be nulled out - * upon resuming that state - */ - def fieldsToNullOut(asyncStates: List[AsyncState], liftables: List[Tree]): mutable.LinkedHashMap[Int, List[Tree]] = { - // live variables analysis: - // the result map indicates in which states a given field should be nulled out - val liveVarsMap: mutable.LinkedHashMap[Tree, StateSet] = liveVars(asyncStates, liftables) - - var assignsOf = mutable.LinkedHashMap[Int, List[Tree]]() - - for ((fld, where) <- liveVarsMap) { - where.foreach { new IntConsumer { def accept(state: Int): Unit = { - assignsOf get state match { - case None => - assignsOf += (state -> List(fld)) - case Some(trees) if !trees.exists(_.symbol == fld.symbol) => - assignsOf += (state -> (fld +: trees)) - case _ => - // do nothing - } - }}} - } - - assignsOf - } - - /** - * Live variables data-flow analysis. - * - * The goal is to find, for each lifted field, the last state where the field is used. - * In all direct successor states which are not (indirect) predecessors of that last state - * (possible through loops), the corresponding field should be nulled out (at the beginning of - * `resume`). - * - * @param asyncStates the states of an `async` block - * @param liftables the lifted fields - * @return a map which indicates for a given field (the key) the states in which it should be nulled out - */ - def liveVars(asyncStates: List[AsyncState], liftables: List[Tree]): mutable.LinkedHashMap[Tree, StateSet] = { - val liftedSyms: Set[Symbol] = // include only vars - liftables.iterator.filter { - case ValDef(mods, _, _, _) => mods.hasFlag(MUTABLE) - case _ => false - }.map(_.symbol).toSet - - // determine which fields should be live also at the end (will not be nulled out) - val noNull: Set[Symbol] = liftedSyms.filter { sym => - val typeSym = tpe(sym).typeSymbol - (typeSym.isClass && (typeSym.asClass.isPrimitive || typeSym == definitions.NothingClass)) || liftables.exists { tree => - !liftedSyms.contains(tree.symbol) && tree.exists(_.symbol == sym) - } - } - AsyncUtils.vprintln(s"fields never zero-ed out: ${noNull.mkString(", ")}") - - /** - * Traverse statements of an `AsyncState`, collect `Ident`-s referring to lifted fields. - * - * @param as a state of an `async` expression - * @return a set of lifted fields that are used within state `as` - */ - def fieldsUsedIn(as: AsyncState): ReferencedFields = { - class FindUseTraverser extends AsyncTraverser { - var usedFields: Set[Symbol] = Set[Symbol]() - var capturedFields: Set[Symbol] = Set[Symbol]() - private def capturing[A](body: => A): A = { - val saved = capturing - try { - capturing = true - body - } finally capturing = saved - } - private def capturingCheck(tree: Tree) = capturing(tree foreach check) - private var capturing: Boolean = false - private def check(tree: Tree): Unit = { - tree match { - case Ident(_) if liftedSyms(tree.symbol) => - if (capturing) - capturedFields += tree.symbol - else - usedFields += tree.symbol - case _ => - } - } - override def traverse(tree: Tree) = { - check(tree) - super.traverse(tree) - } - - override def nestedClass(classDef: ClassDef): Unit = capturingCheck(classDef) - - override def nestedModule(module: ModuleDef): Unit = capturingCheck(module) - - override def nestedMethod(defdef: DefDef): Unit = capturingCheck(defdef) - - override def byNameArgument(arg: Tree): Unit = capturingCheck(arg) - - override def function(function: Function): Unit = capturingCheck(function) - - override def patMatFunction(tree: Match): Unit = capturingCheck(tree) - } - - val findUses = new FindUseTraverser - findUses.traverse(Block(as.stats: _*)) - ReferencedFields(findUses.usedFields, findUses.capturedFields) - } - case class ReferencedFields(used: Set[Symbol], captured: Set[Symbol]) { - override def toString = s"used: ${used.mkString(",")}\ncaptured: ${captured.mkString(",")}" - } - - /* Build the control-flow graph. - * - * A state `i` is contained in the list that is the value to which - * key `j` maps iff control can flow from state `j` to state `i`. - */ - val cfg: IntMap[Array[Int]] = { - var res = IntMap.empty[Array[Int]] - - for (as <- asyncStates) res = res.updated(as.state, as.nextStates) - res - } - - /** Tests if `state1` is a predecessor of `state2`. - */ - def isPred(state1: Int, state2: Int): Boolean = { - val seen = new StateSet() - - def isPred0(state1: Int, state2: Int): Boolean = - if(state1 == state2) false - else if (seen.contains(state1)) false // breaks cycles in the CFG - else cfg getOrElse(state1, null) match { - case null => false - case nextStates => - seen += state1 - var i = 0 - while (i < nextStates.length) { - if (nextStates(i) == state2 || isPred0(nextStates(i), state2)) return true - i += 1 - } - false - } - - isPred0(state1, state2) - } - - val finalState = asyncStates.find(as => !asyncStates.exists(other => isPred(as.state, other.state))).get - - if(AsyncUtils.verbose) { - for (as <- asyncStates) - AsyncUtils.vprintln(s"fields used in state #${as.state}: ${fieldsUsedIn(as)}") - } - - /* Backwards data-flow analysis. Computes live variables information at entry and exit - * of each async state. - * - * Compute using a simple fixed point iteration: - * - * 1. currStates = List(finalState) - * 2. for each cs \in currStates, compute LVentry(cs) from LVexit(cs) and used fields information for cs - * 3. record if LVentry(cs) has changed for some cs. - * 4. obtain predecessors pred of each cs \in currStates - * 5. for each p \in pred, compute LVexit(p) as union of the LVentry of its successors - * 6. currStates = pred - * 7. repeat if something has changed - */ - - var LVentry = IntMap[Set[Symbol]]() withDefaultValue Set[Symbol]() - var LVexit: Map[Int, Set[Symbol]] = IntMap[Set[Symbol]]() withDefaultValue Set[Symbol]() - - // All fields are declared to be dead at the exit of the final async state, except for the ones - // that cannot be nulled out at all (those in noNull), because they have been captured by a nested def. - LVexit = LVexit + (finalState.state -> noNull) - - var currStates = List(finalState) // start at final state - var captured: Set[Symbol] = Set() - - def contains(as: Array[Int], a: Int): Boolean = { - var i = 0 - while (i < as.length) { - if (as(i) == a) return true - i += 1 - } - false - } - while (!currStates.isEmpty) { - var entryChanged: List[AsyncState] = Nil - - for (cs <- currStates) { - val LVentryOld = LVentry(cs.state) - val referenced = fieldsUsedIn(cs) - captured ++= referenced.captured - val LVentryNew = LVexit(cs.state) ++ referenced.used - if (!LVentryNew.sameElements(LVentryOld)) { - LVentry = LVentry.updated(cs.state, LVentryNew) - entryChanged ::= cs - } - } - - val pred = entryChanged.flatMap(cs => asyncStates.filter(state => contains(state.nextStates, cs.state))) - var exitChanged: List[AsyncState] = Nil - - for (p <- pred) { - val LVexitOld = LVexit(p.state) - val LVexitNew = p.nextStates.flatMap(succ => LVentry(succ)).toSet - if (!LVexitNew.sameElements(LVexitOld)) { - LVexit = LVexit.updated(p.state, LVexitNew) - exitChanged ::= p - } - } - - currStates = exitChanged - } - - if(AsyncUtils.verbose) { - for (as <- asyncStates) { - AsyncUtils.vprintln(s"LVentry at state #${as.state}: ${LVentry(as.state).mkString(", ")}") - AsyncUtils.vprintln(s"LVexit at state #${as.state}: ${LVexit(as.state).mkString(", ")}") - } - } - - def lastUsagesOf(field: Tree, at: AsyncState): StateSet = { - val avoid = scala.collection.mutable.HashSet[AsyncState]() - - val result = new StateSet - def lastUsagesOf0(field: Tree, at: AsyncState): Unit = { - if (avoid(at)) () - else if (captured(field.symbol)) { - () - } - else LVentry get at.state match { - case Some(fields) if fields.contains(field.symbol) => - result += at.state - case _ => - avoid += at - for (state <- asyncStates) { - if (contains(state.nextStates, at.state)) { - lastUsagesOf0(field, state) - } - } - } - } - - lastUsagesOf0(field, at) - result - } - - val lastUsages: mutable.LinkedHashMap[Tree, StateSet] = - mutable.LinkedHashMap(liftables.map(fld => fld -> lastUsagesOf(fld, finalState)): _*) - - if(AsyncUtils.verbose) { - for ((fld, lastStates) <- lastUsages) - AsyncUtils.vprintln(s"field ${fld.symbol.name} is last used in states ${lastStates.iterator.mkString(", ")}") - } - - val nullOutAt: mutable.LinkedHashMap[Tree, StateSet] = - for ((fld, lastStates) <- lastUsages) yield { - var result = new StateSet - lastStates.foreach(new IntConsumer { def accept(s: Int): Unit = { - if (s != finalState.state) { - val lastAsyncState = asyncStates.find(_.state == s).get - val succNums = lastAsyncState.nextStates - // all successor states that are not indirect predecessors - // filter out successor states where the field is live at the entry - var i = 0 - while (i < succNums.length) { - val num = succNums(i) - if (!isPred(num, s) && !LVentry(num).contains(fld.symbol)) - result += num - i += 1 - } - } - }}) - (fld, result) - } - - if(AsyncUtils.verbose) { - for ((fld, killAt) <- nullOutAt) - AsyncUtils.vprintln(s"field ${fld.symbol.name} should be nulled out in states ${killAt.iterator.mkString(", ")}") - } - - nullOutAt - } -} diff --git a/src/main/scala/scala/async/internal/ScalaConcurrentAsync.scala b/src/main/scala/scala/async/internal/ScalaConcurrentAsync.scala deleted file mode 100644 index 0b2b3711..00000000 --- a/src/main/scala/scala/async/internal/ScalaConcurrentAsync.scala +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Scala (https://www.scala-lang.org) - * - * Copyright EPFL and Lightbend, Inc. - * - * Licensed under Apache License 2.0 - * (http://www.apache.org/licenses/LICENSE-2.0). - * - * See the NOTICE file distributed with this work for - * additional information regarding copyright ownership. - */ - -package scala -package async -package internal - -import scala.reflect.macros.whitebox -import scala.concurrent.Future - -object ScalaConcurrentAsync extends AsyncBase { - type FS = ScalaConcurrentFutureSystem.type - val futureSystem: FS = ScalaConcurrentFutureSystem - - override def asyncImpl[T: c.WeakTypeTag](c: whitebox.Context) - (body: c.Expr[T]) - (execContext: c.Expr[futureSystem.ExecContext]): c.Expr[Future[T]] = { - super.asyncImpl[T](c)(body)(execContext) - } -} diff --git a/src/main/scala/scala/async/internal/StateAssigner.scala b/src/main/scala/scala/async/internal/StateAssigner.scala deleted file mode 100644 index 5e6c45e7..00000000 --- a/src/main/scala/scala/async/internal/StateAssigner.scala +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Scala (https://www.scala-lang.org) - * - * Copyright EPFL and Lightbend, Inc. - * - * Licensed under Apache License 2.0 - * (http://www.apache.org/licenses/LICENSE-2.0). - * - * See the NOTICE file distributed with this work for - * additional information regarding copyright ownership. - */ - -package scala.async.internal - -private[async] final class StateAssigner { - private var current = StateAssigner.Initial - - def nextState(): Int = try current finally current += 1 -} - -object StateAssigner { - final val Initial = 0 -} diff --git a/src/main/scala/scala/async/internal/StateSet.scala b/src/main/scala/scala/async/internal/StateSet.scala deleted file mode 100644 index 7b7c8124..00000000 --- a/src/main/scala/scala/async/internal/StateSet.scala +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Scala (https://www.scala-lang.org) - * - * Copyright EPFL and Lightbend, Inc. - * - * Licensed under Apache License 2.0 - * (http://www.apache.org/licenses/LICENSE-2.0). - * - * See the NOTICE file distributed with this work for - * additional information regarding copyright ownership. - */ - -package scala.async.internal - -import java.util -import java.util.function.{Consumer, IntConsumer} - -import scala.collection.JavaConverters.{asScalaIteratorConverter, iterableAsScalaIterableConverter} - -// Set for StateIds, which are either small positive integers or -symbolID. -final class StateSet { - private val bitSet = new java.util.BitSet() - private val caseSet = new util.HashSet[Integer]() - def +=(stateId: Int): Unit = if (storeInBitSet(stateId)) bitSet.set(stateId) else caseSet.add(stateId) - def contains(stateId: Int): Boolean = if (storeInBitSet(stateId)) bitSet.get(stateId) else caseSet.contains(stateId) - private def storeInBitSet(stateId: Int) = { - stateId > 0 && stateId < 1024 - } - def iterator: Iterator[Integer] = { - bitSet.stream().iterator().asScala ++ caseSet.asScala.iterator - } - def foreach(f: IntConsumer): Unit = { - bitSet.stream().forEach(f) - caseSet.stream().forEach(new Consumer[Integer] { - override def accept(value: Integer): Unit = f.accept(value) - }) - } -} diff --git a/src/main/scala/scala/async/internal/TransformUtils.scala b/src/main/scala/scala/async/internal/TransformUtils.scala deleted file mode 100644 index 1c1dd17a..00000000 --- a/src/main/scala/scala/async/internal/TransformUtils.scala +++ /dev/null @@ -1,590 +0,0 @@ -/* - * Scala (https://www.scala-lang.org) - * - * Copyright EPFL and Lightbend, Inc. - * - * Licensed under Apache License 2.0 - * (http://www.apache.org/licenses/LICENSE-2.0). - * - * See the NOTICE file distributed with this work for - * additional information regarding copyright ownership. - */ - -package scala.async.internal - -import scala.collection.immutable.ListMap -import scala.collection.mutable -import scala.collection.mutable.ListBuffer - -/** - * Utilities used in both `ExprBuilder` and `AnfTransform`. - */ -private[async] trait TransformUtils { - self: AsyncMacro => - - import c.universe._ - import c.internal._ - import decorators._ - - object name extends asyncNames.AsyncName { - def fresh(name: TermName): TermName = freshenIfNeeded(name) - def fresh(name: String): String = c.freshName(name) - } - - def maybeTry(block: Tree, catches: List[CaseDef], finalizer: Tree) = if (asyncBase.futureSystem.emitTryCatch) Try(block, catches, finalizer) else block - - def isAsync(fun: Tree) = - fun.symbol == defn.Async_async - - def isAwait(fun: Tree) = - fun.symbol == defn.Async_await - - def newBlock(stats: List[Tree], expr: Tree): Block = { - Block(stats, expr) - } - - def isLiteralUnit(t: Tree) = t match { - case Literal(Constant(())) => - true - case _ => false - } - - def isPastTyper = - c.universe.asInstanceOf[scala.reflect.internal.SymbolTable].isPastTyper - - // Copy pasted from TreeInfo in the compiler. - // Using a quasiquote pattern like `case q"$fun[..$targs](...$args)" => is not - // sufficient since https://github.com/scala/scala/pull/3656 as it doesn't match - // constructor invocations. - class Applied(val tree: Tree) { - /** The tree stripped of the possibly nested applications. - * The original tree if it's not an application. - */ - def callee: Tree = { - def loop(tree: Tree): Tree = tree match { - case Apply(fn, _) => loop(fn) - case tree => tree - } - loop(tree) - } - - /** The `callee` unwrapped from type applications. - * The original `callee` if it's not a type application. - */ - def core: Tree = callee match { - case TypeApply(fn, _) => fn - case AppliedTypeTree(fn, _) => fn - case tree => tree - } - - /** The type arguments of the `callee`. - * `Nil` if the `callee` is not a type application. - */ - def targs: List[Tree] = callee match { - case TypeApply(_, args) => args - case AppliedTypeTree(_, args) => args - case _ => Nil - } - - /** (Possibly multiple lists of) value arguments of an application. - * `Nil` if the `callee` is not an application. - */ - def argss: List[List[Tree]] = { - def loop(tree: Tree): List[List[Tree]] = tree match { - case Apply(fn, args) => loop(fn) :+ args - case _ => Nil - } - loop(tree) - } - } - - /** Returns a wrapper that knows how to destructure and analyze applications. - */ - def dissectApplied(tree: Tree) = new Applied(tree) - - /** Destructures applications into important subparts described in `Applied` class, - * namely into: core, targs and argss (in the specified order). - * - * Trees which are not applications are also accepted. Their callee and core will - * be equal to the input, while targs and argss will be Nil. - * - * The provided extractors don't expose all the API of the `Applied` class. - * For advanced use, call `dissectApplied` explicitly and use its methods instead of pattern matching. - */ - object Applied { - def apply(tree: Tree): Applied = new Applied(tree) - - def unapply(applied: Applied): Option[(Tree, List[Tree], List[List[Tree]])] = - Some((applied.core, applied.targs, applied.argss)) - - def unapply(tree: Tree): Option[(Tree, List[Tree], List[List[Tree]])] = - unapply(dissectApplied(tree)) - } - private lazy val Boolean_ShortCircuits: Set[Symbol] = { - import definitions.BooleanClass - def BooleanTermMember(name: String) = BooleanClass.typeSignature.member(TermName(name).encodedName) - val Boolean_&& = BooleanTermMember("&&") - val Boolean_|| = BooleanTermMember("||") - Set(Boolean_&&, Boolean_||) - } - - private def isByName(fun: Tree): ((Int, Int) => Boolean) = { - if (Boolean_ShortCircuits contains fun.symbol) (i, j) => true - else if (fun.tpe == null) (x, y) => false - else { - val paramLists = fun.tpe.paramLists - val byNamess = paramLists.map(_.map(_.asTerm.isByNameParam)) - (i, j) => util.Try(byNamess(i)(j)).getOrElse(false) - } - } - private def argName(fun: Tree): ((Int, Int) => TermName) = { - val paramLists = fun.tpe.paramLists - val namess = paramLists.map(_.map(_.name.toTermName)) - (i, j) => util.Try(namess(i)(j)).getOrElse(TermName(s"arg_${i}_${j}")) - } - - object defn { - def mkList_apply[A](args: List[Expr[A]]): Expr[List[A]] = { - c.Expr(Apply(Ident(definitions.List_apply), args.map(_.tree))) - } - - def mkList_contains[A](self: Expr[List[A]])(elem: Expr[Any]) = reify { - self.splice.contains(elem.splice) - } - - def mkAny_==(self: Expr[Any])(other: Expr[Any]) = reify { - self.splice == other.splice - } - - def mkTry_get[A](self: Expr[util.Try[A]]) = reify { - self.splice.get - } - - val NonFatalClass = rootMirror.staticModule("scala.util.control.NonFatal") - val ThrowableClass = rootMirror.staticClass("java.lang.Throwable") - lazy val Async_async = asyncBase.asyncMethod(c.universe)(c.macroApplication.symbol) - lazy val Async_await = asyncBase.awaitMethod(c.universe)(c.macroApplication.symbol) - val IllegalStateExceptionClass = rootMirror.staticClass("java.lang.IllegalStateException") - } - - // `while(await(x))` ... or `do { await(x); ... } while(...)` contain an `If` that loops; - // we must break that `If` into states so that it convert the label jump into a state machine - // transition - final def containsForiegnLabelJump(t: Tree): Boolean = { - val labelDefs = t.collect { - case ld: LabelDef => ld.symbol - }.toSet - val result = t.exists { - case rt: RefTree => rt.symbol != null && isLabel(rt.symbol) && !(labelDefs contains rt.symbol) - case _ => false - } - result - } - - def isLabel(sym: Symbol): Boolean = { - val LABEL = 1L << 17 // not in the public reflection API. - (internal.flags(sym).asInstanceOf[Long] & LABEL) != 0L - } - def isSynth(sym: Symbol): Boolean = { - val SYNTHETIC = 1 << 21 // not in the public reflection API. - (internal.flags(sym).asInstanceOf[Long] & SYNTHETIC) != 0L - } - def symId(sym: Symbol): Int = { - val symtab = this.c.universe.asInstanceOf[reflect.internal.SymbolTable] - sym.asInstanceOf[symtab.Symbol].id - } - def substituteTrees(t: Tree, from: List[Symbol], to: List[Tree]): Tree = { - val symtab = this.c.universe.asInstanceOf[reflect.internal.SymbolTable] - val subst = new symtab.TreeSubstituter(from.asInstanceOf[List[symtab.Symbol]], to.asInstanceOf[List[symtab.Tree]]) - subst.transform(t.asInstanceOf[symtab.Tree]).asInstanceOf[Tree] - } - - - /** Map a list of arguments to: - * - A list of argument Trees - * - A list of auxillary results. - * - * The function unwraps and rewraps the `arg :_*` construct. - * - * @param args The original argument trees - * @param f A function from argument (with '_*' unwrapped) and argument index to argument. - * @tparam A The type of the auxillary result - */ - private def mapArguments[A](args: List[Tree])(f: (Tree, Int) => (A, Tree)): (List[A], List[Tree]) = { - args match { - case args :+ Typed(tree, Ident(typeNames.WILDCARD_STAR)) => - val (a, argExprs :+ lastArgExpr) = (args :+ tree).zipWithIndex.map(f.tupled).unzip - val exprs = argExprs :+ atPos(lastArgExpr.pos.makeTransparent)(Typed(lastArgExpr, Ident(typeNames.WILDCARD_STAR))) - (a, exprs) - case args => - args.zipWithIndex.map(f.tupled).unzip - } - } - - case class Arg(expr: Tree, isByName: Boolean, argName: TermName) - - /** - * Transform a list of argument lists, producing the transformed lists, and lists of auxillary - * results. - * - * The function `f` need not concern itself with varargs arguments e.g (`xs : _*`). It will - * receive `xs`, and it's result will be re-wrapped as `f(xs) : _*`. - * - * @param fun The function being applied - * @param argss The argument lists - * @return (auxillary results, mapped argument trees) - */ - def mapArgumentss[A](fun: Tree, argss: List[List[Tree]])(f: Arg => (A, Tree)): (List[List[A]], List[List[Tree]]) = { - val isByNamess: (Int, Int) => Boolean = isByName(fun) - val argNamess: (Int, Int) => TermName = argName(fun) - argss.zipWithIndex.map { case (args, i) => - mapArguments[A](args) { - (tree, j) => f(Arg(tree, isByNamess(i, j), argNamess(i, j))) - } - }.unzip - } - - - def statsAndExpr(tree: Tree): (List[Tree], Tree) = tree match { - case Block(stats, expr) => (stats, expr) - case _ => (List(tree), Literal(Constant(()))) - } - - def emptyConstructor: DefDef = { - val emptySuperCall = Apply(Select(Super(This(typeNames.EMPTY), typeNames.EMPTY), termNames.CONSTRUCTOR), Nil) - DefDef(NoMods, termNames.CONSTRUCTOR, List(), List(List()), TypeTree(), Block(List(emptySuperCall), Literal(Constant(())))) - } - - def applied(className: String, types: List[Type]): AppliedTypeTree = - AppliedTypeTree(Ident(rootMirror.staticClass(className)), types.map(TypeTree(_))) - - /** Descends into the regions of the tree that are subject to the - * translation to a state machine by `async`. When a nested template, - * function, or by-name argument is encountered, the descent stops, - * and `nestedClass` etc are invoked. - */ - trait AsyncTraverser extends Traverser { - def nestedClass(classDef: ClassDef): Unit = { - } - - def nestedModule(module: ModuleDef): Unit = { - } - - def nestedMethod(defdef: DefDef): Unit = { - } - - def byNameArgument(arg: Tree): Unit = { - } - - def function(function: Function): Unit = { - } - - def patMatFunction(tree: Match): Unit = { - } - - override def traverse(tree: Tree): Unit = { - tree match { - case _ if isAsync(tree) => - // Under -Ymacro-expand:discard, used in the IDE, nested async blocks will be visible to the outer blocks - case cd: ClassDef => nestedClass(cd) - case md: ModuleDef => nestedModule(md) - case dd: DefDef => nestedMethod(dd) - case fun: Function => function(fun) - case m@Match(EmptyTree, _) => patMatFunction(m) // Pattern matching anonymous function under -Xoldpatmat of after `restorePatternMatchingFunctions` - case q"$fun[..$targs](...$argss)" if argss.nonEmpty => - val isInByName = isByName(fun) - for ((args, i) <- argss.zipWithIndex) { - for ((arg, j) <- args.zipWithIndex) { - if (!isInByName(i, j)) traverse(arg) - else byNameArgument(arg) - } - } - traverse(fun) - case _ => super.traverse(tree) - } - } - } - - def transformAt(tree: Tree)(f: PartialFunction[Tree, (TypingTransformApi => Tree)]) = { - typingTransform(tree)((tree, api) => { - if (f.isDefinedAt(tree)) f(tree)(api) - else api.default(tree) - }) - } - - def toMultiMap[A, B](abs: Iterable[(A, B)]): mutable.LinkedHashMap[A, List[B]] = { - // LinkedHashMap for stable order of results. - val result = new mutable.LinkedHashMap[A, ListBuffer[B]]() - for ((a, b) <- abs) { - val buffer = result.getOrElseUpdate(a, new ListBuffer[B]) - buffer += b - } - result.map { case (a, b) => (a, b.toList) } - } - - // Attributed version of `TreeGen#mkCastPreservingAnnotations` - def mkAttributedCastPreservingAnnotations(tree: Tree, tp: Type): Tree = { - atPos(tree.pos) { - val casted = c.typecheck(gen.mkCast(tree, uncheckedBounds(withoutAnnotations(tp)).dealias)) - Typed(casted, TypeTree(tp)).setType(tp) - } - } - - def deconst(tp: Type): Type = tp match { - case AnnotatedType(anns, underlying) => annotatedType(anns, deconst(underlying)) - case ExistentialType(quants, underlying) => existentialType(quants, deconst(underlying)) - case ConstantType(value) => deconst(value.tpe) - case _ => tp - } - - def withAnnotation(tp: Type, ann: Annotation): Type = withAnnotations(tp, List(ann)) - - def withAnnotations(tp: Type, anns: List[Annotation]): Type = tp match { - case AnnotatedType(existingAnns, underlying) => annotatedType(anns ::: existingAnns, underlying) - case ExistentialType(quants, underlying) => existentialType(quants, withAnnotations(underlying, anns)) - case _ => annotatedType(anns, tp) - } - - def withoutAnnotations(tp: Type): Type = tp match { - case AnnotatedType(anns, underlying) => withoutAnnotations(underlying) - case ExistentialType(quants, underlying) => existentialType(quants, withoutAnnotations(underlying)) - case _ => tp - } - - def tpe(sym: Symbol): Type = { - if (sym.isType) sym.asType.toType - else sym.info - } - - def thisType(sym: Symbol): Type = { - if (sym.isClass) sym.asClass.thisPrefix - else NoPrefix - } - - private def derivedValueClassUnbox(cls: Symbol) = - (cls.info.decls.find(sym => sym.isMethod && sym.asTerm.isParamAccessor) getOrElse NoSymbol) - - def mkZero(tp: Type): Tree = { - val tpSym = tp.typeSymbol - if (tpSym.isClass && tpSym.asClass.isDerivedValueClass) { - val argZero = mkZero(derivedValueClassUnbox(tpSym).infoIn(tp).resultType) - val baseType = tp.baseType(tpSym) // use base type here to dealias / strip phantom "tagged types" etc. - - // By explicitly attributing the types and symbols here, we subvert privacy. - // Otherwise, ticket86PrivateValueClass would fail. - - // Approximately: - // q"new ${valueClass}[$..targs](argZero)" - val target: Tree = gen.mkAttributedSelect( - c.typecheck(atMacroPos( - New(TypeTree(baseType)))), tpSym.asClass.primaryConstructor) - - val zero = gen.mkMethodCall(target, argZero :: Nil) - // restore the original type which we might otherwise have weakened with `baseType` above - c.typecheck(atMacroPos(gen.mkCast(zero, tp))) - } else { - gen.mkZero(tp) - } - } - - // ===================================== - // Copy/Pasted from Scala 2.10.3. See scala/bug#7694 - private lazy val UncheckedBoundsClass = - c.mirror.staticClass("scala.reflect.internal.annotations.uncheckedBounds") - final def uncheckedBounds(tp: Type): Type = - if ((tp.typeArgs.isEmpty && (tp match { case _: TypeRef => true; case _ => false}))) tp - else withAnnotation(tp, Annotation(UncheckedBoundsClass.asType.toType, Nil, ListMap())) - // ===================================== - - /** - * Efficiently decorate each subtree within `t` with the result of `t exists isAwait`, - * and return a function that can be used on derived trees to efficiently test the - * same condition. - * - * If the derived tree contains synthetic wrapper trees, these will be recursed into - * in search of a sub tree that was decorated with the cached answer. - */ - final def containsAwaitCached(t: Tree): Tree => Boolean = { - if (c.macroApplication.symbol == null) return (t => false) - - def treeCannotContainAwait(t: Tree) = t match { - case _: Ident | _: TypeTree | _: Literal => true - case _ => isAsync(t) - } - def shouldAttach(t: Tree) = !treeCannotContainAwait(t) - val symtab = c.universe.asInstanceOf[scala.reflect.internal.SymbolTable] - def attachContainsAwait(t: Tree): Unit = if (shouldAttach(t)) { - val t1 = t.asInstanceOf[symtab.Tree] - t1.updateAttachment(ContainsAwait) - t1.removeAttachment[NoAwait.type] - } - def attachNoAwait(t: Tree): Unit = if (shouldAttach(t)) { - val t1 = t.asInstanceOf[symtab.Tree] - t1.updateAttachment(NoAwait) - } - object markContainsAwaitTraverser extends Traverser { - var stack: List[Tree] = Nil - - override def traverse(tree: Tree): Unit = { - stack ::= tree - try { - if (isAsync(tree)) { - ; - } else { - if (isAwait(tree)) - stack.foreach(attachContainsAwait) - else - attachNoAwait(tree) - super.traverse(tree) - } - } finally stack = stack.tail - } - } - markContainsAwaitTraverser.traverse(t) - - (t: Tree) => { - object traverser extends Traverser { - var containsAwait = false - override def traverse(tree: Tree): Unit = { - def castTree = tree.asInstanceOf[symtab.Tree] - if (!castTree.hasAttachment[NoAwait.type]) { - if (castTree.hasAttachment[ContainsAwait.type]) - containsAwait = true - else if (!treeCannotContainAwait(t)) - super.traverse(tree) - } - } - } - traverser.traverse(t) - traverser.containsAwait - } - } - - final def cleanupContainsAwaitAttachments(t: Tree): t.type = { - val symtab = c.universe.asInstanceOf[scala.reflect.internal.SymbolTable] - t.foreach {t => - t.asInstanceOf[symtab.Tree].removeAttachment[ContainsAwait.type] - t.asInstanceOf[symtab.Tree].removeAttachment[NoAwait.type] - } - t - } - - // First modification to translated patterns: - // - Set the type of label jumps to `Unit` - // - Propagate this change to trees known to directly enclose them: - // ``If` / `Block`) adjust types of enclosing - final def adjustTypeOfTranslatedPatternMatches(t: Tree, owner: Symbol): Tree = { - import definitions.UnitTpe - typingTransform(t, owner) { - (tree, api) => - tree match { - case LabelDef(name, params, rhs) => - val rhs1 = api.recur(rhs) - if (rhs1.tpe =:= UnitTpe) { - internal.setInfo(tree.symbol, internal.methodType(tree.symbol.info.paramLists.head, UnitTpe)) - treeCopy.LabelDef(tree, name, params, rhs1) - } else { - treeCopy.LabelDef(tree, name, params, rhs1) - } - case Block(stats, expr) => - val stats1 = stats map api.recur - val expr1 = api.recur(expr) - if (expr1.tpe =:= UnitTpe) - internal.setType(treeCopy.Block(tree, stats1, expr1), UnitTpe) - else - treeCopy.Block(tree, stats1, expr1) - case If(cond, thenp, elsep) => - val cond1 = api.recur(cond) - val thenp1 = api.recur(thenp) - val elsep1 = api.recur(elsep) - if (thenp1.tpe =:= definitions.UnitTpe && elsep.tpe =:= UnitTpe) - internal.setType(treeCopy.If(tree, cond1, thenp1, elsep1), UnitTpe) - else - treeCopy.If(tree, cond1, thenp1, elsep1) - case Apply(fun, args) if isLabel(fun.symbol) => - internal.setType(treeCopy.Apply(tree, api.recur(fun), args map api.recur), UnitTpe) - case vd @ ValDef(mods, name, tpt, rhs) if isCaseTempVal(vd.symbol) => - def addUncheckedBounds(t: Tree) = { - typingTransform(t, owner) { - (tree, api) => - if (tree.tpe == null) tree else internal.setType(api.default(tree), uncheckedBoundsIfNeeded(tree.tpe)) - } - - } - val uncheckedRhs = addUncheckedBounds(api.recur(rhs)) - val uncheckedTpt = addUncheckedBounds(tpt) - internal.setInfo(vd.symbol, uncheckedBoundsIfNeeded(vd.symbol.info)) - treeCopy.ValDef(vd, mods, name, uncheckedTpt, uncheckedRhs) - case t => api.default(t) - } - } - } - - private def isExistentialSkolem(s: Symbol) = { - val EXISTENTIAL: Long = 1L << 35 - internal.isSkolem(s) && (internal.flags(s).asInstanceOf[Long] & EXISTENTIAL) != 0 - } - private def isCaseTempVal(s: Symbol) = { - s.isTerm && s.asTerm.isVal && s.isSynthetic && s.name.toString.startsWith("x") - } - - def uncheckedBoundsIfNeeded(t: Type): Type = { - var quantified: List[Symbol] = Nil - var badSkolemRefs: List[Symbol] = Nil - t.foreach { - case et: ExistentialType => - quantified :::= et.quantified - case TypeRef(pre, sym, args) => - val illScopedSkolems = args.map(_.typeSymbol).filter(arg => isExistentialSkolem(arg) && !quantified.contains(arg)) - badSkolemRefs :::= illScopedSkolems - case _ => - } - if (badSkolemRefs.isEmpty) t - else t.map { - case tp @ TypeRef(pre, sym, args) if args.exists(a => badSkolemRefs.contains(a.typeSymbol)) => - uncheckedBounds(tp) - case t => t - } - } - - - final def mkMutableField(tpt: Type, name: TermName, init: Tree): List[Tree] = { - if (isPastTyper) { - // If we are running after the typer phase (ie being called from a compiler plugin) - // we have to create the trio of members manually. - val ACCESSOR = (1L << 27).asInstanceOf[FlagSet] - val STABLE = (1L << 22).asInstanceOf[FlagSet] - val field = ValDef(Modifiers(Flag.MUTABLE | Flag.PRIVATE | Flag.LOCAL), TermName(name.toString + " "), TypeTree(tpt), init) - val getter = DefDef(Modifiers(ACCESSOR | STABLE), name, Nil, Nil, TypeTree(tpt), Select(This(typeNames.EMPTY), field.name)) - val setter = DefDef(Modifiers(ACCESSOR), TermName(name.toString + "_="), Nil, List(List(ValDef(NoMods, TermName("x"), TypeTree(tpt), EmptyTree))), TypeTree(definitions.UnitTpe), Assign(Select(This(typeNames.EMPTY), field.name), Ident(TermName("x")))) - field :: getter :: setter :: Nil - } else { - val result = ValDef(NoMods, name, TypeTree(tpt), init) - result :: Nil - } - } - - def deriveLabelDef(ld: LabelDef, applyToRhs: Tree => Tree): LabelDef = { - val rhs2 = applyToRhs(ld.rhs) - val ld2 = treeCopy.LabelDef(ld, ld.name, ld.params, rhs2) - if (ld eq ld2) ld - else { - val info2 = ld2.symbol.info match { - case MethodType(params, p) => internal.methodType(params, rhs2.tpe) - case t => t - } - internal.setInfo(ld2.symbol, info2) - ld2 - } - } - object MatchEnd { - def unapply(t: Tree): Option[LabelDef] = t match { - case ValDef(_, _, _, t) => unapply(t) - case ld: LabelDef if ld.name.toString.startsWith("matchEnd") => Some(ld) - case _ => None - } - } -} - -case object ContainsAwait -case object NoAwait diff --git a/src/test/scala/scala/async/FutureSpec.scala b/src/test/scala/scala/async/FutureSpec.scala new file mode 100644 index 00000000..d692f621 --- /dev/null +++ b/src/test/scala/scala/async/FutureSpec.scala @@ -0,0 +1,541 @@ +/* + * Scala (https://www.scala-lang.org) + * + * Copyright EPFL and Lightbend, Inc. + * + * Licensed under Apache License 2.0 + * (http://www.apache.org/licenses/LICENSE-2.0). + * + * See the NOTICE file distributed with this work for + * additional information regarding copyright ownership. + */ +package scala.async + +import java.util.concurrent.ConcurrentHashMap + +import org.junit.Test + +import scala.async.Async.{async, await} +import scala.async.TestUtil._ +import scala.concurrent.duration.Duration.Inf +import scala.concurrent.duration._ +import scala.concurrent.{ExecutionContext, Future, Promise, _} +import scala.language.postfixOps +import scala.util.{Failure, Success} + +class FutureSpec { + + def testAsync(s: String)(implicit ec: ExecutionContext): Future[String] = s match { + case "Hello" => Future { "World" } + case "Failure" => Future.failed(new RuntimeException("Expected exception; to test fault-tolerance")) + case "NoReply" => Promise[String]().future + } + + val defaultTimeout = 5 seconds + + /* future specification */ + + @Test def `A future with custom ExecutionContext should handle Throwables`(): Unit = { + val ms = new ConcurrentHashMap[Throwable, Unit] + implicit val ec = scala.concurrent.ExecutionContext.fromExecutor(new java.util.concurrent.ForkJoinPool(), { + t => + ms.put(t, ()) + }) + + class ThrowableTest(m: String) extends Throwable(m) + + val f1 = Future[Any] { + throw new ThrowableTest("test") + } + + intercept[ThrowableTest] { + Await.result(f1, defaultTimeout) + } + + val latch = new TestLatch + val f2 = Future { + Await.ready(latch, 5 seconds) + "success" + } + val f3 = async { + val s = await(f2) + s.toUpperCase + } + + f2 foreach { _ => throw new ThrowableTest("dispatcher foreach") } + f2 onComplete { case Success(_) => throw new ThrowableTest("dispatcher receive") case _ => } + + latch.open() + + Await.result(f2, defaultTimeout) mustBe ("success") + + f2 foreach { _ => throw new ThrowableTest("current thread foreach") } + f2 onComplete { case Success(_) => throw new ThrowableTest("current thread receive"); case _ => } + + Await.result(f3, defaultTimeout) mustBe ("SUCCESS") + + val waiting = Future { + Thread.sleep(1000) + } + Await.ready(waiting, 2000 millis) + + ms.size mustBe (4) + } + + import ExecutionContext.Implicits._ + + @Test def `A future with global ExecutionContext should compose with for-comprehensions`(): Unit = { + + def asyncInt(x: Int) = Future { (x * 2).toString } + val future0 = Future[Any] { + "five!".length + } + + val future1 = async { + val a = await(future0.mapTo[Int]) // returns 5 + val b = await(asyncInt(a)) // returns "10" + val c = await(asyncInt(7)) // returns "14" + b + "-" + c + } + + val future2 = async { + val a = await(future0.mapTo[Int]) + val b = await((Future { (a * 2).toString }).mapTo[Int]) + val c = await(Future { (7 * 2).toString }) + b + "-" + c + } + + Await.result(future1, defaultTimeout) mustBe ("10-14") + //assert(checkType(future1, manifest[String])) + intercept[ClassCastException] { Await.result(future2, defaultTimeout) } + } + + //TODO this is not yet supported by Async + @Test def `support pattern matching within a for-comprehension`(): Unit = { + case class Req[T](req: T) + case class Res[T](res: T) + def asyncReq[T](req: Req[T]) = (req: @unchecked) match { + case Req(s: String) => Future { Res(s.length) } + case Req(i: Int) => Future { Res((i * 2).toString) } + } + + val future1 = for { + Res(a: Int) <- asyncReq(Req("Hello")) + Res(b: String) <- asyncReq(Req(a)) + Res(c: String) <- asyncReq(Req(7)) + } yield b + "-" + c + + val future2 = for { + Res(a: Int) <- asyncReq(Req("Hello")) + Res(b: Int) <- asyncReq(Req(a)) + Res(c: Int) <- asyncReq(Req(7)) + } yield b + "-" + c + + Await.result(future1, defaultTimeout) mustBe ("10-14") + intercept[NoSuchElementException] { Await.result(future2, defaultTimeout) } + } + + @Test def mini(): Unit = { + val future4 = async { + await(Future.successful(0)).toString + } + Await.result(future4, defaultTimeout) + } + + @Test def `recover from exceptions`(): Unit = { + val future1 = Future(5) + val future2 = async { await(future1) / 0 } + val future3 = async { await(future2).toString } + + val future1Recovered = future1 recover { + case e: ArithmeticException => 0 + } + val future4 = async { await(future1Recovered).toString } + + val future2Recovered = future2 recover { + case e: ArithmeticException => 0 + } + val future5 = async { await(future2Recovered).toString } + + val future2Recovered2 = future2 recover { + case e: MatchError => 0 + } + val future6 = async { await(future2Recovered2).toString } + + val future7 = future3 recover { + case e: ArithmeticException => "You got ERROR" + } + + val future8 = testAsync("Failure") + val future9 = testAsync("Failure") recover { + case e: RuntimeException => "FAIL!" + } + val future10 = testAsync("Hello") recover { + case e: RuntimeException => "FAIL!" + } + val future11 = testAsync("Failure") recover { + case _ => "Oops!" + } + + Await.result(future1, defaultTimeout) mustBe (5) + intercept[ArithmeticException] { Await.result(future2, defaultTimeout) } + intercept[ArithmeticException] { Await.result(future3, defaultTimeout) } + Await.result(future4, defaultTimeout) mustBe ("5") + Await.result(future5, defaultTimeout) mustBe ("0") + intercept[ArithmeticException] { Await.result(future6, defaultTimeout) } + Await.result(future7, defaultTimeout) mustBe ("You got ERROR") + intercept[RuntimeException] { Await.result(future8, defaultTimeout) } + Await.result(future9, defaultTimeout) mustBe ("FAIL!") + Await.result(future10, defaultTimeout) mustBe ("World") + Await.result(future11, defaultTimeout) mustBe ("Oops!") + } + + @Test def `recoverWith from exceptions`(): Unit = { + val o = new IllegalStateException("original") + val r = new IllegalStateException("recovered") + + intercept[IllegalStateException] { + val failed = Future.failed[String](o) recoverWith { + case _ if false == true => Future.successful("yay!") + } + Await.result(failed, defaultTimeout) + } mustBe (o) + + val recovered = Future.failed[String](o) recoverWith { + case _ => Future.successful("yay!") + } + Await.result(recovered, defaultTimeout) mustBe ("yay!") + + intercept[IllegalStateException] { + val refailed = Future.failed[String](o) recoverWith { + case _ => Future.failed[String](r) + } + Await.result(refailed, defaultTimeout) + } mustBe (r) + } + + @Test def `andThen like a boss`(): Unit = { + val q = new java.util.concurrent.LinkedBlockingQueue[Int] + for (i <- 1 to 1000) { + val chained = Future { + q.add(1); 3 + } andThen { + case _ => q.add(2) + } andThen { + case Success(0) => q.add(Int.MaxValue) + } andThen { + case _ => q.add(3); + } + Await.result(chained, defaultTimeout) mustBe (3) + q.poll() mustBe (1) + q.poll() mustBe (2) + q.poll() mustBe (3) + q.clear() + } + } + + @Test def `firstCompletedOf`(): Unit = { + def futures = Vector.fill[Future[Int]](10) { + Promise[Int]().future + } :+ Future.successful[Int](5) + + Await.result(Future.firstCompletedOf(futures), defaultTimeout) mustBe (5) + Await.result(Future.firstCompletedOf(futures.iterator), defaultTimeout) mustBe (5) + } + + @Test def `find`(): Unit = { + val futures = for (i <- 1 to 10) yield Future { + i + } + + val result = Future.find[Int](futures)(_ == 3) + Await.result(result, defaultTimeout) mustBe (Some(3)) + + val notFound = Future.find[Int](futures)(_ == 11) + Await.result(notFound, defaultTimeout) mustBe (None) + } + + @Test def `zip`(): Unit = { + val timeout = 10000 millis + val f = new IllegalStateException("test") + intercept[IllegalStateException] { + val failed = Future.failed[String](f) zip Future.successful("foo") + Await.result(failed, timeout) + } mustBe (f) + + intercept[IllegalStateException] { + val failed = Future.successful("foo") zip Future.failed[String](f) + Await.result(failed, timeout) + } mustBe (f) + + intercept[IllegalStateException] { + val failed = Future.failed[String](f) zip Future.failed[String](f) + Await.result(failed, timeout) + } mustBe (f) + + val successful = Future.successful("foo") zip Future.successful("foo") + Await.result(successful, timeout) mustBe (("foo", "foo")) + } + + @Test def `fold`(): Unit = { + val timeout = 10000 millis + def async(add: Int, wait: Int) = Future { + Thread.sleep(wait) + add + } + + val futures = (0 to 9) map { + idx => async(idx, idx * 20) + } + val folded = Future.foldLeft(futures)(0)(_ + _) + Await.result(folded, timeout) mustBe (45) + + val futuresit = (0 to 9) map { + idx => async(idx, idx * 20) + } + val foldedit = Future.foldLeft(futures)(0)(_ + _) + Await.result(foldedit, timeout) mustBe (45) + } + + @Test def `fold by composing`(): Unit = { + val timeout = 10000 millis + def async(add: Int, wait: Int) = Future { + Thread.sleep(wait) + add + } + def futures = (0 to 9) map { + idx => async(idx, idx * 20) + } + val folded = futures.foldLeft(Future(0)) { + case (fr, fa) => for (r <- fr; a <- fa) yield (r + a) + } + Await.result(folded, timeout) mustBe (45) + } + + @Test def `fold with an exception`(): Unit = { + val timeout = 10000 millis + def async(add: Int, wait: Int) = Future { + Thread.sleep(wait) + if (add == 6) throw new IllegalArgumentException("shouldFoldResultsWithException: expected") + add + } + def futures = (0 to 9) map { + idx => async(idx, idx * 10) + } + val folded = Future.foldLeft(futures)(0)(_ + _) + intercept[IllegalArgumentException] { + Await.result(folded, timeout) + }.getMessage mustBe ("shouldFoldResultsWithException: expected") + } + + @Test def `fold mutable zeroes safely`(): Unit = { + import scala.collection.mutable.ArrayBuffer + def test(testNumber: Int): Unit = { + val fs = (0 to 1000) map (i => Future(i)) + val f = Future.foldLeft(fs)(ArrayBuffer.empty[AnyRef]) { + case (l, i) if i % 2 == 0 => l += i.asInstanceOf[AnyRef] + case (l, _) => l + } + val result = Await.result(f.mapTo[ArrayBuffer[Int]], 10000 millis).sum + + assert(result == 250500) + } + + (1 to 100) foreach test //Make sure it tries to provoke the problem + } + + @Test def `return zero value if folding empty list`(): Unit = { + val zero = Future.foldLeft(List[Future[Int]]())(0)(_ + _) + Await.result(zero, defaultTimeout) mustBe (0) + } + + @Test def `shouldReduceResults`(): Unit = { + def async(idx: Int) = Future { + Thread.sleep(idx * 20) + idx + } + val timeout = 10000 millis + + val futures = (0 to 9) map { async } + val reduced = Future.reduceLeft(futures)(_ + _) + Await.result(reduced, timeout) mustBe (45) + + val futuresit = (0 to 9) map { async } + val reducedit = Future.reduceLeft(futuresit)(_ + _) + Await.result(reducedit, timeout) mustBe (45) + } + + @Test def `shouldReduceResultsWithException`(): Unit = { + def async(add: Int, wait: Int) = Future { + Thread.sleep(wait) + if (add == 6) throw new IllegalArgumentException("shouldFoldResultsWithException: expected") + else add + } + val timeout = 10000 millis + def futures = (1 to 10) map { + idx => async(idx, idx * 10) + } + val failed = Future.reduceLeft(futures)(_ + _) + intercept[IllegalArgumentException] { + Await.result(failed, timeout) + }.getMessage mustBe ("shouldFoldResultsWithException: expected") + } + + @Test def `shouldReduceThrowNSEEOnEmptyInput`(): Unit = { + intercept[java.util.NoSuchElementException] { + val emptyreduced = Future.reduceLeft(List[Future[Int]]())(_ + _) + Await.result(emptyreduced, defaultTimeout) + } + } + + @Test def `shouldTraverseFutures`(): Unit = { + object counter { + var count = -1 + def incAndGet() = counter.synchronized { + count += 2 + count + } + } + + val oddFutures = List.fill(100)(Future { counter.incAndGet() }).iterator + val traversed = Future.sequence(oddFutures) + Await.result(traversed, defaultTimeout).sum mustBe (10000) + + val list = (1 to 100).toList + val traversedList = Future.traverse(list)(x => Future(x * 2 - 1)) + Await.result(traversedList, defaultTimeout).sum mustBe (10000) + + val iterator = (1 to 100).toList.iterator + val traversedIterator = Future.traverse(iterator)(x => Future(x * 2 - 1)) + Await.result(traversedIterator, defaultTimeout).sum mustBe (10000) + } + + @Test def `shouldBlockUntilResult`(): Unit = { + val latch = new TestLatch + + val f = Future { + Await.ready(latch, 5 seconds) + 5 + } + val f2 = Future { + val res = Await.result(f, Inf) + res + 9 + } + + intercept[TimeoutException] { + Await.ready(f2, 100 millis) + } + + latch.open() + + Await.result(f2, defaultTimeout) mustBe (14) + + val f3 = Future { + Thread.sleep(100) + 5 + } + + intercept[TimeoutException] { + Await.ready(f3, 0 millis) + } + } + + @Test def `run callbacks async`(): Unit = { + val latch = Vector.fill(10)(new TestLatch) + + val f1 = Future { + latch(0).open() + Await.ready(latch(1), TestLatch.DefaultTimeout) + "Hello" + } + val f2 = async { + val s = await(f1) + latch(2).open() + Await.ready(latch(3), TestLatch.DefaultTimeout) + s.length + } + for (_ <- f2) latch(4).open() + + Await.ready(latch(0), TestLatch.DefaultTimeout) + + f1.isCompleted mustBe (false) + f2.isCompleted mustBe (false) + + latch(1).open() + Await.ready(latch(2), TestLatch.DefaultTimeout) + + f1.isCompleted mustBe (true) + f2.isCompleted mustBe (false) + + val f3 = async { + val s = await(f1) + latch(5).open() + Await.ready(latch(6), TestLatch.DefaultTimeout) + s.length * 2 + } + for (_ <- f3) latch(3).open() + + Await.ready(latch(5), TestLatch.DefaultTimeout) + + f3.isCompleted mustBe (false) + + latch(6).open() + Await.ready(latch(4), TestLatch.DefaultTimeout) + + f2.isCompleted mustBe (true) + f3.isCompleted mustBe (true) + + val p1 = Promise[String]() + val f4 = async { + val s = await(p1.future) + latch(7).open() + Await.ready(latch(8), TestLatch.DefaultTimeout) + s.length + } + for (_ <- f4) latch(9).open() + + p1.future.isCompleted mustBe (false) + f4.isCompleted mustBe (false) + + p1 complete Success("Hello") + + Await.ready(latch(7), TestLatch.DefaultTimeout) + + p1.future.isCompleted mustBe (true) + f4.isCompleted mustBe (false) + + latch(8).open() + Await.ready(latch(9), TestLatch.DefaultTimeout) + + Await.ready(f4, defaultTimeout).isCompleted mustBe (true) + } + + @Test def `should not deadlock with nested await (ticket 1313)`(): Unit = { + val simple = async { + await { Future { } } + val unit = Future(()) + val umap = unit map { _ => () } + Await.result(umap, Inf) + } + Await.ready(simple, Inf).isCompleted mustBe (true) + + val l1, l2 = new TestLatch + val complex = async { + await{ Future { } } + blocking { + val nested = Future(()) + for (_ <- nested) l1.open() + Await.ready(l1, TestLatch.DefaultTimeout) // make sure nested is completed + for (_ <- nested) l2.open() + Await.ready(l2, TestLatch.DefaultTimeout) + } + } + Await.ready(complex, defaultTimeout).isCompleted mustBe (true) + } + + @Test def `should not throw when Await.ready`(): Unit = { + val expected = try Success(5 / 0) catch { case a: ArithmeticException => Failure(a) } + val f = async { await(Future(5)) / 0 } + Await.ready(f, defaultTimeout).value.get.toString mustBe expected.toString + } +} diff --git a/src/test/scala/scala/async/SmokeTest.scala b/src/test/scala/scala/async/SmokeTest.scala new file mode 100644 index 00000000..204481d1 --- /dev/null +++ b/src/test/scala/scala/async/SmokeTest.scala @@ -0,0 +1,32 @@ +/* + * Scala (https://www.scala-lang.org) + * + * Copyright EPFL and Lightbend, Inc. + * + * Licensed under Apache License 2.0 + * (http://www.apache.org/licenses/LICENSE-2.0). + * + * See the NOTICE file distributed with this work for + * additional information regarding copyright ownership. + */ +package scala.async + +import org.junit.{Assert, Test} + +import scala.async.Async._ +import scala.concurrent._ +import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.Future.{successful => f} +import scala.concurrent.duration.Duration + +class SmokeTest { + def block[T](f: Future[T]): T = Await.result(f, Duration.Inf) + + @Test def testBasic(): Unit = { + val result = async { + await(f(1)) + await(f(2)) + } + Assert.assertEquals(3, block(result)) + } + +} diff --git a/src/test/scala/scala/async/TestLatch.scala b/src/test/scala/scala/async/TestLatch.scala deleted file mode 100644 index 011a8323..00000000 --- a/src/test/scala/scala/async/TestLatch.scala +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Scala (https://www.scala-lang.org) - * - * Copyright EPFL and Lightbend, Inc. - * - * Licensed under Apache License 2.0 - * (http://www.apache.org/licenses/LICENSE-2.0). - * - * See the NOTICE file distributed with this work for - * additional information regarding copyright ownership. - */ - -package scala.async - -import concurrent.{CanAwait, Awaitable} -import concurrent.duration.Duration -import java.util.concurrent.{TimeoutException, CountDownLatch, TimeUnit} - -object TestLatch { - val DefaultTimeout = Duration(5, TimeUnit.SECONDS) - - def apply(count: Int = 1) = new TestLatch(count) -} - - -class TestLatch(count: Int = 1) extends Awaitable[Unit] { - private var latch = new CountDownLatch(count) - - def countDown() = latch.countDown() - - def isOpen: Boolean = latch.getCount == 0 - - def open() = while (!isOpen) countDown() - - def reset() = latch = new CountDownLatch(count) - - @throws(classOf[TimeoutException]) - def ready(atMost: Duration)(implicit permit: CanAwait) = { - val opened = latch.await(atMost.toNanos, TimeUnit.NANOSECONDS) - if (!opened) throw new TimeoutException(s"Timeout of ${(atMost.toString)}.") - this - } - - @throws(classOf[Exception]) - def result(atMost: Duration)(implicit permit: CanAwait): Unit = { - ready(atMost) - } -} diff --git a/src/test/scala/scala/async/TestUtil.scala b/src/test/scala/scala/async/TestUtil.scala new file mode 100644 index 00000000..ac44de96 --- /dev/null +++ b/src/test/scala/scala/async/TestUtil.scala @@ -0,0 +1,66 @@ +/* + * Scala (https://www.scala-lang.org) + * + * Copyright EPFL and Lightbend, Inc. + * + * Licensed under Apache License 2.0 + * (http://www.apache.org/licenses/LICENSE-2.0). + * + * See the NOTICE file distributed with this work for + * additional information regarding copyright ownership. + */ +package scala.async + +import java.util.concurrent.{CountDownLatch, TimeUnit} + +import scala.concurrent.{Awaitable, CanAwait, TimeoutException} +import scala.concurrent.duration.{Duration, FiniteDuration} +import scala.reflect.{ClassTag, classTag} + +object TestUtil { + object TestLatch { + val DefaultTimeout: FiniteDuration = Duration(5, TimeUnit.SECONDS) + + def apply(count: Int = 1) = new TestLatch(count) + } + + class TestLatch(count: Int = 1) extends Awaitable[Unit] { + private var latch = new CountDownLatch(count) + + def countDown(): Unit = latch.countDown() + + def isOpen: Boolean = latch.getCount == 0 + + def open(): Unit = while (!isOpen) countDown() + + def reset(): Unit = latch = new CountDownLatch(count) + + @throws(classOf[TimeoutException]) + def ready(atMost: Duration)(implicit permit: CanAwait): TestLatch.this.type = { + val opened = latch.await(atMost.toNanos, TimeUnit.NANOSECONDS) + if (!opened) throw new TimeoutException(s"Timeout of ${(atMost.toString)}.") + this + } + + @throws(classOf[Exception]) + def result(atMost: Duration)(implicit permit: CanAwait): Unit = { + ready(atMost) + } + } + def intercept[T <: Throwable : ClassTag](body: => Any): T = { + try { + body + throw new Exception(s"Exception of type ${classTag[T]} was not thrown") + } catch { + case t: Throwable => + if (!classTag[T].runtimeClass.isAssignableFrom(t.getClass)) throw t + else t.asInstanceOf[T] + } + } + + implicit class objectops(obj: Any) { + def mustBe(other: Any): Unit = assert(obj == other, obj + " is not " + other) + + def mustEqual(other: Any): Unit = mustBe(other) + } +} diff --git a/src/test/scala/scala/async/TreeInterrogation.scala b/src/test/scala/scala/async/TreeInterrogation.scala deleted file mode 100644 index 2317d088..00000000 --- a/src/test/scala/scala/async/TreeInterrogation.scala +++ /dev/null @@ -1,112 +0,0 @@ -/* - * Scala (https://www.scala-lang.org) - * - * Copyright EPFL and Lightbend, Inc. - * - * Licensed under Apache License 2.0 - * (http://www.apache.org/licenses/LICENSE-2.0). - * - * See the NOTICE file distributed with this work for - * additional information regarding copyright ownership. - */ - -package scala.async - -import org.junit.Test -import scala.async.internal.AsyncId -import AsyncId._ -import tools.reflect.ToolBox - -class TreeInterrogation { - @Test - def `a minimal set of vals are lifted to vars`(): Unit = { - val cm = reflect.runtime.currentMirror - val tb = mkToolbox(s"-cp $toolboxClasspath") - val tree = tb.parse( - """| import _root_.scala.async.internal.AsyncId._ - | async { - | val x = await(1) - | val y = x * 2 - | def foo(a: Int) = { def nested = 0; a } // don't lift `nested`. - | val z = await(x * 3) - | foo(z) - | z - | }""".stripMargin) - val tree1 = tb.typeCheck(tree) - - //println(cm.universe.show(tree1)) - - import tb.u._ - val functions = tree1.collect { - case f: Function => f - case t: Template => t - } - functions.size mustBe 1 - - val varDefs = tree1.collect { - case vd @ ValDef(mods, name, _, _) if mods.hasFlag(Flag.MUTABLE) && vd.symbol.owner.isClass => name - } - varDefs.map(_.decoded.trim).toSet.toList.sorted mustStartWith (List("await$async$", "await$async", "state$async")) - - val defDefs = tree1.collect { - case t: Template => - val stats: List[Tree] = t.body - stats.collect { - case dd : DefDef - if !dd.symbol.isImplementationArtifact - && !dd.symbol.asTerm.isAccessor && !dd.symbol.asTerm.isSetter => dd.name - } - }.flatten - defDefs.map(_.decoded.trim) mustStartWith List("foo$async$", "", "apply", "apply") - } -} - -object TreeInterrogationApp extends App { - def withDebug[T](t: => T): T = { - def set(level: String, value: Boolean) = System.setProperty(s"scala.async.$level", value.toString) - val levels = Seq("trace", "debug") - def setAll(value: Boolean) = levels.foreach(set(_, value)) - - setAll(value = true) - try t finally setAll(value = false) - } - - withDebug { - val cm = reflect.runtime.currentMirror - val tb = mkToolbox(s"-cp ${toolboxClasspath} -Xprint:typer") - import scala.async.internal.AsyncId._ - val tree = tb.parse( - """ - | import scala.async.internal.AsyncId._ - | trait QBound { type D; trait ResultType { case class Inner() }; def toResult: ResultType = ??? } - | trait QD[Q <: QBound] { - | val operation: Q - | type D = operation.D - | } - | - | async { - | if (!"".isEmpty) { - | val treeResult = null.asInstanceOf[QD[QBound]] - | await(0) - | val y = treeResult.operation - | type RD = treeResult.operation.D - | (null: Object) match { - | case (_, _: RD) => ??? - | case _ => val x = y.toResult; x.Inner() - | } - | await(1) - | (y, null.asInstanceOf[RD]) - | "" - | } - | - | } - | - | """.stripMargin) - println(tree) - val tree1 = tb.typeCheck(tree.duplicate) - println(cm.universe.show(tree1)) - - println(tb.eval(tree)) - } - -} diff --git a/src/test/scala/scala/async/neg/LocalClasses0Spec.scala b/src/test/scala/scala/async/neg/LocalClasses0Spec.scala deleted file mode 100644 index bbf3c11e..00000000 --- a/src/test/scala/scala/async/neg/LocalClasses0Spec.scala +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Scala (https://www.scala-lang.org) - * - * Copyright EPFL and Lightbend, Inc. - * - * Licensed under Apache License 2.0 - * (http://www.apache.org/licenses/LICENSE-2.0). - * - * See the NOTICE file distributed with this work for - * additional information regarding copyright ownership. - */ - -package scala.async -package neg - -import org.junit.Test -import scala.async.internal.AsyncId - -class LocalClasses0Spec { - @Test - def localClassCrashIssue16(): Unit = { - import AsyncId.{async, await} - async { - class B { def f = 1 } - await(new B()).f - } mustBe 1 - } - - @Test - def nestedCaseClassAndModuleAllowed(): Unit = { - import AsyncId.{await, async} - async { - trait Base { def base = 0} - await(0) - case class Person(name: String) extends Base - val fut = async { "bob" } - val x = Person(await(fut)) - x.base - x.name - } mustBe "bob" - } -} diff --git a/src/test/scala/scala/async/neg/NakedAwait.scala b/src/test/scala/scala/async/neg/NakedAwait.scala deleted file mode 100644 index 4dbd0fa1..00000000 --- a/src/test/scala/scala/async/neg/NakedAwait.scala +++ /dev/null @@ -1,183 +0,0 @@ -/* - * Scala (https://www.scala-lang.org) - * - * Copyright EPFL and Lightbend, Inc. - * - * Licensed under Apache License 2.0 - * (http://www.apache.org/licenses/LICENSE-2.0). - * - * See the NOTICE file distributed with this work for - * additional information regarding copyright ownership. - */ - -package scala.async -package neg - -import org.junit.Test - -class NakedAwait { - @Test - def `await only allowed in async neg`(): Unit = { - expectError("`await` must be enclosed in an `async` block") { - """ - | import _root_.scala.async.Async._ - | await[Any](null) - """.stripMargin - } - } - - @Test - def `await not allowed in by-name argument`(): Unit = { - expectError("await must not be used under a by-name argument.") { - """ - | import _root_.scala.async.internal.AsyncId._ - | def foo(a: Int)(b: => Int) = 0 - | async { foo(0)(await(0)) } - """.stripMargin - } - } - - @Test - def `await not allowed in boolean short circuit argument 1`(): Unit = { - expectError("await must not be used under a by-name argument.") { - """ - | import _root_.scala.async.internal.AsyncId._ - | async { true && await(false) } - """.stripMargin - } - } - - @Test - def `await not allowed in boolean short circuit argument 2`(): Unit = { - expectError("await must not be used under a by-name argument.") { - """ - | import _root_.scala.async.internal.AsyncId._ - | async { true || await(false) } - """.stripMargin - } - } - - @Test - def nestedObject(): Unit = { - expectError("await must not be used under a nested object.") { - """ - | import _root_.scala.async.internal.AsyncId._ - | async { object Nested { await(false) } } - """.stripMargin - } - } - - @Test - def nestedTrait(): Unit = { - expectError("await must not be used under a nested trait.") { - """ - | import _root_.scala.async.internal.AsyncId._ - | async { trait Nested { await(false) } } - """.stripMargin - } - } - - @Test - def nestedClass(): Unit = { - expectError("await must not be used under a nested class.") { - """ - | import _root_.scala.async.internal.AsyncId._ - | async { class Nested { await(false) } } - """.stripMargin - } - } - - @Test - def nestedFunction(): Unit = { - expectError("await must not be used under a nested function.") { - """ - | import _root_.scala.async.internal.AsyncId._ - | async { () => { await(false) } } - """.stripMargin - } - } - - @Test - def nestedPatMatFunction(): Unit = { - expectError("await must not be used under a nested class.") { // TODO more specific error message - """ - | import _root_.scala.async.internal.AsyncId._ - | async { { case x => { await(false) } } : PartialFunction[Any, Any] } - """.stripMargin - } - } - - @Test - def tryBody(): Unit = { - expectError("await must not be used under a try/catch.") { - """ - | import _root_.scala.async.internal.AsyncId._ - | async { try { await(false) } catch { case _ => } } - """.stripMargin - } - } - - @Test - def catchBody(): Unit = { - expectError("await must not be used under a try/catch.") { - """ - | import _root_.scala.async.internal.AsyncId._ - | async { try { () } catch { case _ => await(false) } } - """.stripMargin - } - } - - @Test - def finallyBody(): Unit = { - expectError("await must not be used under a try/catch.") { - """ - | import _root_.scala.async.internal.AsyncId._ - | async { try { () } finally { await(false) } } - """.stripMargin - } - } - - @Test - def guard(): Unit = { - expectError("await must not be used under a pattern guard.") { - """ - | import _root_.scala.async.internal.AsyncId._ - | async { 1 match { case _ if await(true) => } } - """.stripMargin - } - } - - @Test - def nestedMethod(): Unit = { - expectError("await must not be used under a nested method.") { - """ - | import _root_.scala.async.internal.AsyncId._ - | async { def foo = await(false) } - """.stripMargin - } - } - - @Test - def returnIllegal(): Unit = { - expectError("return is illegal") { - """ - | import _root_.scala.async.internal.AsyncId._ - | def foo(): Any = async { return false } - | () - | - |""".stripMargin - } - } - - @Test - def lazyValIllegal(): Unit = { - expectError("await must not be used under a lazy val initializer") { - """ - | import _root_.scala.async.internal.AsyncId._ - | def foo(): Any = async { val x = { lazy val y = await(0); y } } - | () - | - |""".stripMargin - } - } -} diff --git a/src/test/scala/scala/async/neg/SampleNegSpec.scala b/src/test/scala/scala/async/neg/SampleNegSpec.scala deleted file mode 100644 index cf2c8394..00000000 --- a/src/test/scala/scala/async/neg/SampleNegSpec.scala +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Scala (https://www.scala-lang.org) - * - * Copyright EPFL and Lightbend, Inc. - * - * Licensed under Apache License 2.0 - * (http://www.apache.org/licenses/LICENSE-2.0). - * - * See the NOTICE file distributed with this work for - * additional information regarding copyright ownership. - */ - -package scala.async -package neg - -import org.junit.Test - -class SampleNegSpec { - @Test - def `missing symbol`(): Unit = { - expectError("not found: value kaboom") { - """ - | kaboom - """.stripMargin - } - } -} diff --git a/src/test/scala/scala/async/package.scala b/src/test/scala/scala/async/package.scala deleted file mode 100644 index e27a3cf5..00000000 --- a/src/test/scala/scala/async/package.scala +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Scala (https://www.scala-lang.org) - * - * Copyright EPFL and Lightbend, Inc. - * - * Licensed under Apache License 2.0 - * (http://www.apache.org/licenses/LICENSE-2.0). - * - * See the NOTICE file distributed with this work for - * additional information regarding copyright ownership. - */ - -package scala - -import reflect._ -import tools.reflect.{ToolBox, ToolBoxError} - -package object async { - - - implicit class objectops(obj: Any) { - def mustBe(other: Any) = assert(obj == other, obj + " is not " + other) - - def mustEqual(other: Any) = mustBe(other) - } - - implicit class stringops(text: String) { - def mustContain(substring: String) = assert(text contains substring, text) - - def mustStartWith(prefix: String) = assert(text startsWith prefix, text) - } - - implicit class listops(list: List[String]) { - def mustStartWith(prefixes: List[String]) = { - assert(list.length == prefixes.size, ("expected = " + prefixes.length + ", actual = " + list.length, list)) - list.zip(prefixes).foreach{ case (el, prefix) => el mustStartWith prefix } - } - } - - def intercept[T <: Throwable : ClassTag](body: => Any): T = { - try { - body - throw new Exception(s"Exception of type ${classTag[T]} was not thrown") - } catch { - case t: Throwable => - if (!classTag[T].runtimeClass.isAssignableFrom(t.getClass)) throw t - else t.asInstanceOf[T] - } - } - - def eval(code: String, compileOptions: String = ""): Any = { - val tb = mkToolbox(compileOptions) - tb.eval(tb.parse(code)) - } - - def mkToolbox(compileOptions: String = ""): ToolBox[_ <: scala.reflect.api.Universe] = { - val m = scala.reflect.runtime.currentMirror - import scala.tools.reflect.ToolBox - m.mkToolBox(options = compileOptions) - } - - import scala.tools.nsc._, reporters._ - def mkGlobal(compileOptions: String = ""): Global = { - val settings = new Settings() - settings.processArgumentString(compileOptions) - val initClassPath = settings.classpath.value - settings.embeddedDefaults(getClass.getClassLoader) - if (initClassPath == settings.classpath.value) - settings.usejavacp.value = true // not running under SBT, try to use the Java claspath instead - val reporter = new StoreReporter - new Global(settings, reporter) - } - - // returns e.g. target/scala-2.12/classes - // implementation is kludgy, but it's just test code. Scala version number formats and their - // relation to Scala binary versions are too diverse to attempt to do that mapping ourselves here, - // as we learned from experience. and we could use sbt-buildinfo to have sbt tell us, but that - // complicates the build since it does source generation (which may e.g. confuse IntelliJ). - // so this is, uh, fine? (crosses fingers) - def toolboxClasspath = - new java.io.File(this.getClass.getProtectionDomain.getCodeSource.getLocation.toURI) - .getParentFile.getParentFile - - def expectError(errorSnippet: String, compileOptions: String = "", - baseCompileOptions: String = s"-cp ${toolboxClasspath}")(code: String): Unit = { - intercept[ToolBoxError] { - eval(code, compileOptions + " " + baseCompileOptions) - }.getMessage mustContain errorSnippet - } -} diff --git a/src/test/scala/scala/async/run/SyncOptimizationSpec.scala b/src/test/scala/scala/async/run/SyncOptimizationSpec.scala deleted file mode 100644 index b5cd6539..00000000 --- a/src/test/scala/scala/async/run/SyncOptimizationSpec.scala +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Scala (https://www.scala-lang.org) - * - * Copyright EPFL and Lightbend, Inc. - * - * Licensed under Apache License 2.0 - * (http://www.apache.org/licenses/LICENSE-2.0). - * - * See the NOTICE file distributed with this work for - * additional information regarding copyright ownership. - */ - -package scala.async.run - -import org.junit.Test -import scala.async.Async._ -import scala.concurrent._ -import scala.concurrent.duration._ -import ExecutionContext.Implicits._ - -class SyncOptimizationSpec { - @Test - def awaitOnCompletedFutureRunsOnSameThread: Unit = { - - def stackDepth = Thread.currentThread().getStackTrace.length - - val future = async { - val thread1 = Thread.currentThread - val stackDepth1 = stackDepth - - val f = await(Future.successful(1)) - val thread2 = Thread.currentThread - val stackDepth2 = stackDepth - assert(thread1 == thread2) - assert(stackDepth1 == stackDepth2) - } - Await.result(future, 10.seconds) - } - -} diff --git a/src/test/scala/scala/async/run/WarningsSpec.scala b/src/test/scala/scala/async/run/WarningsSpec.scala deleted file mode 100644 index 155794d3..00000000 --- a/src/test/scala/scala/async/run/WarningsSpec.scala +++ /dev/null @@ -1,105 +0,0 @@ -/* - * Scala (https://www.scala-lang.org) - * - * Copyright EPFL and Lightbend, Inc. - * - * Licensed under Apache License 2.0 - * (http://www.apache.org/licenses/LICENSE-2.0). - * - * See the NOTICE file distributed with this work for - * additional information regarding copyright ownership. - */ - -package scala.async -package run - -import org.junit.Test - -import scala.language.{postfixOps, reflectiveCalls} -import scala.tools.nsc.reporters.StoreReporter - - -class WarningsSpec { - - @Test - // https://github.com/scala/async/issues/74 - def noPureExpressionInStatementPositionWarning_t74(): Unit = { - val tb = mkToolbox(s"-cp ${toolboxClasspath} -Xfatal-warnings") - // was: "a pure expression does nothing in statement position; you may be omitting necessary parentheses" - tb.eval(tb.parse { - """ - | import scala.async.internal.AsyncId._ - | async { - | if ("".isEmpty) { - | await(println("hello")) - | () - | } else 42 - | } - """.stripMargin - }) - } - - @Test - // https://github.com/scala/async/issues/74 - def noDeadCodeWarningForAsyncThrow(): Unit = { - val global = mkGlobal("-cp ${toolboxClasspath} -Yrangepos -Ywarn-dead-code -Xfatal-warnings -Ystop-after:refchecks") - // was: "a pure expression does nothing in statement position; you may be omitting necessary parentheses" - val source = - """ - | class Test { - | import scala.async.Async._ - | import scala.concurrent.ExecutionContext.Implicits.global - | async { throw new Error() } - | } - """.stripMargin - val run = new global.Run - val sourceFile = global.newSourceFile(source) - run.compileSources(sourceFile :: Nil) - assert(!global.reporter.hasErrors, global.reporter.asInstanceOf[StoreReporter].infos) - } - - @Test - def noDeadCodeWarningInMacroExpansion(): Unit = { - val global = mkGlobal("-cp ${toolboxClasspath} -Yrangepos -Ywarn-dead-code -Xfatal-warnings -Ystop-after:refchecks") - val source = """ - | class Test { - | def test = { - | import scala.async.Async._, scala.concurrent._, ExecutionContext.Implicits.global - | async { - | val opt = await(async(Option.empty[String => Future[Unit]])) - | opt match { - | case None => - | throw new RuntimeException("case a") - | case Some(f) => - | await(f("case b")) - | } - | } - | } - |} - """.stripMargin - val run = new global.Run - val sourceFile = global.newSourceFile(source) - run.compileSources(sourceFile :: Nil) - assert(!global.reporter.hasErrors, global.reporter.asInstanceOf[StoreReporter].infos) - } - - @Test - def ignoreNestedAwaitsInIDE_t1002561(): Unit = { - // https://www.assembla.com/spaces/scala-ide/tickets/1002561 - val global = mkGlobal("-cp ${toolboxClasspath} -Yrangepos -Ystop-after:typer ") - val source = """ - | class Test { - | def test = { - | import scala.async.Async._, scala.concurrent._, ExecutionContext.Implicits.global - | async { - | 1 + await({def foo = (async(await(async(2)))); foo}) - | } - | } - |} - """.stripMargin - val run = new global.Run - val sourceFile = global.newSourceFile(source) - run.compileSources(sourceFile :: Nil) - assert(!global.reporter.hasErrors, global.reporter.asInstanceOf[StoreReporter].infos) - } -} diff --git a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala b/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala deleted file mode 100644 index 2d133b02..00000000 --- a/src/test/scala/scala/async/run/anf/AnfTransformSpec.scala +++ /dev/null @@ -1,459 +0,0 @@ -/* - * Scala (https://www.scala-lang.org) - * - * Copyright EPFL and Lightbend, Inc. - * - * Licensed under Apache License 2.0 - * (http://www.apache.org/licenses/LICENSE-2.0). - * - * See the NOTICE file distributed with this work for - * additional information regarding copyright ownership. - */ - -package scala.async -package run -package anf - -import language.{reflectiveCalls, postfixOps} -import scala.concurrent.{Future, ExecutionContext, Await} -import scala.concurrent.duration._ -import scala.async.Async.{async, await} -import org.junit.Test -import scala.async.internal.AsyncId - - -class AnfTestClass { - - import ExecutionContext.Implicits.global - - def base(x: Int): Future[Int] = Future { - x + 2 - } - - def m(y: Int): Future[Int] = async { - val blerg = base(y) - await(blerg) - } - - def m2(y: Int): Future[Int] = async { - val f = base(y) - val f2 = base(y + 1) - await(f) + await(f2) - } - - def m3(y: Int): Future[Int] = async { - val f = base(y) - var z = 0 - if (y > 0) { - z = await(f) + 2 - } else { - z = await(f) - 2 - } - z - } - - def m4(y: Int): Future[Int] = async { - val f = base(y) - val z = if (y > 0) { - await(f) + 2 - } else { - await(f) - 2 - } - z + 1 - } - - def futureUnitIfElse(y: Int): Future[Unit] = async { - val f = base(y) - if (y > 0) { - State.result = await(f) + 2 - } else { - State.result = await(f) - 2 - } - } -} - -object State { - @volatile var result: Int = 0 -} - -class AnfTransformSpec { - - @Test - def `simple ANF transform`(): Unit = { - val o = new AnfTestClass - val fut = o.m(10) - val res = Await.result(fut, 2 seconds) - res mustBe (12) - } - - @Test - def `simple ANF transform 2`(): Unit = { - val o = new AnfTestClass - val fut = o.m2(10) - val res = Await.result(fut, 2 seconds) - res mustBe (25) - } - - @Test - def `simple ANF transform 3`(): Unit = { - val o = new AnfTestClass - val fut = o.m3(10) - val res = Await.result(fut, 2 seconds) - res mustBe (14) - } - - @Test - def `ANF transform of assigning the result of an if-else`(): Unit = { - val o = new AnfTestClass - val fut = o.m4(10) - val res = Await.result(fut, 2 seconds) - res mustBe (15) - } - - @Test - def `Unit-typed if-else in tail position`(): Unit = { - val o = new AnfTestClass - val fut = o.futureUnitIfElse(10) - Await.result(fut, 2 seconds) - State.result mustBe (14) - } - - @Test - def `inlining block does not produce duplicate definition`(): Unit = { - AsyncId.async { - val f = 12 - val x = AsyncId.await(f) - - { - type X = Int - val x: X = 42 - println(x) - } - type X = Int - x: X - } - } - - @Test - def `inlining block in tail position does not produce duplicate definition`(): Unit = { - AsyncId.async { - val f = 12 - val x = AsyncId.await(f) - - { - val x = 42 - x - } - } mustBe (42) - } - - @Test - def `match as expression 1`(): Unit = { - import ExecutionContext.Implicits.global - val result = AsyncId.async { - val x = "" match { - case _ => AsyncId.await(1) + 1 - } - x - } - result mustBe (2) - } - - @Test - def `match as expression 2`(): Unit = { - import ExecutionContext.Implicits.global - val result = AsyncId.async { - val x = "" match { - case "" if false => AsyncId.await(1) + 1 - case _ => 2 + AsyncId.await(1) - } - val y = x - "" match { - case _ => AsyncId.await(y) + 100 - } - } - result mustBe (103) - } - - @Test - def nestedAwaitAsBareExpression(): Unit = { - import ExecutionContext.Implicits.global - import AsyncId.{async, await} - val result = async { - await(await("").isEmpty) - } - result mustBe (true) - } - - @Test - def nestedAwaitInBlock(): Unit = { - import ExecutionContext.Implicits.global - import AsyncId.{async, await} - val result = async { - () - await(await("").isEmpty) - } - result mustBe (true) - } - - @Test - def nestedAwaitInIf(): Unit = { - import ExecutionContext.Implicits.global - import AsyncId.{async, await} - val result = async { - if ("".isEmpty) - await(await("").isEmpty) - else 0 - } - result mustBe (true) - } - - @Test - def byNameExpressionsArentLifted(): Unit = { - import AsyncId.{async, await} - def foo(ignored: => Any, b: Int) = b - val result = async { - foo(???, await(1)) - } - result mustBe (1) - } - - @Test - def evaluationOrderRespected(): Unit = { - import AsyncId.{async, await} - def foo(a: Int, b: Int) = (a, b) - val result = async { - var i = 0 - def next() = { - i += 1 - i - } - foo(next(), await(next())) - } - result mustBe ((1, 2)) - } - - @Test - def awaitInNonPrimaryParamSection1(): Unit = { - import AsyncId.{async, await} - def foo(a0: Int)(b0: Int) = s"a0 = $a0, b0 = $b0" - val res = async { - var i = 0 - def get = {i += 1; i} - foo(get)(await(get)) - } - res mustBe "a0 = 1, b0 = 2" - } - - @Test - def awaitInNonPrimaryParamSection2(): Unit = { - import AsyncId.{async, await} - def foo[T](a0: Int)(b0: Int*) = s"a0 = $a0, b0 = ${b0.head}" - val res = async { - var i = 0 - def get = async {i += 1; i} - foo[Int](await(get))(await(get) :: await(async(Nil)) : _*) - } - res mustBe "a0 = 1, b0 = 2" - } - - @Test - def awaitInNonPrimaryParamSectionWithLazy1(): Unit = { - import AsyncId.{async, await} - def foo[T](a: => Int)(b: Int) = b - val res = async { - def get = async {0} - foo[Int](???)(await(get)) - } - res mustBe 0 - } - - @Test - def awaitInNonPrimaryParamSectionWithLazy2(): Unit = { - import AsyncId.{async, await} - def foo[T](a: Int)(b: => Int) = a - val res = async { - def get = async {0} - foo[Int](await(get))(???) - } - res mustBe 0 - } - - @Test - def awaitWithLazy(): Unit = { - import AsyncId.{async, await} - def foo[T](a: Int, b: => Int) = a - val res = async { - def get = async {0} - foo[Int](await(get), ???) - } - res mustBe 0 - } - - @Test - def awaitOkInReciever(): Unit = { - import AsyncId.{async, await} - class Foo { def bar(a: Int)(b: Int) = a + b } - async { - await(async(new Foo)).bar(1)(2) - } - } - - @Test - def namedArgumentsRespectEvaluationOrder(): Unit = { - import AsyncId.{async, await} - def foo(a: Int, b: Int) = (a, b) - val result = async { - var i = 0 - def next() = { - i += 1 - i - } - foo(b = next(), a = await(next())) - } - result mustBe ((2, 1)) - } - - @Test - def namedAndDefaultArgumentsRespectEvaluationOrder(): Unit = { - import AsyncId.{async, await} - var i = 0 - def next() = { - i += 1 - i - } - def foo(a: Int = next(), b: Int = next()) = (a, b) - async { - foo(b = await(next())) - } mustBe ((2, 1)) - i = 0 - async { - foo(a = await(next())) - } mustBe ((1, 2)) - } - - @Test - def repeatedParams1(): Unit = { - import AsyncId.{async, await} - var i = 0 - def foo(a: Int, b: Int*) = b.toList - def id(i: Int) = i - async { - foo(await(0), id(1), id(2), id(3), await(4)) - } mustBe (List(1, 2, 3, 4)) - } - - @Test - def repeatedParams2(): Unit = { - import AsyncId.{async, await} - var i = 0 - def foo(a: Int, b: Int*) = b.toList - def id(i: Int) = i - async { - foo(await(0), List(id(1), id(2), id(3)): _*) - } mustBe (List(1, 2, 3)) - } - - @Test - def awaitInThrow(): Unit = { - import _root_.scala.async.internal.AsyncId.{async, await} - intercept[Exception]( - async { - throw new Exception("msg: " + await(0)) - } - ).getMessage mustBe "msg: 0" - } - - @Test - def awaitInTyped(): Unit = { - import _root_.scala.async.internal.AsyncId.{async, await} - async { - (("msg: " + await(0)): String).toString - } mustBe "msg: 0" - } - - - @Test - def awaitInAssign(): Unit = { - import _root_.scala.async.internal.AsyncId.{async, await} - async { - var x = 0 - x = await(1) - x - } mustBe 1 - } - - @Test - def caseBodyMustBeTypedAsUnit(): Unit = { - import _root_.scala.async.internal.AsyncId.{async, await} - val Up = 1 - val Down = 2 - val sign = async { - await(1) match { - case Up => 1.0 - case Down => -1.0 - } - } - sign mustBe 1.0 - } - - @Test - def awaitInImplicitApply(): Unit = { - val tb = mkToolbox(s"-cp ${toolboxClasspath}") - val tree = tb.typeCheck(tb.parse { - """ - | import language.implicitConversions - | import _root_.scala.async.internal.AsyncId.{async, await} - | implicit def view(a: Int): String = "" - | async { - | await(0).length - | } - """.stripMargin - }) - val applyImplicitView = tree.collect { case x if x.getClass.getName.endsWith("ApplyImplicitView") => x } - println(applyImplicitView) - applyImplicitView.map(_.toString) mustStartWith List("view(") - } - - @Test - def nothingTypedIf(): Unit = { - import scala.async.internal.AsyncId.{async, await} - val result = util.Try(async { - if (true) { - val n = await(1) - if (n < 2) { - throw new RuntimeException("case a") - } - else { - throw new RuntimeException("case b") - } - } - else { - "case c" - } - }) - - assert(result.asInstanceOf[util.Failure[_]].exception.getMessage == "case a") - } - - @Test - def nothingTypedMatch(): Unit = { - import scala.async.internal.AsyncId.{async, await} - val result = util.Try(async { - 0 match { - case _ if "".isEmpty => - val n = await(1) - n match { - case _ if n < 2 => - throw new RuntimeException("case a") - case _ => - throw new RuntimeException("case b") - } - case _ => - "case c" - } - }) - - assert(result.asInstanceOf[util.Failure[_]].exception.getMessage == "case a") - } -} diff --git a/src/test/scala/scala/async/run/await0/Await0Spec.scala b/src/test/scala/scala/async/run/await0/Await0Spec.scala deleted file mode 100644 index e70a811e..00000000 --- a/src/test/scala/scala/async/run/await0/Await0Spec.scala +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Scala (https://www.scala-lang.org) - * - * Copyright EPFL and Lightbend, Inc. - * - * Licensed under Apache License 2.0 - * (http://www.apache.org/licenses/LICENSE-2.0). - * - * See the NOTICE file distributed with this work for - * additional information regarding copyright ownership. - */ - -package scala.async -package run -package await0 - -/** - * Copyright (C) 2012-2014 Lightbend Inc. - */ - -import language.{reflectiveCalls, postfixOps} - -import scala.concurrent.{Future, ExecutionContext, Await} -import scala.concurrent.duration._ -import scala.async.Async.{async, await} -import org.junit.Test - -class Await0Class { - - import ExecutionContext.Implicits.global - - def m1(x: Double): Future[Double] = Future { - x + 2.0 - } - - def m2(x: Float): Future[Float] = Future { - x + 2.0f - } - - def m3(x: Char): Future[Char] = Future { - (x.toInt + 2).toChar - } - - def m4(x: Short): Future[Short] = Future { - (x + 2).toShort - } - - def m5(x: Byte): Future[Byte] = Future { - (x + 2).toByte - } - - def m0(y: Int): Future[Double] = async { - val f1 = m1(y.toDouble) - val x1: Double = await(f1) - - val f2 = m2(y.toFloat) - val x2: Float = await(f2) - - val f3 = m3(y.toChar) - val x3: Char = await(f3) - - val f4 = m4(y.toShort) - val x4: Short = await(f4) - - val f5 = m5(y.toByte) - val x5: Byte = await(f5) - - x1 + x2 + 2.0 - } -} - -class Await0Spec { - - @Test - def `An async method support a simple await`(): Unit = { - val o = new Await0Class - val fut = o.m0(10) - val res = Await.result(fut, 10 seconds) - res mustBe (26.0) - } -} diff --git a/src/test/scala/scala/async/run/block0/AsyncSpec.scala b/src/test/scala/scala/async/run/block0/AsyncSpec.scala deleted file mode 100644 index 6284dbdb..00000000 --- a/src/test/scala/scala/async/run/block0/AsyncSpec.scala +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Scala (https://www.scala-lang.org) - * - * Copyright EPFL and Lightbend, Inc. - * - * Licensed under Apache License 2.0 - * (http://www.apache.org/licenses/LICENSE-2.0). - * - * See the NOTICE file distributed with this work for - * additional information regarding copyright ownership. - */ - -package scala.async -package run -package block0 - -import language.{reflectiveCalls, postfixOps} -import scala.concurrent.{Future, ExecutionContext, Await} -import scala.concurrent.duration._ -import scala.async.Async.{async, await} -import org.junit.Test - - -class Test1Class { - - import ExecutionContext.Implicits.global - - def m1(x: Int): Future[Int] = Future { - x + 2 - } - - def m2(y: Int): Future[Int] = async { - val f = m1(y) - val x = await(f) - x + 2 - } - - def m3(y: Int): Future[Int] = async { - val f1 = m1(y) - val x1 = await(f1) - val f2 = m1(y + 2) - val x2 = await(f2) - x1 + x2 - } -} - - -class AsyncSpec { - - @Test - def `simple await`(): Unit = { - val o = new Test1Class - val fut = o.m2(10) - val res = Await.result(fut, 2 seconds) - res mustBe (14) - } - - @Test - def `several awaits in sequence`(): Unit = { - val o = new Test1Class - val fut = o.m3(10) - val res = Await.result(fut, 4 seconds) - res mustBe (26) - } -} diff --git a/src/test/scala/scala/async/run/block1/block1.scala b/src/test/scala/scala/async/run/block1/block1.scala deleted file mode 100644 index 7247c244..00000000 --- a/src/test/scala/scala/async/run/block1/block1.scala +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Scala (https://www.scala-lang.org) - * - * Copyright EPFL and Lightbend, Inc. - * - * Licensed under Apache License 2.0 - * (http://www.apache.org/licenses/LICENSE-2.0). - * - * See the NOTICE file distributed with this work for - * additional information regarding copyright ownership. - */ - -package scala.async -package run -package block1 - -import language.{reflectiveCalls, postfixOps} -import scala.concurrent.{Future, ExecutionContext, Await} -import scala.concurrent.duration._ -import scala.async.Async.{async, await} -import org.junit.Test - - -class Test1Class { - - import ExecutionContext.Implicits.global - - def m1(x: Int): Future[Int] = Future { - x + 2 - } - - def m4(y: Int): Future[Int] = async { - val f1 = m1(y) - val f2 = m1(y + 2) - val x1 = await(f1) - val x2 = await(f2) - x1 + x2 - } -} - -class Block1Spec { - - @Test def `support a simple await`(): Unit = { - val o = new Test1Class - val fut = o.m4(10) - val res = Await.result(fut, 2 seconds) - res mustBe (26) - } -} diff --git a/src/test/scala/scala/async/run/exceptions/ExceptionsSpec.scala b/src/test/scala/scala/async/run/exceptions/ExceptionsSpec.scala deleted file mode 100644 index e75594ab..00000000 --- a/src/test/scala/scala/async/run/exceptions/ExceptionsSpec.scala +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Scala (https://www.scala-lang.org) - * - * Copyright EPFL and Lightbend, Inc. - * - * Licensed under Apache License 2.0 - * (http://www.apache.org/licenses/LICENSE-2.0). - * - * See the NOTICE file distributed with this work for - * additional information regarding copyright ownership. - */ - -package scala.async -package run -package exceptions - -import scala.async.Async.{async, await} - -import scala.concurrent.{Future, ExecutionContext, Await} -import ExecutionContext.Implicits._ -import scala.concurrent.duration._ -import scala.reflect.ClassTag - -import org.junit.Test - -class ExceptionsSpec { - - @Test - def `uncaught exception within async`(): Unit = { - val fut = async { throw new Exception("problem") } - intercept[Exception] { Await.result(fut, 2.seconds) } - } - - @Test - def `uncaught exception within async after await`(): Unit = { - val base = Future { "five!".length } - val fut = async { - val len = await(base) - throw new Exception(s"illegal length: $len") - } - intercept[Exception] { Await.result(fut, 2.seconds) } - } - - @Test - def `await failing future within async`(): Unit = { - val base = Future[Int] { throw new Exception("problem") } - val fut = async { - val x = await(base) - x * 2 - } - intercept[Exception] { Await.result(fut, 2.seconds) } - } - - @Test - def `await failing future within async after await`(): Unit = { - val base = Future[Any] { "five!".length } - val fut = async { - val a = await(base.mapTo[Int]) // result: 5 - val b = await((Future { (a * 2).toString }).mapTo[Int]) // result: ClassCastException - val c = await(Future { (7 * 2).toString }) // result: "14" - b + "-" + c - } - intercept[ClassCastException] { Await.result(fut, 2.seconds) } - } - -} diff --git a/src/test/scala/scala/async/run/futures/FutureSpec.scala b/src/test/scala/scala/async/run/futures/FutureSpec.scala deleted file mode 100644 index 52566894..00000000 --- a/src/test/scala/scala/async/run/futures/FutureSpec.scala +++ /dev/null @@ -1,560 +0,0 @@ -/* - * Scala (https://www.scala-lang.org) - * - * Copyright EPFL and Lightbend, Inc. - * - * Licensed under Apache License 2.0 - * (http://www.apache.org/licenses/LICENSE-2.0). - * - * See the NOTICE file distributed with this work for - * additional information regarding copyright ownership. - */ - -package scala.async -package run -package futures - -import java.util.concurrent.ConcurrentHashMap - -import scala.language.postfixOps -import scala.concurrent._ -import scala.concurrent.duration._ -import scala.concurrent.duration.Duration.Inf -import scala.collection._ -import scala.runtime.NonLocalReturnControl -import scala.util.{Try,Success,Failure} - -import scala.async.Async.{async, await} - -import org.junit.Test - -class FutureSpec { - - /* some utils */ - - def testAsync(s: String)(implicit ec: ExecutionContext): Future[String] = s match { - case "Hello" => Future { "World" } - case "Failure" => Future.failed(new RuntimeException("Expected exception; to test fault-tolerance")) - case "NoReply" => Promise[String]().future - } - - val defaultTimeout = 5 seconds - - /* future specification */ - - @Test def `A future with custom ExecutionContext should handle Throwables`(): Unit = { - val ms = new ConcurrentHashMap[Throwable, Unit] - implicit val ec = scala.concurrent.ExecutionContext.fromExecutor(new java.util.concurrent.ForkJoinPool(), { - t => - ms.put(t, ()) - }) - - class ThrowableTest(m: String) extends Throwable(m) - - val f1 = Future[Any] { - throw new ThrowableTest("test") - } - - intercept[ThrowableTest] { - Await.result(f1, defaultTimeout) - } - - val latch = new TestLatch - val f2 = Future { - Await.ready(latch, 5 seconds) - "success" - } - val f3 = async { - val s = await(f2) - s.toUpperCase - } - - f2 foreach { _ => throw new ThrowableTest("dispatcher foreach") } - f2 onComplete { case Success(_) => throw new ThrowableTest("dispatcher receive") } - - latch.open() - - Await.result(f2, defaultTimeout) mustBe ("success") - - f2 foreach { _ => throw new ThrowableTest("current thread foreach") } - f2 onComplete { case Success(_) => throw new ThrowableTest("current thread receive") } - - Await.result(f3, defaultTimeout) mustBe ("SUCCESS") - - val waiting = Future { - Thread.sleep(1000) - } - Await.ready(waiting, 2000 millis) - - ms.size mustBe (4) - } - - import ExecutionContext.Implicits._ - - @Test def `A future with global ExecutionContext should compose with for-comprehensions`(): Unit = { - import scala.reflect.ClassTag - - def asyncInt(x: Int) = Future { (x * 2).toString } - val future0 = Future[Any] { - "five!".length - } - - val future1 = async { - val a = await(future0.mapTo[Int]) // returns 5 - val b = await(asyncInt(a)) // returns "10" - val c = await(asyncInt(7)) // returns "14" - b + "-" + c - } - - val future2 = async { - val a = await(future0.mapTo[Int]) - val b = await((Future { (a * 2).toString }).mapTo[Int]) - val c = await(Future { (7 * 2).toString }) - b + "-" + c - } - - Await.result(future1, defaultTimeout) mustBe ("10-14") - //assert(checkType(future1, manifest[String])) - intercept[ClassCastException] { Await.result(future2, defaultTimeout) } - } - - //TODO this is not yet supported by Async - @Test def `support pattern matching within a for-comprehension`(): Unit = { - case class Req[T](req: T) - case class Res[T](res: T) - def asyncReq[T](req: Req[T]) = req match { - case Req(s: String) => Future { Res(s.length) } - case Req(i: Int) => Future { Res((i * 2).toString) } - } - - val future1 = for { - Res(a: Int) <- asyncReq(Req("Hello")) - Res(b: String) <- asyncReq(Req(a)) - Res(c: String) <- asyncReq(Req(7)) - } yield b + "-" + c - - val future2 = for { - Res(a: Int) <- asyncReq(Req("Hello")) - Res(b: Int) <- asyncReq(Req(a)) - Res(c: Int) <- asyncReq(Req(7)) - } yield b + "-" + c - - Await.result(future1, defaultTimeout) mustBe ("10-14") - intercept[NoSuchElementException] { Await.result(future2, defaultTimeout) } - } - - @Test def mini(): Unit = { - val future4 = async { - await(Future.successful(0)).toString - } - Await.result(future4, defaultTimeout) - } - - @Test def `recover from exceptions`(): Unit = { - val future1 = Future(5) - val future2 = async { await(future1) / 0 } - val future3 = async { await(future2).toString } - - val future1Recovered = future1 recover { - case e: ArithmeticException => 0 - } - val future4 = async { await(future1Recovered).toString } - - val future2Recovered = future2 recover { - case e: ArithmeticException => 0 - } - val future5 = async { await(future2Recovered).toString } - - val future2Recovered2 = future2 recover { - case e: MatchError => 0 - } - val future6 = async { await(future2Recovered2).toString } - - val future7 = future3 recover { - case e: ArithmeticException => "You got ERROR" - } - - val future8 = testAsync("Failure") - val future9 = testAsync("Failure") recover { - case e: RuntimeException => "FAIL!" - } - val future10 = testAsync("Hello") recover { - case e: RuntimeException => "FAIL!" - } - val future11 = testAsync("Failure") recover { - case _ => "Oops!" - } - - Await.result(future1, defaultTimeout) mustBe (5) - intercept[ArithmeticException] { Await.result(future2, defaultTimeout) } - intercept[ArithmeticException] { Await.result(future3, defaultTimeout) } - Await.result(future4, defaultTimeout) mustBe ("5") - Await.result(future5, defaultTimeout) mustBe ("0") - intercept[ArithmeticException] { Await.result(future6, defaultTimeout) } - Await.result(future7, defaultTimeout) mustBe ("You got ERROR") - intercept[RuntimeException] { Await.result(future8, defaultTimeout) } - Await.result(future9, defaultTimeout) mustBe ("FAIL!") - Await.result(future10, defaultTimeout) mustBe ("World") - Await.result(future11, defaultTimeout) mustBe ("Oops!") - } - - @Test def `recoverWith from exceptions`(): Unit = { - val o = new IllegalStateException("original") - val r = new IllegalStateException("recovered") - - intercept[IllegalStateException] { - val failed = Future.failed[String](o) recoverWith { - case _ if false == true => Future.successful("yay!") - } - Await.result(failed, defaultTimeout) - } mustBe (o) - - val recovered = Future.failed[String](o) recoverWith { - case _ => Future.successful("yay!") - } - Await.result(recovered, defaultTimeout) mustBe ("yay!") - - intercept[IllegalStateException] { - val refailed = Future.failed[String](o) recoverWith { - case _ => Future.failed[String](r) - } - Await.result(refailed, defaultTimeout) - } mustBe (r) - } - - @Test def `andThen like a boss`(): Unit = { - val q = new java.util.concurrent.LinkedBlockingQueue[Int] - for (i <- 1 to 1000) { - val chained = Future { - q.add(1); 3 - } andThen { - case _ => q.add(2) - } andThen { - case Success(0) => q.add(Int.MaxValue) - } andThen { - case _ => q.add(3); - } - Await.result(chained, defaultTimeout) mustBe (3) - q.poll() mustBe (1) - q.poll() mustBe (2) - q.poll() mustBe (3) - q.clear() - } - } - - @Test def `firstCompletedOf`(): Unit = { - def futures = Vector.fill[Future[Int]](10) { - Promise[Int]().future - } :+ Future.successful[Int](5) - - Await.result(Future.firstCompletedOf(futures), defaultTimeout) mustBe (5) - Await.result(Future.firstCompletedOf(futures.iterator), defaultTimeout) mustBe (5) - } - - @Test def `find`(): Unit = { - val futures = for (i <- 1 to 10) yield Future { - i - } - - val result = Future.find[Int](futures)(_ == 3) - Await.result(result, defaultTimeout) mustBe (Some(3)) - - val notFound = Future.find[Int](futures)(_ == 11) - Await.result(notFound, defaultTimeout) mustBe (None) - } - - @Test def `zip`(): Unit = { - val timeout = 10000 millis - val f = new IllegalStateException("test") - intercept[IllegalStateException] { - val failed = Future.failed[String](f) zip Future.successful("foo") - Await.result(failed, timeout) - } mustBe (f) - - intercept[IllegalStateException] { - val failed = Future.successful("foo") zip Future.failed[String](f) - Await.result(failed, timeout) - } mustBe (f) - - intercept[IllegalStateException] { - val failed = Future.failed[String](f) zip Future.failed[String](f) - Await.result(failed, timeout) - } mustBe (f) - - val successful = Future.successful("foo") zip Future.successful("foo") - Await.result(successful, timeout) mustBe (("foo", "foo")) - } - - @Test def `fold`(): Unit = { - val timeout = 10000 millis - def async(add: Int, wait: Int) = Future { - Thread.sleep(wait) - add - } - - val futures = (0 to 9) map { - idx => async(idx, idx * 20) - } - // TODO: change to `foldLeft` after support for 2.11 is dropped - val folded = Future.fold(futures)(0)(_ + _) - Await.result(folded, timeout) mustBe (45) - - val futuresit = (0 to 9) map { - idx => async(idx, idx * 20) - } - // TODO: change to `foldLeft` after support for 2.11 is dropped - val foldedit = Future.fold(futures)(0)(_ + _) - Await.result(foldedit, timeout) mustBe (45) - } - - @Test def `fold by composing`(): Unit = { - val timeout = 10000 millis - def async(add: Int, wait: Int) = Future { - Thread.sleep(wait) - add - } - def futures = (0 to 9) map { - idx => async(idx, idx * 20) - } - val folded = futures.foldLeft(Future(0)) { - case (fr, fa) => for (r <- fr; a <- fa) yield (r + a) - } - Await.result(folded, timeout) mustBe (45) - } - - @Test def `fold with an exception`(): Unit = { - val timeout = 10000 millis - def async(add: Int, wait: Int) = Future { - Thread.sleep(wait) - if (add == 6) throw new IllegalArgumentException("shouldFoldResultsWithException: expected") - add - } - def futures = (0 to 9) map { - idx => async(idx, idx * 10) - } - // TODO: change to `foldLeft` after support for 2.11 is dropped - val folded = Future.fold(futures)(0)(_ + _) - intercept[IllegalArgumentException] { - Await.result(folded, timeout) - }.getMessage mustBe ("shouldFoldResultsWithException: expected") - } - - @Test def `fold mutable zeroes safely`(): Unit = { - import scala.collection.mutable.ArrayBuffer - def test(testNumber: Int): Unit = { - val fs = (0 to 1000) map (i => Future(i)) - // TODO: change to `foldLeft` after support for 2.11 is dropped - val f = Future.fold(fs)(ArrayBuffer.empty[AnyRef]) { - case (l, i) if i % 2 == 0 => l += i.asInstanceOf[AnyRef] - case (l, _) => l - } - val result = Await.result(f.mapTo[ArrayBuffer[Int]], 10000 millis).sum - - assert(result == 250500) - } - - (1 to 100) foreach test //Make sure it tries to provoke the problem - } - - @Test def `return zero value if folding empty list`(): Unit = { - // TODO: change to `foldLeft` after support for 2.11 is dropped - val zero = Future.fold(List[Future[Int]]())(0)(_ + _) - Await.result(zero, defaultTimeout) mustBe (0) - } - - @Test def `shouldReduceResults`(): Unit = { - def async(idx: Int) = Future { - Thread.sleep(idx * 20) - idx - } - val timeout = 10000 millis - - val futures = (0 to 9) map { async } - // TODO: change to `reduceLeft` after support for 2.11 is dropped - val reduced = Future.reduce(futures)(_ + _) - Await.result(reduced, timeout) mustBe (45) - - val futuresit = (0 to 9) map { async } - // TODO: change to `reduceLeft` after support for 2.11 is dropped - val reducedit = Future.reduce(futuresit)(_ + _) - Await.result(reducedit, timeout) mustBe (45) - } - - @Test def `shouldReduceResultsWithException`(): Unit = { - def async(add: Int, wait: Int) = Future { - Thread.sleep(wait) - if (add == 6) throw new IllegalArgumentException("shouldFoldResultsWithException: expected") - else add - } - val timeout = 10000 millis - def futures = (1 to 10) map { - idx => async(idx, idx * 10) - } - // TODO: change to `reduceLeft` after support for 2.11 is dropped - val failed = Future.reduce(futures)(_ + _) - intercept[IllegalArgumentException] { - Await.result(failed, timeout) - }.getMessage mustBe ("shouldFoldResultsWithException: expected") - } - - @Test def `shouldReduceThrowNSEEOnEmptyInput`(): Unit = { - intercept[java.util.NoSuchElementException] { - // TODO: change to `reduceLeft` after support for 2.11 is dropped - val emptyreduced = Future.reduce(List[Future[Int]]())(_ + _) - Await.result(emptyreduced, defaultTimeout) - } - } - - @Test def `shouldTraverseFutures`(): Unit = { - object counter { - var count = -1 - def incAndGet() = counter.synchronized { - count += 2 - count - } - } - - val oddFutures = List.fill(100)(Future { counter.incAndGet() }).iterator - val traversed = Future.sequence(oddFutures) - Await.result(traversed, defaultTimeout).sum mustBe (10000) - - val list = (1 to 100).toList - val traversedList = Future.traverse(list)(x => Future(x * 2 - 1)) - Await.result(traversedList, defaultTimeout).sum mustBe (10000) - - val iterator = (1 to 100).toList.iterator - val traversedIterator = Future.traverse(iterator)(x => Future(x * 2 - 1)) - Await.result(traversedIterator, defaultTimeout).sum mustBe (10000) - } - - @Test def `shouldBlockUntilResult`(): Unit = { - val latch = new TestLatch - - val f = Future { - Await.ready(latch, 5 seconds) - 5 - } - val f2 = Future { - val res = Await.result(f, Inf) - res + 9 - } - - intercept[TimeoutException] { - Await.ready(f2, 100 millis) - } - - latch.open() - - Await.result(f2, defaultTimeout) mustBe (14) - - val f3 = Future { - Thread.sleep(100) - 5 - } - - intercept[TimeoutException] { - Await.ready(f3, 0 millis) - } - } - - @Test def `run callbacks async`(): Unit = { - val latch = Vector.fill(10)(new TestLatch) - - val f1 = Future { - latch(0).open() - Await.ready(latch(1), TestLatch.DefaultTimeout) - "Hello" - } - val f2 = async { - val s = await(f1) - latch(2).open() - Await.ready(latch(3), TestLatch.DefaultTimeout) - s.length - } - for (_ <- f2) latch(4).open() - - Await.ready(latch(0), TestLatch.DefaultTimeout) - - f1.isCompleted mustBe (false) - f2.isCompleted mustBe (false) - - latch(1).open() - Await.ready(latch(2), TestLatch.DefaultTimeout) - - f1.isCompleted mustBe (true) - f2.isCompleted mustBe (false) - - val f3 = async { - val s = await(f1) - latch(5).open() - Await.ready(latch(6), TestLatch.DefaultTimeout) - s.length * 2 - } - for (_ <- f3) latch(3).open() - - Await.ready(latch(5), TestLatch.DefaultTimeout) - - f3.isCompleted mustBe (false) - - latch(6).open() - Await.ready(latch(4), TestLatch.DefaultTimeout) - - f2.isCompleted mustBe (true) - f3.isCompleted mustBe (true) - - val p1 = Promise[String]() - val f4 = async { - val s = await(p1.future) - latch(7).open() - Await.ready(latch(8), TestLatch.DefaultTimeout) - s.length - } - for (_ <- f4) latch(9).open() - - p1.future.isCompleted mustBe (false) - f4.isCompleted mustBe (false) - - p1 complete Success("Hello") - - Await.ready(latch(7), TestLatch.DefaultTimeout) - - p1.future.isCompleted mustBe (true) - f4.isCompleted mustBe (false) - - latch(8).open() - Await.ready(latch(9), TestLatch.DefaultTimeout) - - Await.ready(f4, defaultTimeout).isCompleted mustBe (true) - } - - @Test def `should not deadlock with nested await (ticket 1313)`(): Unit = { - val simple = async { - await { Future { } } - val unit = Future(()) - val umap = unit map { _ => () } - Await.result(umap, Inf) - } - Await.ready(simple, Inf).isCompleted mustBe (true) - - val l1, l2 = new TestLatch - val complex = async { - await{ Future { } } - blocking { - val nested = Future(()) - for (_ <- nested) l1.open() - Await.ready(l1, TestLatch.DefaultTimeout) // make sure nested is completed - for (_ <- nested) l2.open() - Await.ready(l2, TestLatch.DefaultTimeout) - } - } - Await.ready(complex, defaultTimeout).isCompleted mustBe (true) - } - - @Test def `should not throw when Await.ready`(): Unit = { - val expected = try Success(5 / 0) catch { case a: ArithmeticException => Failure(a) } - val f = async { await(Future(5)) / 0 } - Await.ready(f, defaultTimeout).value.get.toString mustBe expected.toString - } -} - - diff --git a/src/test/scala/scala/async/run/hygiene/Hygiene.scala b/src/test/scala/scala/async/run/hygiene/Hygiene.scala deleted file mode 100644 index 78afecaf..00000000 --- a/src/test/scala/scala/async/run/hygiene/Hygiene.scala +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Scala (https://www.scala-lang.org) - * - * Copyright EPFL and Lightbend, Inc. - * - * Licensed under Apache License 2.0 - * (http://www.apache.org/licenses/LICENSE-2.0). - * - * See the NOTICE file distributed with this work for - * additional information regarding copyright ownership. - */ - -package scala.async -package run -package hygiene - -import org.junit.Test -import scala.async.internal.AsyncId - -class HygieneSpec { - - import AsyncId.{async, await} - - @Test - def `is hygenic`(): Unit = { - val state = 23 - val result: Any = "result" - def resume(): Any = "resume" - val res = async { - val f1 = state + 2 - val x = await(f1) - val y = await(result) - val z = await(resume()) - (x, y, z) - } - res mustBe ((25, "result", "resume")) - } - - @Test - def `external var as result of await`(): Unit = { - var ext = 0 - async { - ext = await(12) - } - ext mustBe (12) - } - - @Test - def `external var as result of await 2`(): Unit = { - var ext = 0 - val inp = 10 - async { - if (inp > 0) - ext = await(12) - else - ext = await(10) - } - ext mustBe (12) - } - - @Test - def `external var as result of await 3`(): Unit = { - var ext = 0 - val inp = 10 - async { - val x = if (inp > 0) - await(12) - else - await(10) - ext = x + await(2) - } - ext mustBe (14) - } - - @Test - def `is hygenic nested`(): Unit = { - val state = 23 - val result: Any = "result" - def resume(): Any = "resume" - import AsyncId.{await, async} - val res = async { - val f1 = async { state + 2 } - val x = await(f1) - val y = await(async { result }) - val z = await(async(await(async { resume() }))) - (x, y, z) - } - res._1 mustBe (25) - res._2 mustBe ("result") - res._3 mustBe ("resume") - } -} diff --git a/src/test/scala/scala/async/run/ifelse0/IfElse0.scala b/src/test/scala/scala/async/run/ifelse0/IfElse0.scala deleted file mode 100644 index 7603f3a3..00000000 --- a/src/test/scala/scala/async/run/ifelse0/IfElse0.scala +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Scala (https://www.scala-lang.org) - * - * Copyright EPFL and Lightbend, Inc. - * - * Licensed under Apache License 2.0 - * (http://www.apache.org/licenses/LICENSE-2.0). - * - * See the NOTICE file distributed with this work for - * additional information regarding copyright ownership. - */ - -package scala.async -package run -package ifelse0 - -import language.{reflectiveCalls, postfixOps} -import scala.concurrent.{Future, ExecutionContext, Await} -import scala.concurrent.duration._ -import scala.async.Async.{async, await} -import org.junit.Test -import scala.async.internal.AsyncId - - -class TestIfElseClass { - - import ExecutionContext.Implicits.global - - def m1(x: Int): Future[Int] = Future { - x + 2 - } - - def m2(y: Int): Future[Int] = async { - val f = m1(y) - var z = 0 - if (y > 0) { - val x1 = await(f) - z = x1 + 2 - } else { - val x2 = await(f) - z = x2 - 2 - } - z - } -} - - -class IfElseSpec { - - @Test def `support await in a simple if-else expression`(): Unit = { - val o = new TestIfElseClass - val fut = o.m2(10) - val res = Await.result(fut, 2 seconds) - res mustBe (14) - } - - @Test def `await in condition`(): Unit = { - import AsyncId.{async, await} - val result = async { - if ({await(true); await(true)}) await(1) else ??? - } - result mustBe (1) - } -} diff --git a/src/test/scala/scala/async/run/ifelse0/WhileSpec.scala b/src/test/scala/scala/async/run/ifelse0/WhileSpec.scala deleted file mode 100644 index cfd08d7e..00000000 --- a/src/test/scala/scala/async/run/ifelse0/WhileSpec.scala +++ /dev/null @@ -1,127 +0,0 @@ -/* - * Scala (https://www.scala-lang.org) - * - * Copyright EPFL and Lightbend, Inc. - * - * Licensed under Apache License 2.0 - * (http://www.apache.org/licenses/LICENSE-2.0). - * - * See the NOTICE file distributed with this work for - * additional information regarding copyright ownership. - */ - -package scala.async -package run -package ifelse0 - -import org.junit.Test -import scala.async.internal.AsyncId - -class WhileSpec { - - @Test - def whiling1(): Unit = { - import AsyncId._ - - val result = async { - var xxx: Int = 0 - var y = 0 - while (xxx < 3) { - y = await(xxx) - xxx = xxx + 1 - } - y - } - result mustBe (2) - } - - @Test - def whiling2(): Unit = { - import AsyncId._ - - val result = async { - var xxx: Int = 0 - var y = 0 - while (false) { - y = await(xxx) - xxx = xxx + 1 - } - y - } - result mustBe (0) - } - - @Test - def nestedWhile(): Unit = { - import AsyncId._ - - val result = async { - var sum = 0 - var i = 0 - while (i < 5) { - var j = 0 - while (j < 5) { - sum += await(i) * await(j) - j += 1 - } - i += 1 - } - sum - } - result mustBe (100) - } - - @Test - def whileExpr(): Unit = { - import AsyncId._ - - val result = async { - var cond = true - while (cond) { - cond = false - await { 22 } - } - } - result mustBe () - } - - @Test def doWhile(): Unit = { - import AsyncId._ - val result = async { - var b = 0 - var x = "" - await(do { - x += "1" - x += await("2") - x += "3" - b += await(1) - } while (b < 2)) - await(x) - } - result mustBe "123123" - } - - @Test def whileAwaitCondition(): Unit = { - import AsyncId._ - val result = async { - var b = true - while(await(b)) { - b = false - } - await(b) - } - result mustBe false - } - - @Test def doWhileAwaitCondition(): Unit = { - import AsyncId._ - val result = async { - var b = true - do { - b = false - } while(await(b)) - b - } - result mustBe false - } -} diff --git a/src/test/scala/scala/async/run/ifelse1/IfElse1.scala b/src/test/scala/scala/async/run/ifelse1/IfElse1.scala deleted file mode 100644 index 28b850b0..00000000 --- a/src/test/scala/scala/async/run/ifelse1/IfElse1.scala +++ /dev/null @@ -1,212 +0,0 @@ -/* - * Scala (https://www.scala-lang.org) - * - * Copyright EPFL and Lightbend, Inc. - * - * Licensed under Apache License 2.0 - * (http://www.apache.org/licenses/LICENSE-2.0). - * - * See the NOTICE file distributed with this work for - * additional information regarding copyright ownership. - */ - -package scala.async -package run -package ifelse1 - -import language.{reflectiveCalls, postfixOps} -import scala.concurrent.{Future, ExecutionContext, Await} -import scala.concurrent.duration._ -import scala.async.Async.{async, await} -import org.junit.Test - - -class TestIfElse1Class { - - import ExecutionContext.Implicits.global - - def base(x: Int): Future[Int] = Future { - x + 2 - } - - def m1(y: Int): Future[Int] = async { - val f = base(y) - var z = 0 - if (y > 0) { - if (y > 100) - 5 - else { - val x1 = await(f) - z = x1 + 2 - } - } else { - val x2 = await(f) - z = x2 - 2 - } - z - } - - def m2(y: Int): Future[Int] = async { - val f = base(y) - var z = 0 - if (y > 0) { - if (y < 100) { - val x1 = await(f) - z = x1 + 2 - } - else - 5 - } else { - val x2 = await(f) - z = x2 - 2 - } - z - } - - def m3(y: Int): Future[Int] = async { - val f = base(y) - var z = 0 - if (y < 0) { - val x2 = await(f) - z = x2 - 2 - } else { - if (y > 100) - 5 - else { - val x1 = await(f) - z = x1 + 2 - } - } - z - } - - def m4(y: Int): Future[Int] = async { - val f = base(y) - var z = 0 - if (y < 0) { - val x2 = await(f) - z = x2 - 2 - } else { - if (y < 100) { - val x1 = await(f) - z = x1 + 2 - } else - 5 - } - z - } - - def pred: Future[Boolean] = async(true) - - def m5: Future[Boolean] = async { - if(if(if(if(if(if(if(if(if(if(if(if(if(if(if(if(if(if(if(if(if(await(pred)) - await(pred) - else - false) - await(pred) - else - false) - await(pred) - else - false) - await(pred) - else - false) - await(pred) - else - false) - await(pred) - else - false) - await(pred) - else - false) - await(pred) - else - false) - await(pred) - else - false) - await(pred) - else - false) - await(pred) - else - false) - await(pred) - else - false) - await(pred) - else - false) - await(pred) - else - false) - await(pred) - else - false) - await(pred) - else - false) - await(pred) - else - false) - await(pred) - else - false) - await(pred) - else - false) - await(pred) - else - false) - await(pred) - else - false - } -} - -class IfElse1Spec { - - @Test - def `await in a nested if-else expression`(): Unit = { - val o = new TestIfElse1Class - val fut = o.m1(10) - val res = Await.result(fut, 2 seconds) - res mustBe (14) - } - - @Test - def `await in a nested if-else expression 2`(): Unit = { - val o = new TestIfElse1Class - val fut = o.m2(10) - val res = Await.result(fut, 2 seconds) - res mustBe (14) - } - - - @Test - def `await in a nested if-else expression 3`(): Unit = { - val o = new TestIfElse1Class - val fut = o.m3(10) - val res = Await.result(fut, 2 seconds) - res mustBe (14) - } - - - @Test - def `await in a nested if-else expression 4`(): Unit = { - val o = new TestIfElse1Class - val fut = o.m4(10) - val res = Await.result(fut, 2 seconds) - res mustBe (14) - } - - @Test - def `await in deeply-nested if-else conditions`(): Unit = { - val o = new TestIfElse1Class - val fut = o.m5 - val res = Await.result(fut, 2 seconds) - res mustBe true - } -} diff --git a/src/test/scala/scala/async/run/ifelse2/ifelse2.scala b/src/test/scala/scala/async/run/ifelse2/ifelse2.scala deleted file mode 100644 index 4527d0d2..00000000 --- a/src/test/scala/scala/async/run/ifelse2/ifelse2.scala +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Scala (https://www.scala-lang.org) - * - * Copyright EPFL and Lightbend, Inc. - * - * Licensed under Apache License 2.0 - * (http://www.apache.org/licenses/LICENSE-2.0). - * - * See the NOTICE file distributed with this work for - * additional information regarding copyright ownership. - */ - -package scala.async -package run -package ifelse2 - -import language.{reflectiveCalls, postfixOps} -import scala.concurrent.{Future, ExecutionContext, Await} -import scala.concurrent.duration._ -import scala.async.Async.{async, await} -import org.junit.Test - - -class TestIfElse2Class { - - import ExecutionContext.Implicits.global - - def base(x: Int): Future[Int] = Future { - x + 2 - } - - def m(y: Int): Future[Int] = async { - val f = base(y) - var z = 0 - if (y > 0) { - val x = await(f) - z = x + 2 - } else { - val x = await(f) - z = x - 2 - } - z - } -} - -class IfElse2Spec { - - @Test - def `variables of the same name in different blocks`(): Unit = { - val o = new TestIfElse2Class - val fut = o.m(10) - val res = Await.result(fut, 2 seconds) - res mustBe (14) - } -} diff --git a/src/test/scala/scala/async/run/ifelse3/IfElse3.scala b/src/test/scala/scala/async/run/ifelse3/IfElse3.scala deleted file mode 100644 index 805d95d6..00000000 --- a/src/test/scala/scala/async/run/ifelse3/IfElse3.scala +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Scala (https://www.scala-lang.org) - * - * Copyright EPFL and Lightbend, Inc. - * - * Licensed under Apache License 2.0 - * (http://www.apache.org/licenses/LICENSE-2.0). - * - * See the NOTICE file distributed with this work for - * additional information regarding copyright ownership. - */ - -package scala.async -package run -package ifelse3 - -import language.{reflectiveCalls, postfixOps} -import scala.concurrent.{Future, ExecutionContext, Await} -import scala.concurrent.duration._ -import scala.async.Async.{async, await} -import org.junit.Test - - -class TestIfElse3Class { - - import ExecutionContext.Implicits.global - - def base(x: Int): Future[Int] = Future { - x + 2 - } - - def m(y: Int): Future[Int] = async { - val f = base(y) - var z = 0 - if (y > 0) { - val x1 = await(f) - var w = x1 + 2 - z = w + 2 - } else { - val x2 = await(f) - var w = x2 + 2 - z = w - 2 - } - z - } -} - - -class IfElse3Spec { - - @Test - def `variables of the same name in different blocks`(): Unit = { - val o = new TestIfElse3Class - val fut = o.m(10) - val res = Await.result(fut, 2 seconds) - res mustBe (16) - } -} diff --git a/src/test/scala/scala/async/run/ifelse4/IfElse4.scala b/src/test/scala/scala/async/run/ifelse4/IfElse4.scala deleted file mode 100644 index a71b62eb..00000000 --- a/src/test/scala/scala/async/run/ifelse4/IfElse4.scala +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Scala (https://www.scala-lang.org) - * - * Copyright EPFL and Lightbend, Inc. - * - * Licensed under Apache License 2.0 - * (http://www.apache.org/licenses/LICENSE-2.0). - * - * See the NOTICE file distributed with this work for - * additional information regarding copyright ownership. - */ - -package scala.async -package run -package ifelse4 - -import language.{reflectiveCalls, postfixOps} -import scala.concurrent.{Future, ExecutionContext, Await} -import scala.concurrent.duration._ -import scala.async.Async.{async, await} -import org.junit.Test - - -class TestIfElse4Class { - - import ExecutionContext.Implicits.global - - class F[A] - class S[A](val id: String) - trait P - - case class K(f: F[_]) - - def result[A](f: F[A]) = async { - new S[A with P]("foo") - } - - def run(k: K) = async { - val res = await(result(k.f)) - // these triggered a crash with mismatched existential skolems - // found : S#10272[_$1#10308 with String#137] where type _$1#10308 - // required: S#10272[_$1#10311 with String#137] forSome { type _$1#10311 } - - // This variation of the crash could be avoided by fixing the over-eager - // generation of states in `If` nodes, which was caused by a bug in label - // detection code. - if(true) { - identity(res) - } - - // This variation remained after the aforementioned fix, however. - // It was fixed by manually typing the `Assign(liftedField, rhs)` AST, - // which is how we avoid these problems through the rest of the ANF transform. - if(true) { - identity(res) - await(result(k.f)) - } - res - } -} - -class IfElse4Spec { - - @Test - def `await result with complex type containing skolem`(): Unit = { - val o = new TestIfElse4Class - val fut = o.run(o.K(null)) - val res = Await.result(fut, 2 seconds) - res.id mustBe ("foo") - } -} diff --git a/src/test/scala/scala/async/run/late/LateExpansion.scala b/src/test/scala/scala/async/run/late/LateExpansion.scala deleted file mode 100644 index 51dbdb28..00000000 --- a/src/test/scala/scala/async/run/late/LateExpansion.scala +++ /dev/null @@ -1,612 +0,0 @@ -/* - * Scala (https://www.scala-lang.org) - * - * Copyright EPFL and Lightbend, Inc. - * - * Licensed under Apache License 2.0 - * (http://www.apache.org/licenses/LICENSE-2.0). - * - * See the NOTICE file distributed with this work for - * additional information regarding copyright ownership. - */ - -package scala.async.run.late - -import java.io.File - -import junit.framework.Assert.assertEquals -import org.junit.{Assert, Ignore, Test} - -import scala.annotation.StaticAnnotation -import scala.annotation.meta.{field, getter} -import scala.async.internal.AsyncId -import scala.reflect.internal.util.ScalaClassLoader.URLClassLoader -import scala.tools.nsc._ -import scala.tools.nsc.plugins.{Plugin, PluginComponent} -import scala.tools.nsc.reporters.StoreReporter -import scala.tools.nsc.transform.TypingTransformers - -// Tests for customized use of the async transform from a compiler plugin, which -// calls it from a new phase that runs after patmat. -class LateExpansion { - - @Test def testRewrittenApply(): Unit = { - val result = wrapAndRun( - """ - | object O { - | case class Foo(a: Any) - | } - | @autoawait def id(a: String) = a - | O.Foo - | id("foo") + id("bar") - | O.Foo(1) - | """.stripMargin) - assertEquals("Foo(1)", result.toString) - } - - @Ignore("Need to use adjustType more pervasively in AsyncTransform, but that exposes bugs in {Type, ... }Symbol's cache invalidation") - @Test def testIsInstanceOfType(): Unit = { - val result = wrapAndRun( - """ - | class Outer - | @autoawait def id(a: String) = a - | val o = new Outer - | id("foo") + id("bar") - | ("": Object).isInstanceOf[o.type] - | """.stripMargin) - assertEquals(false, result) - } - - @Test def testIsInstanceOfTerm(): Unit = { - val result = wrapAndRun( - """ - | class Outer - | @autoawait def id(a: String) = a - | val o = new Outer - | id("foo") + id("bar") - | o.isInstanceOf[Outer] - | """.stripMargin) - assertEquals(true, result) - } - - @Test def testArrayLocalModule(): Unit = { - val result = wrapAndRun( - """ - | class Outer - | @autoawait def id(a: String) = a - | val O = "" - | id("foo") + id("bar") - | new Array[O.type](0) - | """.stripMargin) - assertEquals(classOf[Array[String]], result.getClass) - } - - @Test def test0(): Unit = { - val result = wrapAndRun( - """ - | @autoawait def id(a: String) = a - | id("foo") + id("bar") - | """.stripMargin) - assertEquals("foobar", result) - } - - @Test def testGuard(): Unit = { - val result = wrapAndRun( - """ - | @autoawait def id[A](a: A) = a - | "" match { case _ if id(false) => ???; case _ => "okay" } - | """.stripMargin) - assertEquals("okay", result) - } - - @Test def testExtractor(): Unit = { - val result = wrapAndRun( - """ - | object Extractor { @autoawait def unapply(a: String) = Some((a, a)) } - | "" match { case Extractor(a, b) if "".isEmpty => a == b } - | """.stripMargin) - assertEquals(true, result) - } - - @Test def testNestedMatchExtractor(): Unit = { - val result = wrapAndRun( - """ - | object Extractor { @autoawait def unapply(a: String) = Some((a, a)) } - | "" match { - | case _ if "".isEmpty => - | "" match { case Extractor(a, b) => a == b } - | } - | """.stripMargin) - assertEquals(true, result) - } - - @Test def testCombo(): Unit = { - val result = wrapAndRun( - """ - | object Extractor1 { @autoawait def unapply(a: String) = Some((a + 1, a + 2)) } - | object Extractor2 { @autoawait def unapply(a: String) = Some(a + 3) } - | @autoawait def id(a: String) = a - | println("Test.test") - | val r1 = Predef.identity("blerg") match { - | case x if " ".isEmpty => "case 2: " + x - | case Extractor1(Extractor2(x), y: String) if x == "xxx" => "case 1: " + x + ":" + y - | x match { - | case Extractor1(Extractor2(x), y: String) => - | case _ => - | } - | case Extractor2(x) => "case 3: " + x - | } - | r1 - | """.stripMargin) - assertEquals("case 3: blerg3", result) - } - - @Test def polymorphicMethod(): Unit = { - val result = run( - """ - |import scala.async.run.late.{autoawait,lateasync} - |object Test { - | class C { override def toString = "C" } - | @autoawait def foo[A <: C](a: A): A = a - | @lateasync - | def test1[CC <: C](c: CC): (CC, CC) = { - | val x: (CC, CC) = 0 match { case _ if false => ???; case _ => (foo(c), foo(c)) } - | x - | } - | def test(): (C, C) = test1(new C) - |} - | """.stripMargin) - assertEquals("(C,C)", result.toString) - } - - @Test def shadowing(): Unit = { - val result = run( - """ - |import scala.async.run.late.{autoawait,lateasync} - |object Test { - | trait Foo - | trait Bar extends Foo - | @autoawait def boundary = "" - | @lateasync - | def test: Unit = { - | (new Bar {}: Any) match { - | case foo: Bar => - | boundary - | 0 match { - | case _ => foo; () - | } - | () - | } - | () - | } - |} - | """.stripMargin) - } - - @Test def shadowing0(): Unit = { - val result = run( - """ - |import scala.async.run.late.{autoawait,lateasync} - |object Test { - | trait Foo - | trait Bar - | def test: Any = test(new C) - | @autoawait def asyncBoundary: String = "" - | @lateasync - | def test(foo: Foo): Foo = foo match { - | case foo: Bar => - | val foo2: Foo with Bar = new Foo with Bar {} - | asyncBoundary - | null match { - | case _ => foo2 - | } - | case other => foo - | } - | class C extends Foo with Bar - |} - | """.stripMargin) - } - - @Test def shadowing2(): Unit = { - val result = run( - """ - |import scala.async.run.late.{autoawait,lateasync} - |object Test { - | trait Base; trait Foo[T <: Base] { @autoawait def func: Option[Foo[T]] = None } - | class Sub extends Base - | trait Bar extends Foo[Sub] - | def test: Any = test(new Bar {}) - | @lateasync - | def test[T <: Base](foo: Foo[T]): Foo[T] = foo match { - | case foo: Bar => - | val res = foo.func - | res match { - | case _ => - | } - | foo - | case other => foo - | } - | test(new Bar {}) - |} - | """.stripMargin) - } - - @Test def patternAlternative(): Unit = { - val result = wrapAndRun( - """ - | @autoawait def one = 1 - | - | @lateasync def test = { - | Option(true) match { - | case null | None => false - | case Some(v) => one; v - | } - | } - | """.stripMargin) - } - - @Test def patternAlternativeBothAnnotations(): Unit = { - val result = wrapAndRun( - """ - |import scala.async.run.late.{autoawait,lateasync} - |object Test { - | @autoawait def func1() = "hello" - | @lateasync def func(a: Option[Boolean]) = a match { - | case null | None => func1 + " world" - | case _ => "okay" - | } - | def test: Any = func(None) - |} - | """.stripMargin) - } - - @Test def shadowingRefinedTypes(): Unit = { - val result = run( - s""" - |import scala.async.run.late.{autoawait,lateasync} - |trait Base - |class Sub extends Base - |trait Foo[T <: Base] { - | @autoawait def func: Option[Foo[T]] = None - |} - |trait Bar extends Foo[Sub] - |object Test { - | @lateasync def func[T <: Base](foo: Foo[T]): Foo[T] = foo match { // the whole pattern match will be wrapped with async{ } - | case foo: Bar => - | val res = foo.func // will be rewritten into: await(foo.func) - | res match { - | case Some(v) => v // this will report type mismtach - | case other => foo - | } - | case other => foo - | } - | def test: Any = { val b = new Bar{}; func(b) == b } - |}""".stripMargin) - assertEquals(true, result) - } - - @Test def testMatchEndIssue(): Unit = { - val result = run( - """ - |import scala.async.run.late.{autoawait,lateasync} - |sealed trait Subject - |final class Principal(val name: String) extends Subject - |object Principal { - | def unapply(p: Principal): Option[String] = Some(p.name) - |} - |object Test { - | @autoawait @lateasync - | def containsPrincipal(search: String, value: Subject): Boolean = value match { - | case Principal(name) if name == search => true - | case Principal(name) => containsPrincipal(search, value) - | case other => false - | } - | - | @lateasync - | def test = containsPrincipal("test", new Principal("test")) - |} - | """.stripMargin) - } - - @Test def testGenericTypeBoundaryIssue(): Unit = { - val result = run( - """ - - import scala.async.run.late.{autoawait,lateasync} - trait InstrumentOfValue - trait Security[T <: InstrumentOfValue] extends InstrumentOfValue - class Bound extends Security[Bound] - class Futures extends Security[Futures] - object TestGenericTypeBoundIssue { - @autoawait @lateasync def processBound(bound: Bound): Unit = { println("process Bound") } - @autoawait @lateasync def processFutures(futures: Futures): Unit = { println("process Futures") } - @autoawait @lateasync def doStuff(sec: Security[_]): Unit = { - sec match { - case bound: Bound => processBound(bound) - case futures: Futures => processFutures(futures) - case _ => throw new Exception("Unknown Security type: " + sec) - } - } - } - object Test { @lateasync def test: Unit = TestGenericTypeBoundIssue.doStuff(new Bound) } - """.stripMargin) - } - - @Test def testReturnTupleIssue(): Unit = { - val result = run( - """ - import scala.async.run.late.{autoawait,lateasync} - class TestReturnExprIssue(str: String) { - @autoawait @lateasync def getTestValue = Some(42) - @autoawait @lateasync def doStuff: Int = { - val opt: Option[Int] = getTestValue // here we have an async method invoke - opt match { - case Some(li) => li // use the result somehow - case None => - } - 42 // type mismatch; found : AnyVal required: Int - } - } - object Test { @lateasync def test: Unit = new TestReturnExprIssue("").doStuff } - """.stripMargin) - } - - - @Test def testAfterRefchecksIssue(): Unit = { - val result = run( - """ - import scala.async.run.late.{autoawait,lateasync} - trait Factory[T] { def create: T } - sealed trait TimePoint - class TimeLine[TP <: TimePoint](val tpInitial: Factory[TP]) { - @autoawait @lateasync private[TimeLine] val tp: TP = tpInitial.create - @autoawait @lateasync def timePoint: TP = tp - } - object Test { - def test: Unit = () - } - """) - } - - @Test def testArrayIndexOutOfBoundIssue(): Unit = { - val result = run( - """ - import scala.async.run.late.{autoawait,lateasync} - - sealed trait Result - case object A extends Result - case object B extends Result - case object C extends Result - - object Test { - protected def doStuff(res: Result) = { - class C { - @autoawait def needCheck = false - - @lateasync def m = { - if (needCheck) "NO" - else { - res match { - case A => 1 - case _ => 2 - } - } - } - } - } - - - @lateasync - def test() = doStuff(B) - } - """) - } - - def wrapAndRun(code: String): Any = { - run( - s""" - |import scala.async.run.late.{autoawait,lateasync} - |object Test { - | @lateasync - | def test: Any = { - | $code - | } - |} - | """.stripMargin) - } - - - @Test def testNegativeArraySizeException(): Unit = { - val result = run( - """ - import scala.async.run.late.{autoawait,lateasync} - - object Test { - def foo(foo: Any, bar: Any) = () - @autoawait def getValue = 4.2 - @lateasync def func(f: Any) = { - foo(f match { case _ if "".isEmpty => 2 }, getValue); - } - - @lateasync - def test() = func(4) - } - """) - } - - @Test def testNegativeArraySizeExceptionFine1(): Unit = { - val result = run( - """ - import scala.async.run.late.{autoawait,lateasync} - case class FixedFoo(foo: Int) - class Foobar(val foo: Int, val bar: Double) { - @autoawait @lateasync def getValue = 4.2 - @autoawait @lateasync def func(f: Any) = { - new Foobar(foo = f match { - case FixedFoo(x) => x - case _ => 2 - }, - bar = getValue) - } - } - object Test { - @lateasync def test() = new Foobar(0, 0).func(4) - } - """) - } - - @Test def testByNameOwner(): Unit = { - val result = run( - """ - import scala.async.run.late.{autoawait,lateasync} - object Bleh { - @autoawait @lateasync def asyncCall(): Int = 0 - def byName[T](fn: => T): T = fn - } - object Boffo { - @autoawait @lateasync def jerk(): Unit = { - val pointlessSymbolOwner = 1 match { - case _ => - Bleh.asyncCall() - Bleh.byName { - val whyDoHateMe = 1 - whyDoHateMe - } - } - } - } - object Test { - @lateasync def test() = Boffo.jerk() - } - """) - } - - @Test def testByNameOwner2(): Unit = { - val result = run( - """ - import scala.async.run.late.{autoawait,lateasync} - object Bleh { - @autoawait @lateasync def bleh = Bleh - def byName[T](fn: => T): T = fn - } - object Boffo { - @autoawait @lateasync def slob(): Unit = { - val pointlessSymbolOwner = { - Bleh.bleh.byName { - val whyDoHateMeToo = 1 - whyDoHateMeToo - } - } - } - } - object Test { - @lateasync def test() = Boffo.slob() - } - """) - } - - private def createTempDir(): File = { - val f = File.createTempFile("output", "") - f.delete() - f.mkdirs() - f - } - - def run(code: String): Any = { - val out = createTempDir() - try { - val reporter = new StoreReporter - val settings = new Settings(println(_)) - settings.outdir.value = out.getAbsolutePath - settings.embeddedDefaults(getClass.getClassLoader) - // settings.processArgumentString("-Xprint:patmat,postpatmat,jvm -nowarn") - val isInSBT = !settings.classpath.isSetByUser - if (isInSBT) settings.usejavacp.value = true - val global = new Global(settings, reporter) { - self => - - object late extends { - val global: self.type = self - } with LatePlugin - - override protected def loadPlugins(): List[Plugin] = late :: Nil - } - import global._ - - val run = new Run - val source = newSourceFile(code) - // TreeInterrogation.withDebug { - run.compileSources(source :: Nil) - // } - Assert.assertTrue(reporter.infos.mkString("\n"), !reporter.hasErrors) - val loader = new URLClassLoader(Seq(new File(settings.outdir.value).toURI.toURL), global.getClass.getClassLoader) - val cls = loader.loadClass("Test") - cls.getMethod("test").invoke(null) - } finally { - scala.reflect.io.Path.apply(out).deleteRecursively() - } - } -} - -abstract class LatePlugin extends Plugin { - - import global._ - - override val components: List[PluginComponent] = List(new PluginComponent with TypingTransformers { - val global: LatePlugin.this.global.type = LatePlugin.this.global - - lazy val asyncIdSym = symbolOf[AsyncId.type] - lazy val asyncSym = asyncIdSym.info.member(TermName("async")) - lazy val awaitSym = asyncIdSym.info.member(TermName("await")) - lazy val autoAwaitSym = symbolOf[autoawait] - lazy val lateAsyncSym = symbolOf[lateasync] - - def newTransformer(unit: CompilationUnit) = new TypingTransformer(unit) { - override def transform(tree: Tree): Tree = { - super.transform(tree) match { - case ap@Apply(fun, args) if fun.symbol.hasAnnotation(autoAwaitSym) => - localTyper.typed(Apply(TypeApply(gen.mkAttributedRef(asyncIdSym.typeOfThis, awaitSym), TypeTree(ap.tpe) :: Nil), ap :: Nil)) - case sel@Select(fun, _) if sel.symbol.hasAnnotation(autoAwaitSym) && !(tree.tpe.isInstanceOf[MethodTypeApi] || tree.tpe.isInstanceOf[PolyTypeApi]) => - localTyper.typed(Apply(TypeApply(gen.mkAttributedRef(asyncIdSym.typeOfThis, awaitSym), TypeTree(sel.tpe) :: Nil), sel :: Nil)) - case dd: DefDef if dd.symbol.hasAnnotation(lateAsyncSym) => atOwner(dd.symbol) { - deriveDefDef(dd) { rhs: Tree => - val invoke = Apply(TypeApply(gen.mkAttributedRef(asyncIdSym.typeOfThis, asyncSym), TypeTree(rhs.tpe) :: Nil), List(rhs)) - localTyper.typed(atPos(dd.pos)(invoke)) - } - } - case vd: ValDef if vd.symbol.hasAnnotation(lateAsyncSym) => atOwner(vd.symbol) { - deriveValDef(vd) { rhs: Tree => - val invoke = Apply(TypeApply(gen.mkAttributedRef(asyncIdSym.typeOfThis, asyncSym), TypeTree(rhs.tpe) :: Nil), List(rhs)) - localTyper.typed(atPos(vd.pos)(invoke)) - } - } - case vd: ValDef => - vd - case x => x - } - } - } - - override def newPhase(prev: Phase): Phase = new StdPhase(prev) { - override def apply(unit: CompilationUnit): Unit = { - val translated = newTransformer(unit).transformUnit(unit) - //println(show(unit.body)) - translated - } - } - - override val runsAfter: List[String] = "refchecks" :: Nil - override val phaseName: String = "postpatmat" - - }) - override val description: String = "postpatmat" - override val name: String = "postpatmat" -} - -// Methods with this annotation are translated to having the RHS wrapped in `AsyncId.async { }` -@field -final class lateasync extends StaticAnnotation - -// Calls to methods with this annotation are translated to `AsyncId.await()` -@getter -final class autoawait extends StaticAnnotation diff --git a/src/test/scala/scala/async/run/lazyval/LazyValSpec.scala b/src/test/scala/scala/async/run/lazyval/LazyValSpec.scala deleted file mode 100644 index 6805d28c..00000000 --- a/src/test/scala/scala/async/run/lazyval/LazyValSpec.scala +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Scala (https://www.scala-lang.org) - * - * Copyright EPFL and Lightbend, Inc. - * - * Licensed under Apache License 2.0 - * (http://www.apache.org/licenses/LICENSE-2.0). - * - * See the NOTICE file distributed with this work for - * additional information regarding copyright ownership. - */ - -package scala.async -package run -package lazyval - -import org.junit.Test -import scala.async.internal.AsyncId._ - -class LazyValSpec { - @Test - def lazyValAllowed(): Unit = { - val result = async { - var x = 0 - lazy val y = { x += 1; 42 } - assert(x == 0, x) - val z = await(1) - val result = y + x - assert(x == 1, x) - identity(y) - assert(x == 1, x) - result - } - result mustBe 43 - } -} - diff --git a/src/test/scala/scala/async/run/live/LiveVariablesSpec.scala b/src/test/scala/scala/async/run/live/LiveVariablesSpec.scala deleted file mode 100644 index f4268a73..00000000 --- a/src/test/scala/scala/async/run/live/LiveVariablesSpec.scala +++ /dev/null @@ -1,299 +0,0 @@ -/* - * Scala (https://www.scala-lang.org) - * - * Copyright EPFL and Lightbend, Inc. - * - * Licensed under Apache License 2.0 - * (http://www.apache.org/licenses/LICENSE-2.0). - * - * See the NOTICE file distributed with this work for - * additional information regarding copyright ownership. - */ - -package scala.async -package run -package live - -import org.junit.Test - -import internal.AsyncTestLV -import AsyncTestLV._ - -case class Cell[T](v: T) - -class Meter(val len: Long) extends AnyVal - -case class MCell[T](var v: T) - - -class LiveVariablesSpec { - AsyncTestLV.clear() - - @Test - def `zero out fields of reference type`(): Unit = { - val f = async { Cell(1) } - - def m1(x: Cell[Int]): Cell[Int] = - async { Cell(x.v + 1) } - - def m2(x: Cell[Int]): String = - async { x.v.toString } - - def m3() = async { - val a: Cell[Int] = await(f) // await$1$1 - // a == Cell(1) - val b: Cell[Int] = await(m1(a)) // await$2$1 - // b == Cell(2) - assert(AsyncTestLV.log.exists(_._2 == Cell(1)), AsyncTestLV.log) - val res = await(m2(b)) // await$3$1 - assert(AsyncTestLV.log.exists(_._2 == Cell(2))) - res - } - - assert(m3() == "2") - } - - @Test - def `zero out fields of type Any`(): Unit = { - val f = async { Cell(1) } - - def m1(x: Cell[Int]): Cell[Int] = - async { Cell(x.v + 1) } - - def m2(x: Any): String = - async { x.toString } - - def m3() = async { - val a: Cell[Int] = await(f) // await$4$1 - // a == Cell(1) - val b: Any = await(m1(a)) // await$5$1 - // b == Cell(2) - assert(AsyncTestLV.log.exists(_._2 == Cell(1))) - val res = await(m2(b)) // await$6$1 - assert(AsyncTestLV.log.exists(_._2 == Cell(2))) - res - } - - assert(m3() == "Cell(2)") - } - - @Test - def `do not zero out fields of primitive type`(): Unit = { - val f = async { 1 } - - def m1(x: Int): Cell[Int] = - async { Cell(x + 1) } - - def m2(x: Any): String = - async { x.toString } - - def m3() = async { - val a: Int = await(f) // await$7$1 - // a == 1 - val b: Any = await(m1(a)) // await$8$1 - // b == Cell(2) - // assert(!AsyncTestLV.log.exists(p => p._1 == "await$7$1")) - val res = await(m2(b)) // await$9$1 - assert(AsyncTestLV.log.exists(_._2 == Cell(2))) - res - } - - assert(m3() == "Cell(2)") - } - - @Test - def `zero out fields of value class type`(): Unit = { - val f = async { Cell(1) } - - def m1(x: Cell[Int]): Meter = - async { new Meter(x.v + 1) } - - def m2(x: Any): String = - async { x.toString } - - def m3() = async { - val a: Cell[Int] = await(f) // await$10$1 - // a == Cell(1) - val b: Meter = await(m1(a)) // await$11$1 - // b == Meter(2) - assert(AsyncTestLV.log.exists(_._2 == Cell(1))) - val res = await(m2(b.len)) // await$12$1 - assert(AsyncTestLV.log.exists(_._2.asInstanceOf[Meter].len == 2L)) - res - } - - assert(m3() == "2") - } - - @Test - def `zero out fields after use in loop`(): Unit = { - val f = async { MCell(1) } - - def m1(x: MCell[Int], y: Int): Int = - async { x.v + y } - - def m3() = async { - // state #1 - val a: MCell[Int] = await(f) // await$13$1 - // state #2 - var y = MCell(0) - - while (a.v < 10) { - // state #4 - a.v = a.v + 1 - y = MCell(await(a).v + 1) // await$14$1 - // state #7 - } - - // state #3 - // assert(AsyncTestLV.log.exists(entry => entry._1 == "await$14$1")) - - val b = await(m1(a, y.v)) // await$15$1 - // state #8 - assert(AsyncTestLV.log.exists(_._2 == MCell(10)), AsyncTestLV.log) - assert(AsyncTestLV.log.exists(_._2 == MCell(11))) - b - } - - assert(m3() == 21, m3()) - } - - @Test - def `don't zero captured fields captured lambda`(): Unit = { - val f = async { - val x = "x" - val y = "y" - await(0) - y.reverse - val f = () => assert(x != null) - await(0) - f - } - AsyncTestLV.assertNotNulledOut("x") - AsyncTestLV.assertNulledOut("y") - f() - } - - @Test - def `don't zero captured fields captured by-name`(): Unit = { - def func0[A](a: => A): () => A = () => a - val f = async { - val x = "x" - val y = "y" - await(0) - y.reverse - val f = func0(assert(x != null)) - await(0) - f - } - AsyncTestLV.assertNotNulledOut("x") - AsyncTestLV.assertNulledOut("y") - f() - } - - @Test - def `don't zero captured fields nested class`(): Unit = { - def func0[A](a: => A): () => A = () => a - val f = async { - val x = "x" - val y = "y" - await(0) - y.reverse - val f = new Function0[Unit] { - def apply = assert(x != null) - } - await(0) - f - } - AsyncTestLV.assertNotNulledOut("x") - AsyncTestLV.assertNulledOut("y") - f() - } - - @Test - def `don't zero captured fields nested object`(): Unit = { - def func0[A](a: => A): () => A = () => a - val f = async { - val x = "x" - val y = "y" - await(0) - y.reverse - object f extends Function0[Unit] { - def apply = assert(x != null) - } - await(0) - f - } - AsyncTestLV.assertNotNulledOut("x") - AsyncTestLV.assertNulledOut("y") - f() - } - - @Test - def `don't zero captured fields nested def`(): Unit = { - val f = async { - val x = "x" - val y = "y" - await(0) - y.reverse - def xx = x - val f = xx _ - await(0) - f - } - AsyncTestLV.assertNotNulledOut("x") - AsyncTestLV.assertNulledOut("y") - f() - } - - @Test - def `capture bug`(): Unit = { - sealed trait Base - case class B1() extends Base - case class B2() extends Base - val outer = List[(Base, Int)]((B1(), 8)) - - def getMore(b: Base) = 4 - - def baz = async { - outer.head match { - case (a @ B1(), r) => { - val ents = await(getMore(a)) - - { () => - println(a) - assert(a ne null) - } - } - case (b @ B2(), x) => - () => ??? - } - } - baz() - } - - // https://github.com/scala/async/issues/104 - @Test def dontNullOutVarsOfTypeNothing_t104(): Unit = { - import scala.async.Async._ - import scala.concurrent.duration.Duration - import scala.concurrent.{Await, Future} - import scala.concurrent.ExecutionContext.Implicits.global - def errorGenerator(randomNum: Double) = { - Future { - if (randomNum < 0) { - throw new IllegalStateException("Random number was too low!") - } else { - throw new IllegalStateException("Random number was too high!") - } - } - } - def randomTimesTwo = async { - val num = _root_.scala.math.random - if (num < 0 || num > 1) { - await(errorGenerator(num)) - } - num * 2 - } - Await.result(randomTimesTwo, TestLatch.DefaultTimeout) // was: NotImplementedError - } -} diff --git a/src/test/scala/scala/async/run/match0/Match0.scala b/src/test/scala/scala/async/run/match0/Match0.scala deleted file mode 100644 index d8c136b9..00000000 --- a/src/test/scala/scala/async/run/match0/Match0.scala +++ /dev/null @@ -1,154 +0,0 @@ -/* - * Scala (https://www.scala-lang.org) - * - * Copyright EPFL and Lightbend, Inc. - * - * Licensed under Apache License 2.0 - * (http://www.apache.org/licenses/LICENSE-2.0). - * - * See the NOTICE file distributed with this work for - * additional information regarding copyright ownership. - */ - -package scala.async -package run -package match0 - -import language.{reflectiveCalls, postfixOps} -import scala.concurrent.{Future, ExecutionContext, Await} -import scala.concurrent.duration._ -import scala.async.Async.{async, await} -import org.junit.Test -import scala.async.internal.AsyncId - - -class TestMatchClass { - - import ExecutionContext.Implicits.global - - def m1(x: Int): Future[Int] = Future { - x + 2 - } - - def m2(y: Int): Future[Int] = async { - val f = m1(y) - var z = 0 - y match { - case 10 => - val x1 = await(f) - z = x1 + 2 - case 20 => - val x2 = await(f) - z = x2 - 2 - } - z - } - - def m3(y: Int): Future[Int] = async { - val f = m1(y) - var z = 0 - y match { - case 0 => - val x2 = await(f) - z = x2 - 2 - case 1 => - val x1 = await(f) - z = x1 + 2 - } - z - } -} - - -class MatchSpec { - - @Test def `support await in a simple match expression`(): Unit = { - val o = new TestMatchClass - val fut = o.m2(10) // matches first case - val res = Await.result(fut, 2 seconds) - res mustBe (14) - } - - @Test def `support await in a simple match expression 2`(): Unit = { - val o = new TestMatchClass - val fut = o.m3(1) // matches second case - val res = Await.result(fut, 2 seconds) - res mustBe (5) - } - - @Test def `support await in a match expression with binds`(): Unit = { - val result = AsyncId.async { - val x = 1 - Option(x) match { - case op @ Some(x) => - assert(op.contains(1)) - x + AsyncId.await(x) - case None => AsyncId.await(0) - } - } - result mustBe (2) - } - - @Test def `support await referring to pattern matching vals`(): Unit = { - import AsyncId.{async, await} - val result = async { - val x = 1 - val opt = Some("") - await(0) - val o @ Some(y) = opt - - { - val o @ Some(y) = Some(".") - } - - await(0) - await((o, y.isEmpty)) - } - result mustBe ((Some(""), true)) - } - - @Test def `await in scrutinee`(): Unit = { - import AsyncId.{async, await} - val result = async { - await(if ("".isEmpty) await(1) else ???) match { - case x if x < 0 => ??? - case y: Int => y * await(3) - } - } - result mustBe (3) - } - - @Test def duplicateBindName(): Unit = { - import AsyncId.{async, await} - def m4(m: Any) = async { - m match { - case buf: String => - await(0) - case buf: Double => - await(2) - } - } - m4("") mustBe 0 - } - - @Test def bugCastBoxedUnitToStringMatch(): Unit = { - import scala.async.internal.AsyncId.{async, await} - def foo = async { - val p2 = await(5) - "foo" match { - case p3: String => - p2.toString - } - } - foo mustBe "5" - } - - @Test def bugCastBoxedUnitToStringIf(): Unit = { - import scala.async.internal.AsyncId.{async, await} - def foo = async { - val p2 = await(5) - if (true) p2.toString else p2.toString - } - foo mustBe "5" - } -} diff --git a/src/test/scala/scala/async/run/nesteddef/NestedDef.scala b/src/test/scala/scala/async/run/nesteddef/NestedDef.scala deleted file mode 100644 index 9e2d3c83..00000000 --- a/src/test/scala/scala/async/run/nesteddef/NestedDef.scala +++ /dev/null @@ -1,106 +0,0 @@ -/* - * Scala (https://www.scala-lang.org) - * - * Copyright EPFL and Lightbend, Inc. - * - * Licensed under Apache License 2.0 - * (http://www.apache.org/licenses/LICENSE-2.0). - * - * See the NOTICE file distributed with this work for - * additional information regarding copyright ownership. - */ - -package scala.async -package run -package nesteddef - -import org.junit.Test -import scala.async.internal.AsyncId - -class NestedDef { - - @Test - def nestedDef(): Unit = { - import AsyncId._ - val result = async { - val a = 0 - val x = await(a) - 1 - val local = 43 - def bar(d: Double) = -d + a + local - def foo(z: Any) = (a.toDouble, bar(x).toDouble, z) - foo(await(2)) - } - result mustBe ((0d, 44d, 2)) - } - - - @Test - def nestedFunction(): Unit = { - import AsyncId._ - val result = async { - val a = 0 - val x = await(a) - 1 - val local = 43 - val bar = (d: Double) => -d + a + local - val foo = (z: Any) => (a.toDouble, bar(x).toDouble, z) - foo(await(2)) - } - result mustBe ((0d, 44d, 2)) - } - - // We must lift `foo` and `bar` in the next two tests. - @Test - def nestedDefTransitive1(): Unit = { - import AsyncId._ - val result = async { - val a = 0 - val x = await(a) - 1 - def bar = a - def foo = bar - foo - } - result mustBe 0 - } - - @Test - def nestedDefTransitive2(): Unit = { - import AsyncId._ - val result = async { - val a = 0 - val x = await(a) - 1 - def bar = a - def foo = bar - 0 - } - result mustBe 0 - } - - - // checking that our use/definition analysis doesn't cycle. - @Test - def mutuallyRecursive1(): Unit = { - import AsyncId._ - val result = async { - val a = 0 - val x = await(a) - 1 - def foo: Int = if (true) 0 else bar - def bar: Int = if (true) 0 else foo - bar - } - result mustBe 0 - } - - // checking that our use/definition analysis doesn't cycle. - @Test - def mutuallyRecursive2(): Unit = { - import AsyncId._ - val result = async { - val a = 0 - def foo: Int = if (true) 0 else bar - def bar: Int = if (true) 0 else foo - val x = await(a) - 1 - bar - } - result mustBe 0 - } -} diff --git a/src/test/scala/scala/async/run/noawait/NoAwaitSpec.scala b/src/test/scala/scala/async/run/noawait/NoAwaitSpec.scala deleted file mode 100644 index f6f6afb0..00000000 --- a/src/test/scala/scala/async/run/noawait/NoAwaitSpec.scala +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Scala (https://www.scala-lang.org) - * - * Copyright EPFL and Lightbend, Inc. - * - * Licensed under Apache License 2.0 - * (http://www.apache.org/licenses/LICENSE-2.0). - * - * See the NOTICE file distributed with this work for - * additional information regarding copyright ownership. - */ - -package scala.async -package run -package noawait - -import scala.async.internal.AsyncId -import AsyncId._ -import org.junit.Test - -class NoAwaitSpec { - @Test - def `async block without await`(): Unit = { - def foo = 1 - async { - foo - foo - } mustBe (foo) - } - - @Test - def `async block without await 2`(): Unit = { - async { - def x = 0 - if (x > 0) 0 else 1 - } mustBe (1) - } - - @Test - def `async expr without await`(): Unit = { - def foo = 1 - async(foo) mustBe (foo) - } -} diff --git a/src/test/scala/scala/async/run/stackoverflow/StackOverflowSpec.scala b/src/test/scala/scala/async/run/stackoverflow/StackOverflowSpec.scala deleted file mode 100644 index 8e3127a0..00000000 --- a/src/test/scala/scala/async/run/stackoverflow/StackOverflowSpec.scala +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Scala (https://www.scala-lang.org) - * - * Copyright EPFL and Lightbend, Inc. - * - * Licensed under Apache License 2.0 - * (http://www.apache.org/licenses/LICENSE-2.0). - * - * See the NOTICE file distributed with this work for - * additional information regarding copyright ownership. - */ - -package scala.async -package run -package stackoverflow - -import org.junit.Test -import scala.async.internal.AsyncId - - -class StackOverflowSpec { - - @Test - def stackSafety(): Unit = { - import AsyncId._ - async { - var i = 100000000 - while (i > 0) { - if (false) { - await(()) - } - i -= 1 - } - } - } -} diff --git a/src/test/scala/scala/async/run/toughtype/ToughType.scala b/src/test/scala/scala/async/run/toughtype/ToughType.scala deleted file mode 100644 index f7002b57..00000000 --- a/src/test/scala/scala/async/run/toughtype/ToughType.scala +++ /dev/null @@ -1,362 +0,0 @@ -/* - * Scala (https://www.scala-lang.org) - * - * Copyright EPFL and Lightbend, Inc. - * - * Licensed under Apache License 2.0 - * (http://www.apache.org/licenses/LICENSE-2.0). - * - * See the NOTICE file distributed with this work for - * additional information regarding copyright ownership. - */ - -package scala.async -package run -package toughtype - -import language.{reflectiveCalls, postfixOps} -import scala.concurrent._ -import scala.concurrent.duration._ -import scala.async.Async._ -import org.junit.{Assert, Test} -import scala.async.internal.AsyncId - - -object ToughTypeObject { - - import ExecutionContext.Implicits.global - - class Inner - - def m2 = async[(List[_], ToughTypeObject.Inner)] { - val y = await(Future[List[_]](Nil)) - val z = await(Future[Inner](new Inner)) - (y, z) - } -} - -class ToughTypeSpec { - - @Test def `propogates tough types`(): Unit = { - val fut = ToughTypeObject.m2 - val res: (List[_], scala.async.run.toughtype.ToughTypeObject.Inner) = Await.result(fut, 2 seconds) - res._1 mustBe (Nil) - } - - @Test def patternMatchingPartialFunction(): Unit = { - import AsyncId.{await, async} - async { - await(1) - val a = await(1) - val f = { case x => x + a }: PartialFunction[Int, Int] - await(f(2)) - } mustBe 3 - } - - @Test def patternMatchingPartialFunctionNested(): Unit = { - import AsyncId.{await, async} - async { - await(1) - val neg1 = -1 - val a = await(1) - val f = { case x => ({case x => neg1 * x}: PartialFunction[Int, Int])(x + a) }: PartialFunction[Int, Int] - await(f(2)) - } mustBe -3 - } - - @Test def patternMatchingFunction(): Unit = { - import AsyncId.{await, async} - async { - await(1) - val a = await(1) - val f = { case x => x + a }: Function[Int, Int] - await(f(2)) - } mustBe 3 - } - - @Test def existentialBindIssue19(): Unit = { - import AsyncId.{await, async} - def m7(a: Any) = async { - a match { - case s: Seq[_] => - val x = s.size - var ss = s - ss = s - await(x) - } - } - m7(Nil) mustBe 0 - } - - @Test def existentialBind2Issue19(): Unit = { - import scala.async.Async._, scala.concurrent.ExecutionContext.Implicits.global - def conjure[T]: T = null.asInstanceOf[T] - - def m3 = async { - val p: List[Option[_]] = conjure[List[Option[_]]] - await(Future(1)) - } - - def m4 = async { - await(Future[List[_]](Nil)) - } - } - - @Test def singletonTypeIssue17(): Unit = { - import AsyncId.{async, await} - class A { class B } - async { - val a = new A - def foo(b: a.B) = 0 - await(foo(new a.B)) - } - } - - @Test def existentialMatch(): Unit = { - import AsyncId.{async, await} - trait Container[+A] - case class ContainerImpl[A](value: A) extends Container[A] - def foo: Container[_] = async { - val a: Any = List(1) - a match { - case buf: Seq[_] => - val foo = await(5) - val e0 = buf(0) - ContainerImpl(e0) - } - } - foo - } - - @Test def existentialIfElse0(): Unit = { - import AsyncId.{async, await} - trait Container[+A] - case class ContainerImpl[A](value: A) extends Container[A] - def foo: Container[_] = async { - val a: Any = List(1) - if (true) { - val buf: Seq[_] = List(1) - val foo = await(5) - val e0 = buf(0) - ContainerImpl(e0) - } else ??? - } - foo - } - - // This test was failing when lifting `def r` with: - // symbol value m#10864 does not exist in r$1 - // - // We generated: - // - // private[this] def r$1#5727[A#5728 >: Nothing#157 <: Any#156](m#5731: Foo#2349[A#5728]): Unit#208 = Bippy#2352.this.bar#5532({ - // m#5730; - // () - // }); - // - // Notice the incorrect reference to `m`. - // - // We compensated in `Lifter` by copying `ValDef` parameter symbols directly across. - // - // Turns out the behaviour stems from `thisMethodType` in `Namers`, which treats type parameter skolem symbols. - @Test def nestedMethodWithInconsistencyTreeAndInfoParamSymbols(): Unit = { - import language.{reflectiveCalls, postfixOps} - import scala.concurrent.{Future, ExecutionContext, Await} - import scala.concurrent.duration._ - import scala.async.Async.{async, await} - import scala.async.internal.AsyncId - - class Foo[A] - - object Bippy { - - import ExecutionContext.Implicits.global - - def bar(f: => Unit): Unit = f - - def quux: Future[String] = ??? - - def foo = async { - def r[A](m: Foo[A])(n: A) = { - bar { - locally(m) - locally(n) - identity[A] _ - } - } - - await(quux) - - r(new Foo[String])("") - } - } - Bippy - } - - @Test - def ticket63(): Unit = { - import scala.async.Async._ - import scala.concurrent.{ ExecutionContext, Future } - - object SomeExecutionContext extends ExecutionContext { - def reportFailure(t: Throwable): Unit = ??? - def execute(runnable: Runnable): Unit = ??? - } - - trait FunDep[W, S, R] { - def method(w: W, s: S): Future[R] - } - - object FunDep { - implicit def `Something to do with List`[W, S, R](implicit funDep: FunDep[W, S, R]) = - new FunDep[W, List[S], W] { - def method(w: W, l: List[S]) = async { - val it = l.iterator - while (it.hasNext) { - await(funDep.method(w, it.next())) - } - w - }(SomeExecutionContext) - } - } - - } - - @Test def ticket66Nothing(): Unit = { - import scala.concurrent.Future - import scala.concurrent.ExecutionContext.Implicits.global - val e = new Exception() - val f: Future[Nothing] = Future.failed(e) - val f1 = async { - await(f) - } - try { - Await.result(f1, 5.seconds) - } catch { - case `e` => - } - } - - @Test def ticket83ValueClass(): Unit = { - import scala.async.Async._ - import scala.concurrent._, duration._, ExecutionContext.Implicits.global - val f = async { - val uid = new IntWrapper("foo") - await(Future(uid)) - } - val result = Await.result(f, 5.seconds) - result mustEqual (new IntWrapper("foo")) - } - - @Test def ticket86NestedValueClass(): Unit = { - import ExecutionContext.Implicits.global - - val f = async { - val a = Future.successful(new IntWrapper("42")) - await(await(a).plusStr) - } - val result = Await.result(f, 5.seconds) - result mustEqual "42!" - } - - @Test def ticket86MatchedValueClass(): Unit = { - import ExecutionContext.Implicits.global - - def doAThing(param: IntWrapper) = Future(None) - - val fut = async { - Option(new IntWrapper("value!")) match { - case Some(valueHolder) => - await(doAThing(valueHolder)) - case None => - None - } - } - - val result = Await.result(fut, 5.seconds) - result mustBe None - } - - @Test def ticket86MatchedParameterizedValueClass(): Unit = { - import ExecutionContext.Implicits.global - - def doAThing(param: ParamWrapper[String]) = Future(None) - - val fut = async { - Option(new ParamWrapper("value!")) match { - case Some(valueHolder) => - await(doAThing(valueHolder)) - case None => - None - } - } - - val result = Await.result(fut, 5.seconds) - result mustBe None - } - - @Test def ticket86PrivateValueClass(): Unit = { - import ExecutionContext.Implicits.global - - def doAThing(param: PrivateWrapper) = Future(None) - - val fut = async { - Option(PrivateWrapper.Instance) match { - case Some(valueHolder) => - await(doAThing(valueHolder)) - case None => - None - } - } - - val result = Await.result(fut, 5.seconds) - result mustBe None - } - - @Test def awaitOfAbstractType(): Unit = { - import ExecutionContext.Implicits.global - - def combine[A](a1: A, a2: A): A = a1 - - def combineAsync[A](a1: Future[A], a2: Future[A]) = async { - combine(await(a1), await(a2)) - } - - val fut = combineAsync(Future(1), Future(2)) - - val result = Await.result(fut, 5.seconds) - result mustEqual 1 - } - - // https://github.com/scala/async/issues/106 - @Test def valueClassT106(): Unit = { - import scala.async.internal.AsyncId._ - async { - "whatever value" match { - case _ => - await("whatever return type") - new IntWrapper("value class matters") - } - "whatever return type" - } - } -} - -class IntWrapper(val value: String) extends AnyVal { - def plusStr = Future.successful(value + "!") -} -class ParamWrapper[T](val value: T) extends AnyVal - -class PrivateWrapper private (private val value: String) extends AnyVal -object PrivateWrapper { - def Instance = new PrivateWrapper("") -} - - -trait A - -trait B - -trait L[A2, B2 <: A2] { - def bar(a: Any, b: Any) = 0 -} diff --git a/src/test/scala/scala/async/run/uncheckedBounds/UncheckedBoundsSpec.scala b/src/test/scala/scala/async/run/uncheckedBounds/UncheckedBoundsSpec.scala deleted file mode 100644 index 435a14be..00000000 --- a/src/test/scala/scala/async/run/uncheckedBounds/UncheckedBoundsSpec.scala +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Scala (https://www.scala-lang.org) - * - * Copyright EPFL and Lightbend, Inc. - * - * Licensed under Apache License 2.0 - * (http://www.apache.org/licenses/LICENSE-2.0). - * - * See the NOTICE file distributed with this work for - * additional information regarding copyright ownership. - */ - -package scala.async -package run -package uncheckedBounds - -import org.junit.{Test, Assert} -import scala.async.TreeInterrogation - -class UncheckedBoundsSpec { - @Test def insufficientLub_SI_7694(): Unit = { - eval( s""" - object Test { - import _root_.scala.async.run.toughtype._ - import _root_.scala.async.internal.AsyncId.{async, await} - async { - (if (true) await(null: L[A, A]) else await(null: L[B, B])) - } - } - """, compileOptions = s"-cp ${toolboxClasspath} ") - } - - @Test def insufficientLub_SI_7694_ScalaConcurrent(): Unit = { - eval( s""" - object Test { - import _root_.scala.async.run.toughtype._ - import _root_.scala.async.Async.{async, await} - import scala.concurrent._ - import scala.concurrent.ExecutionContext.Implicits.global - async { - (if (true) await(null: Future[L[A, A]]) else await(null: Future[L[B, B]])) - } - } - """, compileOptions = s"-cp ${toolboxClasspath} ") - } - -}