Skip to content

Commit

Permalink
Add support for passing custom contexts
Browse files Browse the repository at this point in the history
  • Loading branch information
jodersky committed Jul 23, 2024
1 parent bc70eac commit 6cf9c8c
Show file tree
Hide file tree
Showing 9 changed files with 114 additions and 13 deletions.
4 changes: 4 additions & 0 deletions build.sc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import $file.example.compress3.build
import $file.example.cookies.build
import $file.example.decorated.build
import $file.example.decorated2.build
import $file.example.decoratedContext.build
import $file.example.endpoints.build
import $file.example.formJsonPost.build
import $file.example.httpMethods.build
Expand Down Expand Up @@ -131,6 +132,9 @@ object example extends Module{
trait Decorated2Module extends millbuild.example.decorated2.build.AppModule with LocalModule
object decorated2 extends Cross[Decorated2Module](scalaVersions)

trait DecoratedContextModule extends millbuild.example.decoratedContext.build.AppModule with LocalModule
object decoratedContext extends Cross[DecoratedContextModule](scalaVersions)

trait EndpointsModule extends millbuild.example.endpoints.build.AppModule with LocalModule
object endpoints extends Cross[EndpointsModule](scalaVersions)

Expand Down
14 changes: 7 additions & 7 deletions cask/src-3/cask/router/Macros.scala
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ object Macros {
method: quotes.reflect.Symbol,
decorators: List[Expr[Decorator[_, _, _, _]]], // these must also include the endpoint
endpoint: Expr[Endpoint[_, _, _, _]]
): Expr[EntryPoint[Cls, cask.Request]] = {
): Expr[EntryPoint[Cls, Any]] = {
import quotes.reflect._

val defaults = getDefaultParams(method)
Expand Down Expand Up @@ -235,31 +235,31 @@ object Macros {
}

'{
ArgSig[Any, Cls, Any, cask.Request](
ArgSig[Any, Cls, Any, Any](
${Expr(param.name)},
${Expr(paramTpeName)},
doc = None, // TODO
default = ${defaultGetter}
)(using ${reader}.asInstanceOf[ArgReader[Any, Any, cask.Request]])
)(using ${reader}.asInstanceOf[ArgReader[Any, Any, Any]])
}
}
Expr.ofList(exprs1)
}
val sigExprs = Expr.ofList(exprs0)

