From a3e12b3cb7772c93b5c00abca18d76aa1f81f733 Mon Sep 17 00:00:00 2001 From: Martin Date: Mon, 17 May 2021 13:08:16 +0200 Subject: [PATCH] Prevent cask from responding with 405 for an undefined route. Fixes #51. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Prior to this commit, `prepareRouteTries` created a mapping from method-name to DispatchTrie (Map[String, DispatchTrie[…]]). This commit instead creates a DispatchTrie[Map[String, …]], basically an inversion of the previous result. The updated tests in minimalApplication and minimalApplication2 have been updated to cover the differences. --- cask/src/cask/main/Main.scala | 45 +++++++++---------- .../app/test/src/ExampleTests.scala | 2 +- .../app/test/src/ExampleTests.scala | 2 +- 3 files changed, 23 insertions(+), 26 deletions(-) diff --git a/cask/src/cask/main/Main.scala b/cask/src/cask/main/Main.scala index b76a2e9768..4c8e865152 100644 --- a/cask/src/cask/main/Main.scala +++ b/cask/src/cask/main/Main.scala @@ -14,6 +14,7 @@ import io.undertow.server.handlers.BlockingHandler import io.undertow.util.HttpString import scala.concurrent.ExecutionContext +import geny.Generator /** * A combination of [[cask.Main]] and [[cask.Routes]], ideal for small @@ -74,7 +75,7 @@ abstract class Main{ } object Main{ - class DefaultHandler(routeTries: Map[String, DispatchTrie[(Routes, EndpointMetadata[_])]], + class DefaultHandler(routeTries: DispatchTrie[Map[String, (Routes, EndpointMetadata[_])]], mainDecorators: Seq[Decorator[_, _, _]], debugMode: Boolean, handleNotFound: () => Response.Raw, @@ -101,16 +102,10 @@ object Main{ (r: Any) => Main.writeResponse(exchange, r.asInstanceOf[Response.Raw]) ) - val dispatchTrie: DispatchTrie[(Routes, EndpointMetadata[_])] = routeTries.get(effectiveMethod) match { - case None => - Main.writeResponse(exchange, handleMethodNotAllowed()) - return - case Some(trie) => trie - } - - dispatchTrie.lookup(Util.splitPath(exchange.getRequestPath).toList, Map()) match { + routeTries.lookup(Util.splitPath(exchange.getRequestPath).toList, Map()) match{ case None => Main.writeResponse(exchange, handleNotFound()) - case Some(((routes, metadata), routeBindings, remaining)) => + case Some((methodMap, routeBindings, remaining)) if methodMap.contains(effectiveMethod) => + val (routes, metadata) = methodMap(effectiveMethod) Decorator.invoke( Request(exchange, remaining), metadata.endpoint, @@ -128,8 +123,8 @@ object Main{ ) None } + case _ => Main.writeResponse(exchange, handleMethodNotAllowed()) } - // println("Completed Request: " + exchange.getRequestPath) }catch{case e: Throwable => e.printStackTrace() } @@ -149,23 +144,25 @@ object Main{ ) } - def prepareRouteTries(allRoutes: Seq[Routes]): Map[String, DispatchTrie[(Routes, EndpointMetadata[_])]] = { - val routeList = for{ + def prepareRouteTries(allRoutes: Seq[Routes]): DispatchTrie[Map[String, (Routes, EndpointMetadata[_])]] = { + val flattenedRoutes = for { routes <- allRoutes - route <- routes.caskMetadata.value.map(x => x: EndpointMetadata[_]) - } yield (routes, route) + metadata <- routes.caskMetadata.value + } yield { + val segments = Util.splitPath(metadata.endpoint.path) + val methodMap = metadata.endpoint.methods.map(_ -> (routes, metadata: EndpointMetadata[_])).toMap + (segments, methodMap, metadata.endpoint.subpath) + } - val allMethods: Set[String] = - routeList.flatMap(_._2.endpoint.methods).map(_.toLowerCase).toSet + val dispatchInputs = flattenedRoutes.groupBy(_._1).map { case (segments, values) => + val methodMap = values.map(_._2).flatten.toMap + val hasSubpath = values.map(_._3).contains(true) + (segments, methodMap, hasSubpath) + }.toSeq - allMethods - .map { method => - method -> DispatchTrie.construct[(Routes, EndpointMetadata[_])](0, - for ((route, metadata) <- routeList if metadata.endpoint.methods.contains(method)) - yield (Util.splitPath(metadata.endpoint.path): collection.IndexedSeq[String], (route, metadata), metadata.endpoint.subpath) - ) - }.toMap + DispatchTrie.construct(0, dispatchInputs) } + def writeResponse(exchange: HttpServerExchange, response: Response.Raw) = { response.data.headers.foreach{case (k, v) => exchange.getResponseHeaders.put(new HttpString(k), v) diff --git a/example/minimalApplication/app/test/src/ExampleTests.scala b/example/minimalApplication/app/test/src/ExampleTests.scala index 986fc617b3..cd04898c05 100644 --- a/example/minimalApplication/app/test/src/ExampleTests.scala +++ b/example/minimalApplication/app/test/src/ExampleTests.scala @@ -27,7 +27,7 @@ object ExampleTests extends TestSuite{ requests.post(s"$host/do-thing", data = "hello").text() ==> "olleh" - requests.get(s"$host/do-thing", check = false).statusCode ==> 404 + requests.delete(s"$host/do-thing", check = false).statusCode ==> 405 } } } diff --git a/example/minimalApplication2/app/test/src/ExampleTests.scala b/example/minimalApplication2/app/test/src/ExampleTests.scala index 5bbee09e9f..7d5ad2873e 100644 --- a/example/minimalApplication2/app/test/src/ExampleTests.scala +++ b/example/minimalApplication2/app/test/src/ExampleTests.scala @@ -27,7 +27,7 @@ object ExampleTests extends TestSuite{ requests.post(s"$host/do-thing", data = "hello").text() ==> "olleh" - requests.get(s"$host/do-thing", check = false).statusCode ==> 404 + requests.delete(s"$host/do-thing", check = false).statusCode ==> 405 } } }