Skip to content

Commit

Permalink
NestedFuture: new wart to avoid nested futures
Browse files Browse the repository at this point in the history
Co-authored-by: Iván Molina Rebolledo <[email protected]>
  • Loading branch information
fabianhjr-dealengine and IvanAtDealEngine committed Jan 15, 2024
1 parent d7cbeb3 commit b1de43e
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 1 deletion.
2 changes: 1 addition & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ lazy val sbtPlug: Project = Project(
.map(_.getName.replaceAll("""\.scala$""", ""))
.filterNot(deprecatedWarts)
.sorted
val expectCount = 12
val expectCount = 13
assert(
warts.size == expectCount,
s"${warts.size} != ${expectCount}. please update build.sbt when add or remove wart"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package org.wartremover
package contrib.warts

object NestedFuture extends WartTraverser {
val message: String =
"""`Future[Future[A]]` will not wait for and discard/cancel the inner future.
|To chain the result of Future to other Future, use flatMap or a for comprehension.
|""".stripMargin

private val futureSymbols: Set[String] = Set(
"scala.concurrent.Future",
"com.twitter.util.Future"
)

def apply(u: WartUniverse): u.Traverser = {
import u.universe._

new Traverser {
override def traverse(tree: Tree): Unit = {
tree match {
// Ignore trees marked by SuppressWarnings
case t if hasWartAnnotation(u)(t) =>
case t: TermTree if futureSymbols.contains(t.tpe.typeSymbol.fullName) =>
t.tpe.typeArgs match {
case Seq(singleArg) if singleArg.typeSymbol.fullName == t.tpe.typeSymbol.fullName =>
warning(u)(tree.pos, message)
super.traverse(tree)
case _ => super.traverse(tree)
}
case _ =>
super.traverse(tree)
}
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package org.wartremover
package contrib.warts

import scala.concurrent.Future

object NestedFuture extends WartTraverser {
val message: String =
"""`Future[Future[A]]` will not wait for and discard/cancel the inner future.
|To chain the result of Future to other Future, use flatMap or a for comprehension.
|""".stripMargin

def apply(u: WartUniverse): u.Traverser = {
new u.Traverser(this) {
import q.reflect.*

override def traverseTree(tree: Tree)(owner: Symbol): Unit =
tree match {
case _ if tree.isExpr =>
tree.asExpr match {
case '{ scala.Predef.??? } => super.traverseTree(tree)(owner)
case '{
type a
$f: Future[Future[`a`]]
} =>
warning(tree.pos, message)
super.traverseTree(tree)(owner)
case _ => super.traverseTree(tree)(owner)
}
case _ => super.traverseTree(tree)(owner)
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package org.wartremover.contrib.test

import org.scalatest.funsuite.AnyFunSuite
import org.wartremover.contrib.warts.NestedFuture
import org.wartremover.test.WartTestTraverser
import scala.concurrent.Future

class NestedFutureTest extends AnyFunSuite with ResultAssertions {
implicit val ec: scala.concurrent.ExecutionContext =
scala.concurrent.ExecutionContext.global

test("single future doesn't warn") {
val result = WartTestTraverser(NestedFuture) {
val f: Future[Unit] = Future.successful(())
}
assertEmpty(result)
}

test("nested Future[Future[Unit]] warns") {
val result = WartTestTraverser(NestedFuture) {
val f: Future[Future[Unit]] = Future.successful(Future.successful(()))
}
assertWarnings(result)(NestedFuture.message, 1)
}

test("func causes nested futures") {
val result = WartTestTraverser(NestedFuture) {
val futureFunc: String => Future[String] = arg => Future.successful(arg)
futureFunc("hello world").map(futureFunc)
}

// NOTE: in scala 2 it emits 2 times, in scala 3 it emits 3 times
assertWarningAnyTimes(result)(NestedFuture.message)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ trait ResultAssertions extends Assertions {
assertResult(List.empty, "result.warnings")(result.warnings.map(skipTraverserPrefix))
}

def assertWarningAnyTimes(result: WartTestTraverser.Result)(message: String) = {
assertResult(List.empty, "result.errors")(result.errors.map(skipTraverserPrefix))
assertResult(Set(message), "result.warnings")(result.warnings.map(skipTraverserPrefix).toSet)
}

def assertWarnings(result: WartTestTraverser.Result)(message: String, times: Int) = {
assertResult(List.empty, "result.errors")(result.errors.map(skipTraverserPrefix))
assertResult(List.fill(times)(message), "result.warnings")(result.warnings.map(skipTraverserPrefix))
Expand Down

0 comments on commit b1de43e

Please sign in to comment.