'{
EntryPoint[Cls, cask.Request](
EntryPoint[Cls, Any](
name = ${Expr(method.name)},
argSignatures = $sigExprs,
doc = None, // TODO
invoke0 = (
clazz: Cls,
ctx: cask.Request,
ctxs: Seq[Any],
argss: Seq[Map[String, Any]],
sigss: Seq[Seq[ArgSig[Any, _, _, cask.Request]]]
sigss: Seq[Seq[ArgSig[Any, _, _, Any]]]
) => {
val parsedArgss: Seq[Seq[Either[Seq[cask.router.Result.ParamError], Any]]] =
sigss.zip(argss).map{ case (sigs, args) =>
(sigss, argss, ctxs).zipped.map { case (sigs, args, ctx) =>
sigs.map{ case sig =>
Runtime.makeReadCall(
args,
Expand Down
2 changes: 1 addition & 1 deletion cask/src-3/cask/router/RoutesEndpointMetadata.scala
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ object RoutesEndpointsMetadata{
val entrypointExpr = Macros.extractMethod[T](m, decorators, endpointExpr)

'{
val entrypoint: EntryPoint[T, cask.Request] = ${entrypointExpr}
val entrypoint: EntryPoint[T, Any] = ${entrypointExpr}

EndpointMetadata[T](
// the Scala 2 version and non-macro code expects decorators to be reversed
Expand Down
1 change: 1 addition & 0 deletions cask/src/cask/main/Main.scala
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ object Main{
routes,
routeBindings,
(mainDecorators ++ routes.decorators ++ metadata.decorators).toList,
Nil,
Nil
) match {
case Result.Success(res) => runner(res)
Expand Down
6 changes: 4 additions & 2 deletions cask/src/cask/router/Decorators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,23 +40,25 @@ object Decorator{
routes: T,
routeBindings: Map[String, String],
remainingDecorators: List[Decorator[_, _, _, _]],
inputContexts: List[Any],
bindings: List[Map[String, Any]]): Result[Any] = try {
remainingDecorators match {
case head :: rest =>
head.asInstanceOf[Decorator[Any, Any, Any, Any]].wrapFunction(
ctx,
(_, args) => invoke(ctx, endpoint, entryPoint, routes, routeBindings, rest, args :: bindings)
(ictx, args) => invoke(ctx, endpoint, entryPoint, routes, routeBindings, rest, ictx :: inputContexts, args :: bindings)
.asInstanceOf[Result[Nothing]]
)

case Nil =>
endpoint.wrapFunction(ctx, { (ictx: Any, endpointBindings: Map[String, Any]) =>

val mergedEndpointBindings = endpointBindings ++ routeBindings.mapValues(endpoint.wrapPathSegment)
val finalBindings = mergedEndpointBindings :: bindings

entryPoint
.asInstanceOf[EntryPoint[T, Any]]
.invoke(routes, ictx, finalBindings)
.invoke(routes, ictx :: inputContexts, finalBindings)
.asInstanceOf[Result[Nothing]]
})
}
Expand Down
6 changes: 3 additions & 3 deletions cask/src/cask/router/EntryPoint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@ import scala.collection.mutable
case class EntryPoint[T, C](name: String,
argSignatures: Seq[Seq[ArgSig[_, T, _, C]]],
doc: Option[String],
invoke0: (T, C, Seq[Map[String, Any]], Seq[Seq[ArgSig[Any, _, _, C]]]) => Result[Any]){
invoke0: (T, Seq[C], Seq[Map[String, Any]], Seq[Seq[ArgSig[Any, _, _, C]]]) => Result[Any]){

val firstArgs = argSignatures.head
.map(x => x.name -> x)
.toMap[String, ArgSig[_, T, _, C]]

def invoke(target: T,
ctx: C,
ctxs: Seq[C],
paramLists: Seq[Map[String, Any]]): Result[Any] = {

val missing = mutable.Buffer.empty[ArgSig[_, T, _, C]]
Expand All @@ -42,7 +42,7 @@ case class EntryPoint[T, C](name: String,
} else {
try invoke0(
target,
ctx,
ctxs,
paramLists,
argSignatures.asInstanceOf[Seq[Seq[ArgSig[Any, _, _, C]]]]
)
Expand Down
54 changes: 54 additions & 0 deletions example/decoratedContext/app/src/DecoratedContext.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package app

case class Context(
session: Session
)

case class Session(data: collection.mutable.Map[String, String])

trait CustomParser[T] extends cask.router.ArgReader[Any, T, Context]
object CustomParser:
given CustomParser[Context] with
def arity = 0
def read(ctx: Context, label: String, input: Any): Context = ctx
given CustomParser[Session] with
def arity = 0
def read(ctx: Context, label: String, input: Any): Session = ctx.session
given literal[Literal]: CustomParser[Literal] with
def arity = 1
def read(ctx: Context, label: String, input: Any): Literal = input.asInstanceOf[Literal]

object DecoratedContext extends cask.MainRoutes{

class custom extends cask.router.Decorator[cask.Response.Raw, cask.Response.Raw, Any, Context]{

override type InputParser[T] = CustomParser[T]

def wrapFunction(req: cask.Request, delegate: Delegate) = {
// Create a custom context out of the request. Custom contexts are useful
// to group an expensive operation that may be used by multiple
// parameter readers.
val ctx = Context(Session(collection.mutable.Map.empty)) // this would typically be populated from a signed cookie

delegate(ctx, Map("user" -> 1337)).map{ response =>
val extraCookies = ctx.session.data.map(
(k, v) => cask.Cookie(k, v)
)

response.copy(
cookies = response.cookies ++ extraCookies
)
}

}
}

@custom()
@cask.get("/hello/:world")
def hello(world: String, req: cask.Request)(session: Session, user: Int) = {
session.data("hello") = "world"
world + user
}

initialize()
}
26 changes: 26 additions & 0 deletions example/decoratedContext/app/test/src/ExampleTests.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package app
import io.undertow.Undertow

import utest._

object ExampleTests extends TestSuite{
def withServer[T](example: cask.main.Main)(f: String => T): T = {
val server = Undertow.builder
.addHttpListener(8081, "localhost")
.setHandler(example.defaultHandler)
.build
server.start()
val res =
try f("http://localhost:8081")
finally server.stop()
res
}

val tests = Tests{
test("DecoratedContext") - withServer(DecoratedContext){ host =>
val response = requests.get(s"$host/hello/woo")
response.text() ==> "woo1337"
response.cookies("hello").getValue ==> "world"
}
}
}
14 changes: 14 additions & 0 deletions example/decoratedContext/build.sc
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import mill._, scalalib._

trait AppModule extends CrossScalaModule{

def ivyDeps = Agg[Dep](
)
object test extends ScalaTests with TestModule.Utest{

def ivyDeps = Agg(
ivy"com.lihaoyi::utest::0.8.1",
ivy"com.lihaoyi::requests::0.8.0",
)
}
}

0 comments on commit 6cf9c8c

Please sign in to comment.