diff --git a/cask/src/cask/main/Main.scala b/cask/src/cask/main/Main.scala index b76a2e9768..dd94ed0613 100644 --- a/cask/src/cask/main/Main.scala +++ b/cask/src/cask/main/Main.scala @@ -46,10 +46,10 @@ abstract class Main{ implicit def log: cask.util.Logger = new cask.util.Logger.Console() - def routeTries = Main.prepareRouteTries(allRoutes) + def dispatchTrie = Main.prepareDispatchTrie(allRoutes) def defaultHandler = new BlockingHandler( - new Main.DefaultHandler(routeTries, mainDecorators, debugMode, handleNotFound, handleMethodNotAllowed, handleEndpointError) + new Main.DefaultHandler(dispatchTrie, mainDecorators, debugMode, handleNotFound, handleMethodNotAllowed, handleEndpointError) ) def handleNotFound() = Main.defaultHandleNotFound() @@ -74,7 +74,7 @@ abstract class Main{ } object Main{ - class DefaultHandler(routeTries: Map[String, DispatchTrie[(Routes, EndpointMetadata[_])]], + class DefaultHandler(dispatchTrie: DispatchTrie[Map[String, (Routes, EndpointMetadata[_])]], mainDecorators: Seq[Decorator[_, _, _]], debugMode: Boolean, handleNotFound: () => Response.Raw, @@ -101,16 +101,10 @@ object Main{ (r: Any) => Main.writeResponse(exchange, r.asInstanceOf[Response.Raw]) ) - val dispatchTrie: DispatchTrie[(Routes, EndpointMetadata[_])] = routeTries.get(effectiveMethod) match { - case None => - Main.writeResponse(exchange, handleMethodNotAllowed()) - return - case Some(trie) => trie - } - - dispatchTrie.lookup(Util.splitPath(exchange.getRequestPath).toList, Map()) match { + dispatchTrie.lookup(Util.splitPath(exchange.getRequestPath).toList, Map()) match{ case None => Main.writeResponse(exchange, handleNotFound()) - case Some(((routes, metadata), routeBindings, remaining)) => + case Some((methodMap, routeBindings, remaining)) if methodMap.contains(effectiveMethod) => + val (routes, metadata) = methodMap(effectiveMethod) Decorator.invoke( Request(exchange, remaining), metadata.endpoint, @@ -128,8 +122,8 @@ object Main{ ) None } + case _ => Main.writeResponse(exchange, handleMethodNotAllowed()) } - // println("Completed Request: " + exchange.getRequestPath) }catch{case e: Throwable => e.printStackTrace() } @@ -149,23 +143,25 @@ object Main{ ) } - def prepareRouteTries(allRoutes: Seq[Routes]): Map[String, DispatchTrie[(Routes, EndpointMetadata[_])]] = { - val routeList = for{ + def prepareDispatchTrie(allRoutes: Seq[Routes]): DispatchTrie[Map[String, (Routes, EndpointMetadata[_])]] = { + val flattenedRoutes = for { routes <- allRoutes - route <- routes.caskMetadata.value.map(x => x: EndpointMetadata[_]) - } yield (routes, route) + metadata <- routes.caskMetadata.value + } yield { + val segments = Util.splitPath(metadata.endpoint.path) + val methodMap = metadata.endpoint.methods.map(_ -> (routes, metadata: EndpointMetadata[_])).toMap + (segments, methodMap, metadata.endpoint.subpath) + } - val allMethods: Set[String] = - routeList.flatMap(_._2.endpoint.methods).map(_.toLowerCase).toSet + val dispatchInputs = flattenedRoutes.groupBy(_._1).map { case (segments, values) => + val methodMap = values.map(_._2).flatten.toMap + val hasSubpath = values.map(_._3).contains(true) + (segments, methodMap, hasSubpath) + }.toSeq - allMethods - .map { method => - method -> DispatchTrie.construct[(Routes, EndpointMetadata[_])](0, - for ((route, metadata) <- routeList if metadata.endpoint.methods.contains(method)) - yield (Util.splitPath(metadata.endpoint.path): collection.IndexedSeq[String], (route, metadata), metadata.endpoint.subpath) - ) - }.toMap + DispatchTrie.construct(0, dispatchInputs) } + def writeResponse(exchange: HttpServerExchange, response: Response.Raw) = { response.data.headers.foreach{case (k, v) => exchange.getResponseHeaders.put(new HttpString(k), v) diff --git a/example/minimalApplication/app/test/src/ExampleTests.scala b/example/minimalApplication/app/test/src/ExampleTests.scala index 986fc617b3..cd04898c05 100644 --- a/example/minimalApplication/app/test/src/ExampleTests.scala +++ b/example/minimalApplication/app/test/src/ExampleTests.scala @@ -27,7 +27,7 @@ object ExampleTests extends TestSuite{ requests.post(s"$host/do-thing", data = "hello").text() ==> "olleh" - requests.get(s"$host/do-thing", check = false).statusCode ==> 404 + requests.delete(s"$host/do-thing", check = false).statusCode ==> 405 } } } diff --git a/example/minimalApplication2/app/test/src/ExampleTests.scala b/example/minimalApplication2/app/test/src/ExampleTests.scala index 5bbee09e9f..7d5ad2873e 100644 --- a/example/minimalApplication2/app/test/src/ExampleTests.scala +++ b/example/minimalApplication2/app/test/src/ExampleTests.scala @@ -27,7 +27,7 @@ object ExampleTests extends TestSuite{ requests.post(s"$host/do-thing", data = "hello").text() ==> "olleh" - requests.get(s"$host/do-thing", check = false).statusCode ==> 404 + requests.delete(s"$host/do-thing", check = false).statusCode ==> 405 } } }