From 8e0210ef312419d05968fbd33e7df7f8bebe8176 Mon Sep 17 00:00:00 2001 From: Jakob Odersky Date: Mon, 8 Jul 2024 13:57:51 +0200 Subject: [PATCH] Allow overlap between static routes and wildcards This changes the DispatchTrie to allow overlapping wildcard segments in paths with static ones, with a preference for the latter. For example, consider the following routes: ``` @cask.get("/settings") def settings() = "settings" @cask.get("/:user") def user(id: String) = id ``` This is currently not allowed. With these changes, it would be allowed, and the static route `settings` would be preffered, with a fallback to the dynamic route `user`: ``` GET /settings => settings GET /foo => foo GET /bar => bar ``` --- The reason I'm proposing this change is mostly for use in HTML applications (i.e. not RPC-style JSON APIs). In this scenario, short URLs are useful, since users may type them directly and associate meaning to them. Consider for example the way GitHub structures URLs. If github were written with cask's current routing logic, it would not be possible to have URLs such as `/settings` and `/com-lihaoyi`, and instead some namespacing would need to be introduced (e.g. /orgs/com-lihaoyi/) to separate these, which might not actually be relevant for users. --- cask/src/cask/internal/DispatchTrie.scala | 59 +++++++++------- cask/src/cask/main/Main.scala | 2 +- .../src/test/cask/DispatchTrieTests.scala | 68 ++++++------------- example/queryParams/app/src/QueryParams.scala | 17 ++++- .../app/test/src/ExampleTests.scala | 10 +++ 5 files changed, 81 insertions(+), 75 deletions(-) diff --git a/cask/src/cask/internal/DispatchTrie.scala b/cask/src/cask/internal/DispatchTrie.scala index cc9f7e7252..351ce26217 100644 --- a/cask/src/cask/internal/DispatchTrie.scala +++ b/cask/src/cask/internal/DispatchTrie.scala @@ -38,17 +38,25 @@ object DispatchTrie{ validateGroup(groupTerminals, groupContinuations) } + val dynamicChildren = continuations.filter(_._1.startsWith(":")) + .flatMap(_._2).toIndexedSeq + DispatchTrie[T]( - current = terminals.headOption.map(x => x._2 -> x._3), - children = continuations + current = terminals.headOption + .map{ case (path, value, capturesSubpath) => + val argNames = path.filter(_.startsWith(":")).map(_.drop(1)).toVector + (value, capturesSubpath, argNames) + }, + staticChildren = continuations + .filter(!_._1.startsWith(":")) .map{ case (k, vs) => (k, construct(index + 1, vs)(validationGroups))} - .toMap + .toMap, + dynamicChildren = if (dynamicChildren.isEmpty) None else Some(construct(index + 1, dynamicChildren)(validationGroups)) ) } def validateGroup[T, V](terminals: collection.Seq[(collection.Seq[String], T, Boolean, V)], continuations: mutable.Map[String, mutable.Buffer[(collection.IndexedSeq[String], T, Boolean, V)]]) = { - val wildcards = continuations.filter(_._1(0) == ':') def renderTerminals = terminals .map{case (path, v, allowSubpath, group) => s"$group${renderPath(path)}"} @@ -65,12 +73,6 @@ object DispatchTrie{ ) } - if (wildcards.size >= 1 && continuations.size > 1) { - throw new Exception( - s"Routes overlap with wildcards: $renderContinuations" - ) - } - if (terminals.headOption.exists(_._3) && continuations.size == 1) { throw new Exception( s"Routes overlap with subpath capture: $renderTerminals, $renderContinuations" @@ -88,32 +90,37 @@ object DispatchTrie{ * segments starting with `:`) and any remaining un-used path segments * (only when `current._2 == true`, indicating this route allows trailing * segments) + * current = (value, captures subpaths, argument names) */ -case class DispatchTrie[T](current: Option[(T, Boolean)], - children: Map[String, DispatchTrie[T]]){ +case class DispatchTrie[T]( + current: Option[(T, Boolean, Vector[String])], + staticChildren: Map[String, DispatchTrie[T]], + dynamicChildren: Option[DispatchTrie[T]] +) { + final def lookup(remainingInput: List[String], - bindings: Map[String, String]) + bindings: Vector[String]) : Option[(T, Map[String, String], Seq[String])] = { - remainingInput match{ + remainingInput match { case Nil => - current.map(x => (x._1, bindings, Nil)) + current.map(x => (x._1, x._3.zip(bindings).toMap, Nil)) case head :: rest if current.exists(_._2) => - current.map(x => (x._1, bindings, head :: rest)) + current.map(x => (x._1, x._3.zip(bindings).toMap, head :: rest)) case head :: rest => - if (children.size == 1 && children.keys.head.startsWith(":")){ - children.values.head.lookup(rest, bindings + (children.keys.head.drop(1) -> head)) - }else{ - children.get(head) match{ - case None => None - case Some(continuation) => continuation.lookup(rest, bindings) - } + staticChildren.get(head) match { + case Some(continuation) => continuation.lookup(rest, bindings) + case None => + dynamicChildren match { + case Some(continuation) => continuation.lookup(rest, bindings :+ head) + case None => None + } } - } } def map[V](f: T => V): DispatchTrie[V] = DispatchTrie( - current.map{case (t, v) => (f(t), v)}, - children.map { case (k, v) => (k, v.map(f))} + current.map{case (t, v, a) => (f(t), v, a)}, + staticChildren.map { case (k, v) => (k, v.map(f))}, + dynamicChildren.map { case v => v.map(f)}, ) } diff --git a/cask/src/cask/main/Main.scala b/cask/src/cask/main/Main.scala index 6ac6e80e57..15c365666b 100644 --- a/cask/src/cask/main/Main.scala +++ b/cask/src/cask/main/Main.scala @@ -106,7 +106,7 @@ object Main{ .map(java.net.URLDecoder.decode(_, "UTF-8")) .toList - dispatchTrie.lookup(decodedSegments, Map()) match { + dispatchTrie.lookup(decodedSegments, Vector()) match { case None => Main.writeResponse(exchange, handleNotFound(Request(exchange, decodedSegments))) case Some((methodMap, routeBindings, remaining)) => methodMap.get(effectiveMethod) match { diff --git a/cask/test/src/test/cask/DispatchTrieTests.scala b/cask/test/src/test/cask/DispatchTrieTests.scala index 1da0df3286..b7b2ce04ff 100644 --- a/cask/test/src/test/cask/DispatchTrieTests.scala +++ b/cask/test/src/test/cask/DispatchTrieTests.scala @@ -11,9 +11,9 @@ object DispatchTrieTests extends TestSuite { )(Seq(_)) assert( - x.lookup(List("hello"), Map()) == Some((1, Map(), Nil)), - x.lookup(List("hello", "world"), Map()) == None, - x.lookup(List("world"), Map()) == None + x.lookup(List("hello"), Vector()) == Some((1, Map(), Nil)), + x.lookup(List("hello", "world"), Vector()) == None, + x.lookup(List("world"), Vector()) == None ) } "nested" - { @@ -24,11 +24,11 @@ object DispatchTrieTests extends TestSuite { ) )(Seq(_)) assert( - x.lookup(List("hello", "world"), Map()) == Some((1, Map(), Nil)), - x.lookup(List("hello", "cow"), Map()) == Some((2, Map(), Nil)), - x.lookup(List("hello"), Map()) == None, - x.lookup(List("hello", "moo"), Map()) == None, - x.lookup(List("hello", "world", "moo"), Map()) == None + x.lookup(List("hello", "world"), Vector()) == Some((1, Map(), Nil)), + x.lookup(List("hello", "cow"), Vector()) == Some((2, Map(), Nil)), + x.lookup(List("hello"), Vector()) == None, + x.lookup(List("hello", "moo"), Vector()) == None, + x.lookup(List("hello", "world", "moo"), Vector()) == None ) } "bindings" - { @@ -36,11 +36,11 @@ object DispatchTrieTests extends TestSuite { Seq((Vector(":hello", ":world"), 1, false)) )(Seq(_)) assert( - x.lookup(List("hello", "world"), Map()) == Some((1, Map("hello" -> "hello", "world" -> "world"), Nil)), - x.lookup(List("world", "hello"), Map()) == Some((1, Map("hello" -> "world", "world" -> "hello"), Nil)), + x.lookup(List("hello", "world"), Vector()) == Some((1, Map("hello" -> "hello", "world" -> "world"), Nil)), + x.lookup(List("world", "hello"), Vector()) == Some((1, Map("hello" -> "world", "world" -> "hello"), Nil)), - x.lookup(List("hello", "world", "cow"), Map()) == None, - x.lookup(List("hello"), Map()) == None + x.lookup(List("hello", "world", "cow"), Vector()) == None, + x.lookup(List("hello"), Vector()) == None ) } @@ -50,35 +50,21 @@ object DispatchTrieTests extends TestSuite { )(Seq(_)) assert( - x.lookup(List("hello", "world"), Map()) == Some((1,Map(), Seq("world"))), - x.lookup(List("hello", "world", "cow"), Map()) == Some((1,Map(), Seq("world", "cow"))), - x.lookup(List("hello"), Map()) == Some((1,Map(), Seq())), - x.lookup(List(), Map()) == None + x.lookup(List("hello", "world"), Vector()) == Some((1,Map(), Seq("world"))), + x.lookup(List("hello", "world", "cow"), Vector()) == Some((1,Map(), Seq("world", "cow"))), + x.lookup(List("hello"), Vector()) == Some((1,Map(), Seq())), + x.lookup(List(), Vector()) == None ) } - "errors" - { + "wildcards" - { test - { DispatchTrie.construct(0, Seq( (Vector("hello", ":world"), 1, false), - (Vector("hello", "world"), 2, false) + (Vector("hello", "world"), 1, false) ) )(Seq(_)) - - val ex = intercept[Exception]{ - DispatchTrie.construct(0, - Seq( - (Vector("hello", ":world"), 1, false), - (Vector("hello", "world"), 1, false) - ) - )(Seq(_)) - } - - assert( - ex.getMessage == - "Routes overlap with wildcards: 1 /hello/:world, 1 /hello/world" - ) } test - { DispatchTrie.construct(0, @@ -87,21 +73,9 @@ object DispatchTrieTests extends TestSuite { (Vector("hello", "world", "omg"), 2, false) ) )(Seq(_)) - - val ex = intercept[Exception]{ - DispatchTrie.construct(0, - Seq( - (Vector("hello", ":world"), 1, false), - (Vector("hello", "world", "omg"), 1, false) - ) - )(Seq(_)) - } - - assert( - ex.getMessage == - "Routes overlap with wildcards: 1 /hello/:world, 1 /hello/world/omg" - ) } + } + "errors" - { test - { DispatchTrie.construct(0, Seq( @@ -143,7 +117,7 @@ object DispatchTrieTests extends TestSuite { assert( ex.getMessage == - "Routes overlap with wildcards: 1 /hello/:world, 1 /hello/:cow" + "More than one endpoint has the same path: 1 /hello/:world, 1 /hello/:cow" ) } test - { diff --git a/example/queryParams/app/src/QueryParams.scala b/example/queryParams/app/src/QueryParams.scala index 5e73574e82..5e79409101 100644 --- a/example/queryParams/app/src/QueryParams.scala +++ b/example/queryParams/app/src/QueryParams.scala @@ -2,7 +2,7 @@ package app object QueryParams extends cask.MainRoutes{ @cask.get("/article/:articleId") // Mandatory query param, e.g. HOST/article/foo?param=bar - def getArticle(articleId: Int, param: String) = { + def getArticle(articleId: Int, param: String) = { s"Article $articleId $param" } @@ -31,5 +31,20 @@ object QueryParams extends cask.MainRoutes{ s"User $userName " + params.value } + @cask.get("/statics/foo") + def getStatic() = { + "static route takes precedence" + } + + @cask.get("/statics/:foo") + def getDynamics(foo: String) = { + s"dynamic route $foo" + } + + @cask.get("/statics/bar") + def getStatic2() = { + "another static route" + } + initialize() } diff --git a/example/queryParams/app/test/src/ExampleTests.scala b/example/queryParams/app/test/src/ExampleTests.scala index 03ae03371f..ca9b4b0752 100644 --- a/example/queryParams/app/test/src/ExampleTests.scala +++ b/example/queryParams/app/test/src/ExampleTests.scala @@ -90,6 +90,16 @@ object ExampleTests extends TestSuite{ res3 == "User lihaoyi Map(unknown1 -> WrappedArray(123), unknown2 -> WrappedArray(abc))" || res3 == "User lihaoyi Map(unknown1 -> ArraySeq(123), unknown2 -> ArraySeq(abc))" ) + + assert( + requests.get(s"$host/statics/foo").text() == "static route takes precedence" + ) + assert( + requests.get(s"$host/statics/hello").text() == "dynamic route hello" + ) + assert( + requests.get(s"$host/statics/bar").text() == "another static route" + ) } } }