Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Parameterize context type of decorators #137

Merged
merged 10 commits into from
Nov 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions build.mill
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def zippedExamples = T {
build.example.cookies.millSourcePath,
build.example.decorated.millSourcePath,
build.example.decorated2.millSourcePath,
build.example.decoratedContext.millSourcePath,
build.example.endpoints.millSourcePath,
build.example.formJsonPost.millSourcePath,
build.example.httpMethods.millSourcePath,
Expand Down Expand Up @@ -143,6 +144,7 @@ def zippedExamples = T {
.replaceFirst(
"object app extends.*\ntrait AppModule extends CrossScalaModule(.*)\\{",
s"object app extends ScalaModule $$1\\{\n def scalaVersion = \"${scala213}\"")
.replaceAll("build.scala3", s"\"${scala3}\"")
.replaceFirst(
"def ivyDeps = Agg\\[Dep\\]\\(",
"def ivyDeps = Agg(\n ivy\"com.lihaoyi::cask:" + releaseTag + "\","
Expand Down
2 changes: 1 addition & 1 deletion cask/src-2/cask/main/Routes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import language.experimental.macros

trait Routes{

def decorators = Seq.empty[cask.router.Decorator[_, _, _]]
def decorators = Seq.empty[cask.router.Decorator[_, _, _, _]]
private[this] var metadata0: RoutesEndpointsMetadata[this.type] = null
def caskMetadata =
if (metadata0 != null) metadata0
Expand Down
15 changes: 7 additions & 8 deletions cask/src-2/cask/router/Macros.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ class Macros[C <: blackbox.Context](val c: C) {
def extractMethod(method: MethodSymbol,
curCls: c.universe.Type,
convertToResultType: c.Tree,
ctx: c.Tree,
argReaders: Seq[c.Tree],
annotDeserializeTypes: Seq[c.Tree]): c.universe.Tree = {
val baseArgSym = TermName(c.freshName())
Expand All @@ -64,7 +63,7 @@ class Macros[C <: blackbox.Context](val c: C) {
val ctxSymbol = q"${c.fresh[TermName](TermName("ctx"))}"
val argData = for(argListIndex <- method.paramLists.indices) yield{
val annotDeserializeType = annotDeserializeTypes.lift(argListIndex).getOrElse(tq"scala.Any")
val argReader = argReaders.lift(argListIndex).getOrElse(q"cask.router.NoOpParser.instanceAny")
val argReader = argReaders.lift(argListIndex).getOrElse(q"cask.router.NoOpParser.instanceAnyRequest")
val flattenedArgLists = method.paramss(argListIndex)
def hasDefault(i: Int) = {
// defaults are numbered globally on a class-level, this means that we
Expand Down Expand Up @@ -108,18 +107,18 @@ class Macros[C <: blackbox.Context](val c: C) {

val argSig =
q"""
cask.router.ArgSig[$annotDeserializeType, $curCls, $docUnwrappedType, $ctx](
cask.router.ArgSig[$annotDeserializeType, $curCls, $docUnwrappedType, Any](
${arg.name.toString},
${docUnwrappedType.toString},
$docTree,
$defaultOpt
)($argReader[$docUnwrappedType])
)($argReader[$docUnwrappedType].asInstanceOf[cask.router.ArgReader[$annotDeserializeType, $docUnwrappedType, Any]])
"""

val reader = q"""
cask.router.Runtime.makeReadCall(
$argValuesSymbol($argListIndex),
$ctxSymbol,
$ctxSymbol($argListIndex),
$default,
$argSigsSymbol($argListIndex)($i)
)
Expand Down Expand Up @@ -151,7 +150,7 @@ class Macros[C <: blackbox.Context](val c: C) {
for(argNameCast <- argNameCasts) methodCall = q"$methodCall(..$argNameCast)"

val res = q"""
cask.router.EntryPoint[$curCls, $ctx](
cask.router.EntryPoint[$curCls, Any](
${method.name.toString},
${argSigs.toList},
${methodDoc match{
Expand All @@ -160,9 +159,9 @@ class Macros[C <: blackbox.Context](val c: C) {
}},
(
$baseArgSym: $curCls,
$ctxSymbol: $ctx,
$ctxSymbol: Seq[_],
$argValuesSymbol: Seq[Map[String, Any]],
$argSigsSymbol: scala.Seq[scala.Seq[cask.router.ArgSig[Any, _, _, $ctx]]]
$argSigsSymbol: scala.Seq[scala.Seq[cask.router.ArgSig[Any, _, _, Any]]]
) =>
cask.router.Runtime.validate(Seq(..${readArgs.flatten.toList})).map{
case Seq(..${argNames.flatten.toList}) => $convertToResultType($methodCall)
Expand Down
7 changes: 3 additions & 4 deletions cask/src-2/cask/router/RoutesEndpointMetadata.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@ object RoutesEndpointsMetadata{

val routeParts = for{
m <- c.weakTypeOf[T].members
annotations = m.annotations.filter(_.tree.tpe <:< c.weakTypeOf[Decorator[_, _, _]])
annotations = m.annotations.filter(_.tree.tpe <:< c.weakTypeOf[Decorator[_, _, _, _]])
if annotations.nonEmpty
} yield {
if(!(annotations.last.tree.tpe <:< weakTypeOf[Endpoint[_, _, _]])) c.abort(
if(!(annotations.last.tree.tpe <:< weakTypeOf[Endpoint[_, _, _, _]])) c.abort(
annotations.head.tree.pos,
s"Last annotation applied to a function must be an instance of Endpoint, " +
s"not ${annotations.last.tree.tpe}"
)
val allEndpoints = annotations.filter(_.tree.tpe <:< weakTypeOf[Endpoint[_, _, _]])
val allEndpoints = annotations.filter(_.tree.tpe <:< weakTypeOf[Endpoint[_, _, _, _]])
if(allEndpoints.length > 1) c.abort(
annotations.last.tree.pos,
s"You can only apply one Endpoint annotation to a function, not " +
Expand All @@ -49,7 +49,6 @@ object RoutesEndpointsMetadata{
m.asInstanceOf[MethodSymbol],
weakTypeOf[T],
q"${annotObjectSyms.last}.convertToResultType",
tq"cask.Request",
annotObjectSyms.reverse.map(annotObjectSym => q"$annotObjectSym.getParamParser"),
annotObjectSyms.reverse.map(annotObjectSym => tq"$annotObjectSym.InputTypeAlias")
)
Expand Down
2 changes: 1 addition & 1 deletion cask/src-3/cask/main/Routes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import language.experimental.macros

trait Routes{

def decorators = Seq.empty[cask.router.Decorator[_, _, _]]
def decorators = Seq.empty[cask.router.Decorator[_, _, _, _]]
private[this] var metadata0: RoutesEndpointsMetadata[this.type] = null
def caskMetadata =
if (metadata0 != null) metadata0
Expand Down
36 changes: 18 additions & 18 deletions cask/src-3/cask/router/Macros.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,21 @@ object Macros {
* This replicates EndpointMetadata.seqify, but in a macro where error
* positions can be controlled.
*/
def checkDecorators(using Quotes)(decorators: List[Expr[Decorator[_, _, _]]]): Boolean = {
def checkDecorators(using Quotes)(decorators: List[Expr[Decorator[_, _, _, _]]]): Boolean = {
import quotes.reflect._

var hasErrors = false

def check(prevOuter: TypeRepr, decorators: List[Expr[Decorator[_, _, _]]]): Unit =
def check(prevOuter: TypeRepr, decorators: List[Expr[Decorator[_, _, _, _]]]): Unit =
decorators match {
case Nil =>
case '{ $d: Decorator[outer, inner, _] } :: tail =>
case '{ $d: Decorator[outer, inner, _, _] } :: tail =>
if (TypeRepr.of[inner] <:< prevOuter) {
check(TypeRepr.of[outer], tail)
} else {
hasErrors = true
report.error(
s"required: cask.router.Decorator[_, ${prevOuter.show}, _]",
s"required: cask.router.Decorator[_, ${prevOuter.show}, _, _]",
d
)
}
Expand Down Expand Up @@ -56,7 +56,7 @@ object Macros {

/** Summon the reader for a parameter. */
def summonReader(using Quotes)(
decorator: Expr[Decorator[_,_,_]],
decorator: Expr[Decorator[_,_,_,_]],
param: quotes.reflect.Symbol
): Expr[ArgReader[_, _, _]] = {
import quotes.reflect._
Expand Down Expand Up @@ -143,13 +143,13 @@ object Macros {
*/
def convertToResponse(using Quotes)(
method: quotes.reflect.Symbol,
endpoint: Expr[Endpoint[_, _, _]],
endpoint: Expr[Endpoint[_, _, _, _]],
result: Expr[Any]
): Expr[Any] = {
import quotes.reflect._

val innerReturnedTpt = endpoint.asTerm.tpe.asType match {
case '[Endpoint[_, innerReturned, _]] => TypeRepr.of[innerReturned]
case '[Endpoint[_, innerReturned, _, _]] => TypeRepr.of[innerReturned]
case _ => ???
}

Expand Down Expand Up @@ -186,9 +186,9 @@ object Macros {

def extractMethod[Cls: Type](using q: Quotes)(
method: quotes.reflect.Symbol,
decorators: List[Expr[Decorator[_, _, _]]], // these must also include the endpoint
endpoint: Expr[Endpoint[_, _, _]]
): Expr[EntryPoint[Cls, cask.Request]] = {
decorators: List[Expr[Decorator[_, _, _, _]]], // these must also include the endpoint
endpoint: Expr[Endpoint[_, _, _, _]]
): Expr[EntryPoint[Cls, Any]] = {
import quotes.reflect._

val defaults = getDefaultParams(method)
Expand All @@ -198,7 +198,7 @@ object Macros {

// sometimes we have more params than annotated decorators, for example if
// there are global decorators
val decorator: Option[Expr[Decorator[_, _, _]]] = decorators.lift(idx)
val decorator: Option[Expr[Decorator[_, _, _, _]]] = decorators.lift(idx)

val exprs1 = for (param <- params) yield {
val paramTree = param.tree.asInstanceOf[ValDef]
Expand Down Expand Up @@ -231,35 +231,35 @@ object Macros {
case Some(deco) => summonReader(deco, param)
case None =>
decoTpe match
case '[t] => '{ NoOpParser.instanceAny[t] }
case '[t] => '{ NoOpParser.instanceAnyRequest[t] } // TODO
}

'{
ArgSig[Any, Cls, Any, cask.Request](
ArgSig[Any, Cls, Any, Any](
${Expr(param.name)},
${Expr(paramTpeName)},
doc = None, // TODO
default = ${defaultGetter}
)(using ${reader}.asInstanceOf[ArgReader[Any, Any, cask.Request]])
)(using ${reader}.asInstanceOf[ArgReader[Any, Any, Any]])
}
}
Expr.ofList(exprs1)
}
val sigExprs = Expr.ofList(exprs0)

'{
EntryPoint[Cls, cask.Request](
EntryPoint[Cls, Any](
name = ${Expr(method.name)},
argSignatures = $sigExprs,
doc = None, // TODO
invoke0 = (
clazz: Cls,
ctx: cask.Request,
ctxs: Seq[Any],
argss: Seq[Map[String, Any]],
sigss: Seq[Seq[ArgSig[Any, _, _, cask.Request]]]
sigss: Seq[Seq[ArgSig[Any, _, _, Any]]]
) => {
val parsedArgss: Seq[Seq[Either[Seq[cask.router.Result.ParamError], Any]]] =
sigss.zip(argss).map{ case (sigs, args) =>
(sigss, argss, ctxs).zipped.map { case (sigs, args, ctx) =>
sigs.map{ case sig =>
Runtime.makeReadCall(
args,
Expand Down
12 changes: 6 additions & 6 deletions cask/src-3/cask/router/RoutesEndpointMetadata.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,18 @@ object RoutesEndpointsMetadata{

val routeParts: List[Expr[EndpointMetadata[T]]] = for {
m <- TypeRepr.of[T].typeSymbol.memberMethods
annotations = m.annotations.filter(_.tpe <:< TypeRepr.of[Decorator[_, _, _]])
annotations = m.annotations.filter(_.tpe <:< TypeRepr.of[Decorator[_, _, _, _]])
if (annotations.nonEmpty)
} yield {

if(!(annotations.head.tpe <:< TypeRepr.of[Endpoint[_, _, _]])) {
if(!(annotations.head.tpe <:< TypeRepr.of[Endpoint[_, _, _, _]])) {
report.error(s"Last annotation applied to a function must be an instance of Endpoint, " +
s"not ${annotations.head.tpe.show}",
annotations.head.pos
)
return '{???} // in this case, we can't continue expansion of this macro
}
val allEndpoints = annotations.filter(_.tpe <:< TypeRepr.of[Endpoint[_, _, _]])
val allEndpoints = annotations.filter(_.tpe <:< TypeRepr.of[Endpoint[_, _, _, _]])
if(allEndpoints.length > 1) {
report.error(
s"You can only apply one Endpoint annotation to a function, not " +
Expand All @@ -41,16 +41,16 @@ object RoutesEndpointsMetadata{
return '{???}
}

val decorators = annotations.map(_.asExprOf[Decorator[_, _, _]])
val decorators = annotations.map(_.asExprOf[Decorator[_, _, _, _]])

if (!Macros.checkDecorators(decorators))
return '{???} // there was a type mismatch in the decorator chain

val endpointExpr = decorators.head.asExprOf[Endpoint[_, _, _]]
val endpointExpr = decorators.head.asExprOf[Endpoint[_, _, _, _]]
val entrypointExpr = Macros.extractMethod[T](m, decorators, endpointExpr)

'{
val entrypoint: EntryPoint[T, cask.Request] = ${entrypointExpr}
val entrypoint: EntryPoint[T, Any] = ${entrypointExpr}

EndpointMetadata[T](
// the Scala 2 version and non-macro code expects decorators to be reversed
Expand Down
2 changes: 1 addition & 1 deletion cask/src/cask/decorators/compress.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class compress extends cask.RawDecorator{
.toSeq
.flatMap(_.asScala)
.flatMap(_.split(", "))
val finalResult = delegate(Map()).transform{ case v: cask.Response.Raw =>
val finalResult = delegate(ctx, Map()).transform{ case v: cask.Response.Raw =>
val (newData, newHeaders) = if (acceptEncodings.exists(_.toLowerCase == "gzip")) {
new Response.Data {
def write(out: OutputStream): Unit = {
Expand Down
1 change: 1 addition & 0 deletions cask/src/cask/endpoints/FormEndpoint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class postForm(val path: String, override val subpath: Boolean = false)
.createParser(ctx.exchange)
.parseBlocking()
delegate(
ctx,
formData
.iterator()
.asScala
Expand Down
4 changes: 2 additions & 2 deletions cask/src/cask/endpoints/JsonEndpoint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ abstract class postJsonBase(val path: String, override val subpath: Boolean = fa
} yield obj.toMap
obj match{
case Left(r) => Result.Success(r.map(Response.Data.WritableData(_)))
case Right(params) => delegate(params)
case Right(params) => delegate(ctx, params)
}
}
def wrapPathSegment(s: String): ujson.Value = ujson.Str(s)
Expand All @@ -78,7 +78,7 @@ class getJson(val path: String, override val subpath: Boolean = false)
type InputParser[T] = QueryParamReader[T]

def wrapFunction(ctx: Request, delegate: Delegate): Result[Response.Raw] = {
delegate(WebEndpoint.buildMapFromQueryParams(ctx))
delegate(ctx, WebEndpoint.buildMapFromQueryParams(ctx))
}
def wrapPathSegment(s: String) = Seq(s)
}
4 changes: 2 additions & 2 deletions cask/src/cask/endpoints/StaticEndpoints.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class staticFiles(val path: String, headers: Seq[(String, String)] = Nil) extend
type InputParser[T] = QueryParamReader[T]
override def subpath = true
def wrapFunction(ctx: Request, delegate: Delegate) = {
delegate(Map()).map{t =>
delegate(ctx, Map()).map{t =>
val (path, contentTypeOpt) = StaticUtil.makePathAndContentType(t, ctx)
cask.model.StaticFile(path, headers ++ contentTypeOpt.map("Content-Type" -> _))
}
Expand All @@ -36,7 +36,7 @@ class staticResources(val path: String,
type InputParser[T] = QueryParamReader[T]
override def subpath = true
def wrapFunction(ctx: Request, delegate: Delegate) = {
delegate(Map()).map { t =>
delegate(ctx, Map()).map { t =>
val (path, contentTypeOpt) = StaticUtil.makePathAndContentType(t, ctx)
cask.model.StaticResource(path, resourceRoot, headers ++ contentTypeOpt.map("Content-Type" -> _))
}
Expand Down
2 changes: 1 addition & 1 deletion cask/src/cask/endpoints/WebEndpoints.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ trait WebEndpoint extends HttpEndpoint[Response.Raw, Seq[String]]{
type InputParser[T] = QueryParamReader[T]
def wrapFunction(ctx: Request,
delegate: Delegate): Result[Response.Raw] = {
delegate(WebEndpoint.buildMapFromQueryParams(ctx))
delegate(ctx, WebEndpoint.buildMapFromQueryParams(ctx))
}
def wrapPathSegment(s: String) = Seq(s)
}
Expand Down
4 changes: 2 additions & 2 deletions cask/src/cask/endpoints/WebSocketEndpoint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ object WebsocketResult{
}

class websocket(val path: String, override val subpath: Boolean = false)
extends cask.router.Endpoint[WebsocketResult, WebsocketResult, Seq[String]]{
extends cask.router.Endpoint[WebsocketResult, WebsocketResult, Seq[String], Request]{
val methods = Seq("websocket")
type InputParser[T] = QueryParamReader[T]
type OuterReturned = Result[WebsocketResult]
def wrapFunction(ctx: Request, delegate: Delegate) = {
delegate(WebEndpoint.buildMapFromQueryParams(ctx))
delegate(ctx, WebEndpoint.buildMapFromQueryParams(ctx))
}

def wrapPathSegment(s: String): Seq[String] = Seq(s)
Expand Down
5 changes: 3 additions & 2 deletions cask/src/cask/main/Main.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class MainRoutes extends Main with Routes{
* application-wide properties.
*/
abstract class Main{
def mainDecorators: Seq[Decorator[_, _, _]] = Nil
def mainDecorators: Seq[Decorator[_, _, _, _]] = Nil
def allRoutes: Seq[Routes]
def port: Int = 8080
def host: String = "localhost"
Expand Down Expand Up @@ -74,7 +74,7 @@ abstract class Main{

object Main{
class DefaultHandler(dispatchTrie: DispatchTrie[Map[String, (Routes, EndpointMetadata[_])]],
mainDecorators: Seq[Decorator[_, _, _]],
mainDecorators: Seq[Decorator[_, _, _, _]],
debugMode: Boolean,
handleNotFound: Request => Response.Raw,
handleMethodNotAllowed: Request => Response.Raw,
Expand Down Expand Up @@ -120,6 +120,7 @@ object Main{
routes,
routeBindings,
(mainDecorators ++ routes.decorators ++ metadata.decorators).toList,
Nil,
Nil
) match {
case Result.Success(res) => runner(res)
Expand Down
Loading
Loading