From bc70eac9ec356eb650110255abdb4119b9b4c76f Mon Sep 17 00:00:00 2001 From: Jakob Odersky Date: Tue, 23 Jul 2024 15:06:44 +0200 Subject: [PATCH 1/2] Add InputContext parameter --- cask/src-2/cask/main/Routes.scala | 2 +- cask/src-2/cask/router/Macros.scala | 2 +- .../cask/router/RoutesEndpointMetadata.scala | 6 +-- cask/src-3/cask/main/Routes.scala | 2 +- cask/src-3/cask/router/Macros.scala | 20 ++++---- .../cask/router/RoutesEndpointMetadata.scala | 10 ++-- cask/src/cask/decorators/compress.scala | 2 +- cask/src/cask/endpoints/FormEndpoint.scala | 1 + cask/src/cask/endpoints/JsonEndpoint.scala | 4 +- cask/src/cask/endpoints/StaticEndpoints.scala | 4 +- cask/src/cask/endpoints/WebEndpoints.scala | 2 +- .../cask/endpoints/WebSocketEndpoint.scala | 4 +- cask/src/cask/main/Main.scala | 4 +- cask/src/cask/router/Decorators.scala | 40 ++++++++-------- cask/src/cask/router/EndpointMetadata.scala | 46 +++++++++---------- cask/test/src/test/cask/FailureTests.scala | 2 +- example/decorated/app/src/Decorated.scala | 6 +-- example/decorated2/app/src/Decorated2.scala | 4 +- example/endpoints/app/src/Endpoints.scala | 2 +- example/todo/app/src/TodoServer.scala | 2 +- example/todoDb/app/src/TodoMvcDb.scala | 2 +- 21 files changed, 85 insertions(+), 82 deletions(-) diff --git a/cask/src-2/cask/main/Routes.scala b/cask/src-2/cask/main/Routes.scala index 18aff11a74..247cd8b6ae 100644 --- a/cask/src-2/cask/main/Routes.scala +++ b/cask/src-2/cask/main/Routes.scala @@ -6,7 +6,7 @@ import language.experimental.macros trait Routes{ - def decorators = Seq.empty[cask.router.Decorator[_, _, _]] + def decorators = Seq.empty[cask.router.Decorator[_, _, _, _]] private[this] var metadata0: RoutesEndpointsMetadata[this.type] = null def caskMetadata = if (metadata0 != null) metadata0 diff --git a/cask/src-2/cask/router/Macros.scala b/cask/src-2/cask/router/Macros.scala index e265b1a377..6a8e14294a 100644 --- a/cask/src-2/cask/router/Macros.scala +++ b/cask/src-2/cask/router/Macros.scala @@ -64,7 +64,7 @@ class Macros[C <: blackbox.Context](val c: C) { val ctxSymbol = q"${c.fresh[TermName](TermName("ctx"))}" val argData = for(argListIndex <- method.paramLists.indices) yield{ val annotDeserializeType = annotDeserializeTypes.lift(argListIndex).getOrElse(tq"scala.Any") - val argReader = argReaders.lift(argListIndex).getOrElse(q"cask.router.NoOpParser.instanceAny") + val argReader = argReaders.lift(argListIndex).getOrElse(q"cask.router.NoOpParser.instanceAnyRequest") val flattenedArgLists = method.paramss(argListIndex) def hasDefault(i: Int) = { // defaults are numbered globally on a class-level, this means that we diff --git a/cask/src-2/cask/router/RoutesEndpointMetadata.scala b/cask/src-2/cask/router/RoutesEndpointMetadata.scala index 5ce25ef4f9..8169781e22 100644 --- a/cask/src-2/cask/router/RoutesEndpointMetadata.scala +++ b/cask/src-2/cask/router/RoutesEndpointMetadata.scala @@ -15,15 +15,15 @@ object RoutesEndpointsMetadata{ val routeParts = for{ m <- c.weakTypeOf[T].members - annotations = m.annotations.filter(_.tree.tpe <:< c.weakTypeOf[Decorator[_, _, _]]) + annotations = m.annotations.filter(_.tree.tpe <:< c.weakTypeOf[Decorator[_, _, _, _]]) if annotations.nonEmpty } yield { - if(!(annotations.last.tree.tpe <:< weakTypeOf[Endpoint[_, _, _]])) c.abort( + if(!(annotations.last.tree.tpe <:< weakTypeOf[Endpoint[_, _, _, _]])) c.abort( annotations.head.tree.pos, s"Last annotation applied to a function must be an instance of Endpoint, " + s"not ${annotations.last.tree.tpe}" ) - val allEndpoints = annotations.filter(_.tree.tpe <:< weakTypeOf[Endpoint[_, _, _]]) + val allEndpoints = annotations.filter(_.tree.tpe <:< weakTypeOf[Endpoint[_, _, _, _]]) if(allEndpoints.length > 1) c.abort( annotations.last.tree.pos, s"You can only apply one Endpoint annotation to a function, not " + diff --git a/cask/src-3/cask/main/Routes.scala b/cask/src-3/cask/main/Routes.scala index dec32d257b..bce3552176 100644 --- a/cask/src-3/cask/main/Routes.scala +++ b/cask/src-3/cask/main/Routes.scala @@ -6,7 +6,7 @@ import language.experimental.macros trait Routes{ - def decorators = Seq.empty[cask.router.Decorator[_, _, _]] + def decorators = Seq.empty[cask.router.Decorator[_, _, _, _]] private[this] var metadata0: RoutesEndpointsMetadata[this.type] = null def caskMetadata = if (metadata0 != null) metadata0 diff --git a/cask/src-3/cask/router/Macros.scala b/cask/src-3/cask/router/Macros.scala index 78561ad1bc..7ceb70b339 100644 --- a/cask/src-3/cask/router/Macros.scala +++ b/cask/src-3/cask/router/Macros.scala @@ -9,15 +9,15 @@ object Macros { * This replicates EndpointMetadata.seqify, but in a macro where error * positions can be controlled. */ - def checkDecorators(using Quotes)(decorators: List[Expr[Decorator[_, _, _]]]): Boolean = { + def checkDecorators(using Quotes)(decorators: List[Expr[Decorator[_, _, _, _]]]): Boolean = { import quotes.reflect._ var hasErrors = false - def check(prevOuter: TypeRepr, decorators: List[Expr[Decorator[_, _, _]]]): Unit = + def check(prevOuter: TypeRepr, decorators: List[Expr[Decorator[_, _, _, _]]]): Unit = decorators match { case Nil => - case '{ $d: Decorator[outer, inner, _] } :: tail => + case '{ $d: Decorator[outer, inner, _, _] } :: tail => if (TypeRepr.of[inner] <:< prevOuter) { check(TypeRepr.of[outer], tail) } else { @@ -56,7 +56,7 @@ object Macros { /** Summon the reader for a parameter. */ def summonReader(using Quotes)( - decorator: Expr[Decorator[_,_,_]], + decorator: Expr[Decorator[_,_,_,_]], param: quotes.reflect.Symbol ): Expr[ArgReader[_, _, _]] = { import quotes.reflect._ @@ -143,13 +143,13 @@ object Macros { */ def convertToResponse(using Quotes)( method: quotes.reflect.Symbol, - endpoint: Expr[Endpoint[_, _, _]], + endpoint: Expr[Endpoint[_, _, _, _]], result: Expr[Any] ): Expr[Any] = { import quotes.reflect._ val innerReturnedTpt = endpoint.asTerm.tpe.asType match { - case '[Endpoint[_, innerReturned, _]] => TypeRepr.of[innerReturned] + case '[Endpoint[_, innerReturned, _, _]] => TypeRepr.of[innerReturned] case _ => ??? } @@ -186,8 +186,8 @@ object Macros { def extractMethod[Cls: Type](using q: Quotes)( method: quotes.reflect.Symbol, - decorators: List[Expr[Decorator[_, _, _]]], // these must also include the endpoint - endpoint: Expr[Endpoint[_, _, _]] + decorators: List[Expr[Decorator[_, _, _, _]]], // these must also include the endpoint + endpoint: Expr[Endpoint[_, _, _, _]] ): Expr[EntryPoint[Cls, cask.Request]] = { import quotes.reflect._ @@ -198,7 +198,7 @@ object Macros { // sometimes we have more params than annotated decorators, for example if // there are global decorators - val decorator: Option[Expr[Decorator[_, _, _]]] = decorators.lift(idx) + val decorator: Option[Expr[Decorator[_, _, _, _]]] = decorators.lift(idx) val exprs1 = for (param <- params) yield { val paramTree = param.tree.asInstanceOf[ValDef] @@ -231,7 +231,7 @@ object Macros { case Some(deco) => summonReader(deco, param) case None => decoTpe match - case '[t] => '{ NoOpParser.instanceAny[t] } + case '[t] => '{ NoOpParser.instanceAnyRequest[t] } // TODO } '{ diff --git a/cask/src-3/cask/router/RoutesEndpointMetadata.scala b/cask/src-3/cask/router/RoutesEndpointMetadata.scala index 794f539bb5..b676b9e436 100644 --- a/cask/src-3/cask/router/RoutesEndpointMetadata.scala +++ b/cask/src-3/cask/router/RoutesEndpointMetadata.scala @@ -20,18 +20,18 @@ object RoutesEndpointsMetadata{ val routeParts: List[Expr[EndpointMetadata[T]]] = for { m <- TypeRepr.of[T].typeSymbol.memberMethods - annotations = m.annotations.filter(_.tpe <:< TypeRepr.of[Decorator[_, _, _]]) + annotations = m.annotations.filter(_.tpe <:< TypeRepr.of[Decorator[_, _, _, _]]) if (annotations.nonEmpty) } yield { - if(!(annotations.head.tpe <:< TypeRepr.of[Endpoint[_, _, _]])) { + if(!(annotations.head.tpe <:< TypeRepr.of[Endpoint[_, _, _, _]])) { report.error(s"Last annotation applied to a function must be an instance of Endpoint, " + s"not ${annotations.head.tpe.show}", annotations.head.pos ) return '{???} // in this case, we can't continue expansion of this macro } - val allEndpoints = annotations.filter(_.tpe <:< TypeRepr.of[Endpoint[_, _, _]]) + val allEndpoints = annotations.filter(_.tpe <:< TypeRepr.of[Endpoint[_, _, _, _]]) if(allEndpoints.length > 1) { report.error( s"You can only apply one Endpoint annotation to a function, not " + @@ -41,12 +41,12 @@ object RoutesEndpointsMetadata{ return '{???} } - val decorators = annotations.map(_.asExprOf[Decorator[_, _, _]]) + val decorators = annotations.map(_.asExprOf[Decorator[_, _, _, _]]) if (!Macros.checkDecorators(decorators)) return '{???} // there was a type mismatch in the decorator chain - val endpointExpr = decorators.head.asExprOf[Endpoint[_, _, _]] + val endpointExpr = decorators.head.asExprOf[Endpoint[_, _, _, _]] val entrypointExpr = Macros.extractMethod[T](m, decorators, endpointExpr) '{ diff --git a/cask/src/cask/decorators/compress.scala b/cask/src/cask/decorators/compress.scala index 71f4bb4b77..3a28dd020f 100644 --- a/cask/src/cask/decorators/compress.scala +++ b/cask/src/cask/decorators/compress.scala @@ -11,7 +11,7 @@ class compress extends cask.RawDecorator{ .toSeq .flatMap(_.asScala) .flatMap(_.split(", ")) - val finalResult = delegate(Map()).transform{ case v: cask.Response.Raw => + val finalResult = delegate(ctx, Map()).transform{ case v: cask.Response.Raw => val (newData, newHeaders) = if (acceptEncodings.exists(_.toLowerCase == "gzip")) { new Response.Data { def write(out: OutputStream): Unit = { diff --git a/cask/src/cask/endpoints/FormEndpoint.scala b/cask/src/cask/endpoints/FormEndpoint.scala index 956e09130b..a5e36f17ee 100644 --- a/cask/src/cask/endpoints/FormEndpoint.scala +++ b/cask/src/cask/endpoints/FormEndpoint.scala @@ -57,6 +57,7 @@ class postForm(val path: String, override val subpath: Boolean = false) try { val formData = FormParserFactory.builder().build().createParser(ctx.exchange).parseBlocking() delegate( + ctx, formData .iterator() .asScala diff --git a/cask/src/cask/endpoints/JsonEndpoint.scala b/cask/src/cask/endpoints/JsonEndpoint.scala index 50591a9007..72f8277cb4 100644 --- a/cask/src/cask/endpoints/JsonEndpoint.scala +++ b/cask/src/cask/endpoints/JsonEndpoint.scala @@ -66,7 +66,7 @@ abstract class postJsonBase(val path: String, override val subpath: Boolean = fa } yield obj.toMap obj match{ case Left(r) => Result.Success(r.map(Response.Data.WritableData(_))) - case Right(params) => delegate(params) + case Right(params) => delegate(ctx, params) } } def wrapPathSegment(s: String): ujson.Value = ujson.Str(s) @@ -78,7 +78,7 @@ class getJson(val path: String, override val subpath: Boolean = false) type InputParser[T] = QueryParamReader[T] def wrapFunction(ctx: Request, delegate: Delegate): Result[Response.Raw] = { - delegate(WebEndpoint.buildMapFromQueryParams(ctx)) + delegate(ctx, WebEndpoint.buildMapFromQueryParams(ctx)) } def wrapPathSegment(s: String) = Seq(s) } diff --git a/cask/src/cask/endpoints/StaticEndpoints.scala b/cask/src/cask/endpoints/StaticEndpoints.scala index 3482451cb6..68681dc49d 100644 --- a/cask/src/cask/endpoints/StaticEndpoints.scala +++ b/cask/src/cask/endpoints/StaticEndpoints.scala @@ -19,7 +19,7 @@ class staticFiles(val path: String, headers: Seq[(String, String)] = Nil) extend type InputParser[T] = QueryParamReader[T] override def subpath = true def wrapFunction(ctx: Request, delegate: Delegate) = { - delegate(Map()).map{t => + delegate(ctx, Map()).map{t => val (path, contentTypeOpt) = StaticUtil.makePathAndContentType(t, ctx) cask.model.StaticFile(path, headers ++ contentTypeOpt.map("Content-Type" -> _)) } @@ -36,7 +36,7 @@ class staticResources(val path: String, type InputParser[T] = QueryParamReader[T] override def subpath = true def wrapFunction(ctx: Request, delegate: Delegate) = { - delegate(Map()).map { t => + delegate(ctx, Map()).map { t => val (path, contentTypeOpt) = StaticUtil.makePathAndContentType(t, ctx) cask.model.StaticResource(path, resourceRoot, headers ++ contentTypeOpt.map("Content-Type" -> _)) } diff --git a/cask/src/cask/endpoints/WebEndpoints.scala b/cask/src/cask/endpoints/WebEndpoints.scala index b2b9731a62..c6a61f2db1 100644 --- a/cask/src/cask/endpoints/WebEndpoints.scala +++ b/cask/src/cask/endpoints/WebEndpoints.scala @@ -11,7 +11,7 @@ trait WebEndpoint extends HttpEndpoint[Response.Raw, Seq[String]]{ type InputParser[T] = QueryParamReader[T] def wrapFunction(ctx: Request, delegate: Delegate): Result[Response.Raw] = { - delegate(WebEndpoint.buildMapFromQueryParams(ctx)) + delegate(ctx, WebEndpoint.buildMapFromQueryParams(ctx)) } def wrapPathSegment(s: String) = Seq(s) } diff --git a/cask/src/cask/endpoints/WebSocketEndpoint.scala b/cask/src/cask/endpoints/WebSocketEndpoint.scala index ea5bbd6a41..a5032de7dc 100644 --- a/cask/src/cask/endpoints/WebSocketEndpoint.scala +++ b/cask/src/cask/endpoints/WebSocketEndpoint.scala @@ -21,12 +21,12 @@ object WebsocketResult{ } class websocket(val path: String, override val subpath: Boolean = false) - extends cask.router.Endpoint[WebsocketResult, WebsocketResult, Seq[String]]{ + extends cask.router.Endpoint[WebsocketResult, WebsocketResult, Seq[String], Request]{ val methods = Seq("websocket") type InputParser[T] = QueryParamReader[T] type OuterReturned = Result[WebsocketResult] def wrapFunction(ctx: Request, delegate: Delegate) = { - delegate(WebEndpoint.buildMapFromQueryParams(ctx)) + delegate(ctx, WebEndpoint.buildMapFromQueryParams(ctx)) } def wrapPathSegment(s: String): Seq[String] = Seq(s) diff --git a/cask/src/cask/main/Main.scala b/cask/src/cask/main/Main.scala index 15c365666b..b71b6b561d 100644 --- a/cask/src/cask/main/Main.scala +++ b/cask/src/cask/main/Main.scala @@ -29,7 +29,7 @@ class MainRoutes extends Main with Routes{ * application-wide properties. */ abstract class Main{ - def mainDecorators: Seq[Decorator[_, _, _]] = Nil + def mainDecorators: Seq[Decorator[_, _, _, _]] = Nil def allRoutes: Seq[Routes] def port: Int = 8080 def host: String = "localhost" @@ -74,7 +74,7 @@ abstract class Main{ object Main{ class DefaultHandler(dispatchTrie: DispatchTrie[Map[String, (Routes, EndpointMetadata[_])]], - mainDecorators: Seq[Decorator[_, _, _]], + mainDecorators: Seq[Decorator[_, _, _, _]], debugMode: Boolean, handleNotFound: Request => Response.Raw, handleMethodNotAllowed: Request => Response.Raw, diff --git a/cask/src/cask/router/Decorators.scala b/cask/src/cask/router/Decorators.scala index 20824bc21c..a04d030776 100644 --- a/cask/src/cask/router/Decorators.scala +++ b/cask/src/cask/router/Decorators.scala @@ -2,6 +2,7 @@ package cask.router import cask.internal.Conversion import cask.model.{Request, Response} +import java.awt.im.InputContext /** * A [[Decorator]] allows you to annotate a function to wrap it, via @@ -14,10 +15,10 @@ import cask.model.{Request, Response} * to `wrapFunction`, which takes a `Map` representing any additional argument * lists (if any). */ -trait Decorator[OuterReturned, InnerReturned, Input] extends scala.annotation.Annotation { +trait Decorator[OuterReturned, InnerReturned, Input, InputContext] extends scala.annotation.Annotation { final type InputTypeAlias = Input - type InputParser[T] <: ArgReader[Input, T, Request] - final type Delegate = Map[String, Input] => Result[InnerReturned] + type InputParser[T] <: ArgReader[Input, T, InputContext] + final type Delegate = (InputContext, Map[String, Input]) => Result[InnerReturned] def wrapFunction(ctx: Request, delegate: Delegate): Result[OuterReturned] def getParamParser[T](implicit p: InputParser[T]) = p } @@ -34,28 +35,28 @@ object Decorator{ * used as the first argument list. */ def invoke[T](ctx: Request, - endpoint: Endpoint[_, _, _], + endpoint: Endpoint[_, _, _, _], entryPoint: EntryPoint[T, _], routes: T, routeBindings: Map[String, String], - remainingDecorators: List[Decorator[_, _, _]], + remainingDecorators: List[Decorator[_, _, _, _]], bindings: List[Map[String, Any]]): Result[Any] = try { remainingDecorators match { case head :: rest => - head.asInstanceOf[Decorator[Any, Any, Any]].wrapFunction( + head.asInstanceOf[Decorator[Any, Any, Any, Any]].wrapFunction( ctx, - args => invoke(ctx, endpoint, entryPoint, routes, routeBindings, rest, args :: bindings) + (_, args) => invoke(ctx, endpoint, entryPoint, routes, routeBindings, rest, args :: bindings) .asInstanceOf[Result[Nothing]] ) case Nil => - endpoint.wrapFunction(ctx, { (endpointBindings: Map[String, Any]) => + endpoint.wrapFunction(ctx, { (ictx: Any, endpointBindings: Map[String, Any]) => val mergedEndpointBindings = endpointBindings ++ routeBindings.mapValues(endpoint.wrapPathSegment) val finalBindings = mergedEndpointBindings :: bindings entryPoint - .asInstanceOf[EntryPoint[T, cask.model.Request]] - .invoke(routes, ctx, finalBindings) + .asInstanceOf[EntryPoint[T, Any]] + .invoke(routes, ictx, finalBindings) .asInstanceOf[Result[Nothing]] }) } @@ -69,8 +70,8 @@ object Decorator{ * A [[RawDecorator]] is a decorator that operates on the raw request and * response stream, before and after the primary [[Endpoint]] does it's job. */ -trait RawDecorator extends Decorator[Response.Raw, Response.Raw, Any]{ - type InputParser[T] = NoOpParser[Any, T] +trait RawDecorator extends Decorator[Response.Raw, Response.Raw, Any, Request]{ + type InputParser[T] = NoOpParser[Any, T, Request] } @@ -78,8 +79,8 @@ trait RawDecorator extends Decorator[Response.Raw, Response.Raw, Any]{ * An [[HttpEndpoint]] that may return something else than a HTTP response, e.g. * a websocket endpoint which may instead return a websocket event handler */ -trait Endpoint[OuterReturned, InnerReturned, Input] - extends Decorator[OuterReturned, InnerReturned, Input]{ +trait Endpoint[OuterReturned, InnerReturned, Input, InputContext] + extends Decorator[OuterReturned, InnerReturned, Input, InputContext]{ /** * What is the path that this particular endpoint matches? @@ -119,15 +120,16 @@ trait Endpoint[OuterReturned, InnerReturned, Input] * Annotates a Cask endpoint that returns a HTTP [[Response]]; similar to a * [[RawDecorator]] but with additional metadata and capabilities. */ -trait HttpEndpoint[InnerReturned, Input] extends Endpoint[Response.Raw, InnerReturned, Input] +trait HttpEndpoint[InnerReturned, Input] extends Endpoint[Response.Raw, InnerReturned, Input, cask.Request] -class NoOpParser[Input, T] extends ArgReader[Input, T, Request] { +class NoOpParser[Input, T, InputContext] extends ArgReader[Input, T, InputContext] { def arity = 1 - def read(ctx: Request, label: String, input: Input) = input.asInstanceOf[T] + def read(ctx: InputContext, label: String, input: Input) = input.asInstanceOf[T] } object NoOpParser{ - implicit def instance[Input, T]: NoOpParser[Input, T] = new NoOpParser[Input, T] - implicit def instanceAny[T]: NoOpParser[Any, T] = new NoOpParser[Any, T] + implicit def instance[Input, T, InputContext]: NoOpParser[Input, T, InputContext] = new NoOpParser[Input, T, InputContext] + implicit def instanceAny[T, InputContext]: NoOpParser[Any, T, InputContext] = new NoOpParser[Any, T, InputContext] + implicit def instanceAnyRequest[T]: NoOpParser[Any, T, Request] = new NoOpParser[Any, T, Request] } diff --git a/cask/src/cask/router/EndpointMetadata.scala b/cask/src/cask/router/EndpointMetadata.scala index cdc12207ff..650b7a91aa 100644 --- a/cask/src/cask/router/EndpointMetadata.scala +++ b/cask/src/cask/router/EndpointMetadata.scala @@ -1,7 +1,7 @@ package cask.router -case class EndpointMetadata[T](decorators: Seq[Decorator[_, _, _]], - endpoint: Endpoint[_, _, _], +case class EndpointMetadata[T](decorators: Seq[Decorator[_, _, _, _]], + endpoint: Endpoint[_, _, _, _], entryPoint: EntryPoint[T, _]) object EndpointMetadata{ // `seqify` is used to statically check that the decorators applied to each @@ -10,30 +10,30 @@ object EndpointMetadata{ // checking decorators defined as part of cask.Main or cask.Routes, since those // are both more dynamic (and hard to check) and also less often used and thus // less error prone - def seqify1(d: Decorator[_, _, _]) = Seq(d) + def seqify1(d: Decorator[_, _, _, _]) = Seq(d) def seqify2[T1] - (d1: Decorator[T1, _, _]) - (d2: Decorator[_, T1, _]) = Seq(d1, d2) + (d1: Decorator[T1, _, _, _]) + (d2: Decorator[_, T1, _, _]) = Seq(d1, d2) def seqify3[T1, T2] - (d1: Decorator[T1, _, _]) - (d2: Decorator[T2, T1, _]) - (d3: Decorator[_, T2, _]) = Seq(d1, d2, d3) + (d1: Decorator[T1, _, _, _]) + (d2: Decorator[T2, T1, _, _]) + (d3: Decorator[_, T2, _, _]) = Seq(d1, d2, d3) def seqify4[T1, T2, T3] - (d1: Decorator[T1, _, _]) - (d2: Decorator[T2, T1, _]) - (d3: Decorator[T3, T2, _]) - (d4: Decorator[_, T3, _]) = Seq(d1, d2, d3, d4) + (d1: Decorator[T1, _, _, _]) + (d2: Decorator[T2, T1, _, _]) + (d3: Decorator[T3, T2, _, _]) + (d4: Decorator[_, T3, _, _]) = Seq(d1, d2, d3, d4) def seqify5[T1, T2, T3, T4] - (d1: Decorator[T1, _, _]) - (d2: Decorator[T2, T1, _]) - (d3: Decorator[T3, T2, _]) - (d4: Decorator[T4, T3, _]) - (d5: Decorator[_, T4, _]) = Seq(d1, d2, d3, d4, d5) + (d1: Decorator[T1, _, _, _]) + (d2: Decorator[T2, T1, _, _]) + (d3: Decorator[T3, T2, _, _]) + (d4: Decorator[T4, T3, _, _]) + (d5: Decorator[_, T4, _, _]) = Seq(d1, d2, d3, d4, d5) def seqify6[T1, T2, T3, T4, T5] - (d1: Decorator[T1, _, _]) - (d2: Decorator[T2, T1, _]) - (d3: Decorator[T3, T2, _]) - (d4: Decorator[T4, T3, _]) - (d5: Decorator[T5, T4, _]) - (d6: Decorator[_, T5, _]) = Seq(d1, d2, d3, d4) + (d1: Decorator[T1, _, _, _]) + (d2: Decorator[T2, T1, _, _]) + (d3: Decorator[T3, T2, _, _]) + (d4: Decorator[T4, T3, _, _]) + (d5: Decorator[T5, T4, _, _]) + (d6: Decorator[_, T5, _, _]) = Seq(d1, d2, d3, d4) } diff --git a/cask/test/src/test/cask/FailureTests.scala b/cask/test/src/test/cask/FailureTests.scala index 94dd9846d2..3f78e5e11c 100644 --- a/cask/test/src/test/cask/FailureTests.scala +++ b/cask/test/src/test/cask/FailureTests.scala @@ -6,7 +6,7 @@ import utest._ object FailureTests extends TestSuite { class myDecorator extends cask.RawDecorator { def wrapFunction(ctx: Request, delegate: Delegate) = { - delegate(Map("extra" -> 31337)) + delegate(ctx, Map("extra" -> 31337)) } } diff --git a/example/decorated/app/src/Decorated.scala b/example/decorated/app/src/Decorated.scala index 5347629bb0..3a964e9aa6 100644 --- a/example/decorated/app/src/Decorated.scala +++ b/example/decorated/app/src/Decorated.scala @@ -5,19 +5,19 @@ object Decorated extends cask.MainRoutes { } class loggedIn extends cask.RawDecorator { def wrapFunction(ctx: cask.Request, delegate: Delegate) = { - delegate(Map("user" -> new User())) + delegate(ctx, Map("user" -> new User())) } } class withExtra extends cask.RawDecorator { def wrapFunction(ctx: cask.Request, delegate: Delegate) = { - delegate(Map("extra" -> 31337)) + delegate(ctx, Map("extra" -> 31337)) } } class withCustomHeader extends cask.RawDecorator { def wrapFunction(request: cask.Request, delegate: Delegate) = { request.headers.get("x-custom-header").map(_.head) match { - case Some(header) => delegate(Map("customHeader" -> header)) + case Some(header) => delegate(request, Map("customHeader" -> header)) case None => cask.router.Result.Success( cask.model.Response( diff --git a/example/decorated2/app/src/Decorated2.scala b/example/decorated2/app/src/Decorated2.scala index 4a3b9ccc57..90a6a3b195 100644 --- a/example/decorated2/app/src/Decorated2.scala +++ b/example/decorated2/app/src/Decorated2.scala @@ -5,12 +5,12 @@ object Decorated2 extends cask.MainRoutes{ } class loggedIn extends cask.RawDecorator { def wrapFunction(ctx: cask.Request, delegate: Delegate) = { - delegate(Map("user" -> new User())) + delegate(ctx, Map("user" -> new User())) } } class withExtra extends cask.RawDecorator { def wrapFunction(ctx: cask.Request, delegate: Delegate) = { - delegate(Map("extra" -> 31337)) + delegate(ctx, Map("extra" -> 31337)) } } diff --git a/example/endpoints/app/src/Endpoints.scala b/example/endpoints/app/src/Endpoints.scala index f59983c68e..a71b5bc2e1 100644 --- a/example/endpoints/app/src/Endpoints.scala +++ b/example/endpoints/app/src/Endpoints.scala @@ -3,7 +3,7 @@ package app class custom(val path: String, val methods: Seq[String]) extends cask.HttpEndpoint[Int, Seq[String]]{ def wrapFunction(ctx: cask.Request, delegate: Delegate) = { - delegate(Map()).map{num => + delegate(ctx, Map()).map{num => cask.Response("Echo " + num, statusCode = num) } } diff --git a/example/todo/app/src/TodoServer.scala b/example/todo/app/src/TodoServer.scala index 6e80cd9650..44081229c1 100644 --- a/example/todo/app/src/TodoServer.scala +++ b/example/todo/app/src/TodoServer.scala @@ -18,7 +18,7 @@ object TodoServer extends cask.MainRoutes{ class transactional extends cask.RawDecorator{ def wrapFunction(pctx: cask.Request, delegate: Delegate) = { sqliteClient.transaction { txn => - val res = delegate(Map("txn" -> txn)) + val res = delegate(ctx, Map("txn" -> txn)) if (res.isInstanceOf[cask.router.Result.Error]) txn.rollback() res } diff --git a/example/todoDb/app/src/TodoMvcDb.scala b/example/todoDb/app/src/TodoMvcDb.scala index eb3245f883..0b2c72d6bf 100644 --- a/example/todoDb/app/src/TodoMvcDb.scala +++ b/example/todoDb/app/src/TodoMvcDb.scala @@ -15,7 +15,7 @@ object TodoMvcDb extends cask.MainRoutes{ class transactional extends cask.RawDecorator{ def wrapFunction(pctx: cask.Request, delegate: Delegate) = { sqliteClient.transaction { txn => - val res = delegate(Map("txn" -> txn)) + val res = delegate(ctx, Map("txn" -> txn)) if (res.isInstanceOf[cask.router.Result.Error]) txn.rollback() res } From 6cf9c8c4fab38575d8183fce0c2c37bc3fde6008 Mon Sep 17 00:00:00 2001 From: Jakob Odersky Date: Tue, 23 Jul 2024 16:35:24 +0200 Subject: [PATCH 2/2] Add support for passing custom contexts --- build.sc | 4 ++ cask/src-3/cask/router/Macros.scala | 14 ++--- .../cask/router/RoutesEndpointMetadata.scala | 2 +- cask/src/cask/main/Main.scala | 1 + cask/src/cask/router/Decorators.scala | 6 ++- cask/src/cask/router/EntryPoint.scala | 6 +-- .../app/src/DecoratedContext.scala | 54 +++++++++++++++++++ .../app/test/src/ExampleTests.scala | 26 +++++++++ example/decoratedContext/build.sc | 14 +++++ 9 files changed, 114 insertions(+), 13 deletions(-) create mode 100644 example/decoratedContext/app/src/DecoratedContext.scala create mode 100644 example/decoratedContext/app/test/src/ExampleTests.scala create mode 100644 example/decoratedContext/build.sc diff --git a/build.sc b/build.sc index c836fcbdf8..cd845d8a31 100644 --- a/build.sc +++ b/build.sc @@ -7,6 +7,7 @@ import $file.example.compress3.build import $file.example.cookies.build import $file.example.decorated.build import $file.example.decorated2.build +import $file.example.decoratedContext.build import $file.example.endpoints.build import $file.example.formJsonPost.build import $file.example.httpMethods.build @@ -131,6 +132,9 @@ object example extends Module{ trait Decorated2Module extends millbuild.example.decorated2.build.AppModule with LocalModule object decorated2 extends Cross[Decorated2Module](scalaVersions) + trait DecoratedContextModule extends millbuild.example.decoratedContext.build.AppModule with LocalModule + object decoratedContext extends Cross[DecoratedContextModule](scalaVersions) + trait EndpointsModule extends millbuild.example.endpoints.build.AppModule with LocalModule object endpoints extends Cross[EndpointsModule](scalaVersions) diff --git a/cask/src-3/cask/router/Macros.scala b/cask/src-3/cask/router/Macros.scala index 7ceb70b339..4d8429f883 100644 --- a/cask/src-3/cask/router/Macros.scala +++ b/cask/src-3/cask/router/Macros.scala @@ -188,7 +188,7 @@ object Macros { method: quotes.reflect.Symbol, decorators: List[Expr[Decorator[_, _, _, _]]], // these must also include the endpoint endpoint: Expr[Endpoint[_, _, _, _]] - ): Expr[EntryPoint[Cls, cask.Request]] = { + ): Expr[EntryPoint[Cls, Any]] = { import quotes.reflect._ val defaults = getDefaultParams(method) @@ -235,12 +235,12 @@ object Macros { } '{ - ArgSig[Any, Cls, Any, cask.Request]( + ArgSig[Any, Cls, Any, Any]( ${Expr(param.name)}, ${Expr(paramTpeName)}, doc = None, // TODO default = ${defaultGetter} - )(using ${reader}.asInstanceOf[ArgReader[Any, Any, cask.Request]]) + )(using ${reader}.asInstanceOf[ArgReader[Any, Any, Any]]) } } Expr.ofList(exprs1) @@ -248,18 +248,18 @@ object Macros { val sigExprs = Expr.ofList(exprs0) '{ - EntryPoint[Cls, cask.Request]( + EntryPoint[Cls, Any]( name = ${Expr(method.name)}, argSignatures = $sigExprs, doc = None, // TODO invoke0 = ( clazz: Cls, - ctx: cask.Request, + ctxs: Seq[Any], argss: Seq[Map[String, Any]], - sigss: Seq[Seq[ArgSig[Any, _, _, cask.Request]]] + sigss: Seq[Seq[ArgSig[Any, _, _, Any]]] ) => { val parsedArgss: Seq[Seq[Either[Seq[cask.router.Result.ParamError], Any]]] = - sigss.zip(argss).map{ case (sigs, args) => + (sigss, argss, ctxs).zipped.map { case (sigs, args, ctx) => sigs.map{ case sig => Runtime.makeReadCall( args, diff --git a/cask/src-3/cask/router/RoutesEndpointMetadata.scala b/cask/src-3/cask/router/RoutesEndpointMetadata.scala index b676b9e436..188e4e24ab 100644 --- a/cask/src-3/cask/router/RoutesEndpointMetadata.scala +++ b/cask/src-3/cask/router/RoutesEndpointMetadata.scala @@ -50,7 +50,7 @@ object RoutesEndpointsMetadata{ val entrypointExpr = Macros.extractMethod[T](m, decorators, endpointExpr) '{ - val entrypoint: EntryPoint[T, cask.Request] = ${entrypointExpr} + val entrypoint: EntryPoint[T, Any] = ${entrypointExpr} EndpointMetadata[T]( // the Scala 2 version and non-macro code expects decorators to be reversed diff --git a/cask/src/cask/main/Main.scala b/cask/src/cask/main/Main.scala index b71b6b561d..e1986d77ef 100644 --- a/cask/src/cask/main/Main.scala +++ b/cask/src/cask/main/Main.scala @@ -120,6 +120,7 @@ object Main{ routes, routeBindings, (mainDecorators ++ routes.decorators ++ metadata.decorators).toList, + Nil, Nil ) match { case Result.Success(res) => runner(res) diff --git a/cask/src/cask/router/Decorators.scala b/cask/src/cask/router/Decorators.scala index a04d030776..9a596613a2 100644 --- a/cask/src/cask/router/Decorators.scala +++ b/cask/src/cask/router/Decorators.scala @@ -40,23 +40,25 @@ object Decorator{ routes: T, routeBindings: Map[String, String], remainingDecorators: List[Decorator[_, _, _, _]], + inputContexts: List[Any], bindings: List[Map[String, Any]]): Result[Any] = try { remainingDecorators match { case head :: rest => head.asInstanceOf[Decorator[Any, Any, Any, Any]].wrapFunction( ctx, - (_, args) => invoke(ctx, endpoint, entryPoint, routes, routeBindings, rest, args :: bindings) + (ictx, args) => invoke(ctx, endpoint, entryPoint, routes, routeBindings, rest, ictx :: inputContexts, args :: bindings) .asInstanceOf[Result[Nothing]] ) case Nil => endpoint.wrapFunction(ctx, { (ictx: Any, endpointBindings: Map[String, Any]) => + val mergedEndpointBindings = endpointBindings ++ routeBindings.mapValues(endpoint.wrapPathSegment) val finalBindings = mergedEndpointBindings :: bindings entryPoint .asInstanceOf[EntryPoint[T, Any]] - .invoke(routes, ictx, finalBindings) + .invoke(routes, ictx :: inputContexts, finalBindings) .asInstanceOf[Result[Nothing]] }) } diff --git a/cask/src/cask/router/EntryPoint.scala b/cask/src/cask/router/EntryPoint.scala index 2d2b7e5d62..d63575150b 100644 --- a/cask/src/cask/router/EntryPoint.scala +++ b/cask/src/cask/router/EntryPoint.scala @@ -16,14 +16,14 @@ import scala.collection.mutable case class EntryPoint[T, C](name: String, argSignatures: Seq[Seq[ArgSig[_, T, _, C]]], doc: Option[String], - invoke0: (T, C, Seq[Map[String, Any]], Seq[Seq[ArgSig[Any, _, _, C]]]) => Result[Any]){ + invoke0: (T, Seq[C], Seq[Map[String, Any]], Seq[Seq[ArgSig[Any, _, _, C]]]) => Result[Any]){ val firstArgs = argSignatures.head .map(x => x.name -> x) .toMap[String, ArgSig[_, T, _, C]] def invoke(target: T, - ctx: C, + ctxs: Seq[C], paramLists: Seq[Map[String, Any]]): Result[Any] = { val missing = mutable.Buffer.empty[ArgSig[_, T, _, C]] @@ -42,7 +42,7 @@ case class EntryPoint[T, C](name: String, } else { try invoke0( target, - ctx, + ctxs, paramLists, argSignatures.asInstanceOf[Seq[Seq[ArgSig[Any, _, _, C]]]] ) diff --git a/example/decoratedContext/app/src/DecoratedContext.scala b/example/decoratedContext/app/src/DecoratedContext.scala new file mode 100644 index 0000000000..2ac006eaa2 --- /dev/null +++ b/example/decoratedContext/app/src/DecoratedContext.scala @@ -0,0 +1,54 @@ +package app + +case class Context( + session: Session +) + +case class Session(data: collection.mutable.Map[String, String]) + +trait CustomParser[T] extends cask.router.ArgReader[Any, T, Context] +object CustomParser: + given CustomParser[Context] with + def arity = 0 + def read(ctx: Context, label: String, input: Any): Context = ctx + given CustomParser[Session] with + def arity = 0 + def read(ctx: Context, label: String, input: Any): Session = ctx.session + given literal[Literal]: CustomParser[Literal] with + def arity = 1 + def read(ctx: Context, label: String, input: Any): Literal = input.asInstanceOf[Literal] + +object DecoratedContext extends cask.MainRoutes{ + + class custom extends cask.router.Decorator[cask.Response.Raw, cask.Response.Raw, Any, Context]{ + + override type InputParser[T] = CustomParser[T] + + def wrapFunction(req: cask.Request, delegate: Delegate) = { + // Create a custom context out of the request. Custom contexts are useful + // to group an expensive operation that may be used by multiple + // parameter readers. + val ctx = Context(Session(collection.mutable.Map.empty)) // this would typically be populated from a signed cookie + + delegate(ctx, Map("user" -> 1337)).map{ response => + val extraCookies = ctx.session.data.map( + (k, v) => cask.Cookie(k, v) + ) + + response.copy( + cookies = response.cookies ++ extraCookies + ) + } + + } + } + + @custom() + @cask.get("/hello/:world") + def hello(world: String, req: cask.Request)(session: Session, user: Int) = { + session.data("hello") = "world" + world + user + } + + initialize() +} diff --git a/example/decoratedContext/app/test/src/ExampleTests.scala b/example/decoratedContext/app/test/src/ExampleTests.scala new file mode 100644 index 0000000000..630f1ceb77 --- /dev/null +++ b/example/decoratedContext/app/test/src/ExampleTests.scala @@ -0,0 +1,26 @@ +package app +import io.undertow.Undertow + +import utest._ + +object ExampleTests extends TestSuite{ + def withServer[T](example: cask.main.Main)(f: String => T): T = { + val server = Undertow.builder + .addHttpListener(8081, "localhost") + .setHandler(example.defaultHandler) + .build + server.start() + val res = + try f("http://localhost:8081") + finally server.stop() + res + } + + val tests = Tests{ + test("DecoratedContext") - withServer(DecoratedContext){ host => + val response = requests.get(s"$host/hello/woo") + response.text() ==> "woo1337" + response.cookies("hello").getValue ==> "world" + } + } +} diff --git a/example/decoratedContext/build.sc b/example/decoratedContext/build.sc new file mode 100644 index 0000000000..75de91ef2f --- /dev/null +++ b/example/decoratedContext/build.sc @@ -0,0 +1,14 @@ +import mill._, scalalib._ + +trait AppModule extends CrossScalaModule{ + + def ivyDeps = Agg[Dep]( + ) + object test extends ScalaTests with TestModule.Utest{ + + def ivyDeps = Agg( + ivy"com.lihaoyi::utest::0.8.1", + ivy"com.lihaoyi::requests::0.8.0", + ) + } +}