Skip to content

Commit

Permalink
Correct handling of ensure header
Browse files Browse the repository at this point in the history
- Test for this case
- Correct implementation
- Better parse failure messages
  • Loading branch information
noelwelsh committed Jul 11, 2024
1 parent 4220293 commit 2f4b6c7
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 12 deletions.
4 changes: 3 additions & 1 deletion core/jvm/src/test/scala/krop/route/RequestHeaderSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ class RequestHeaderSuite extends CatsEffectSuite {
test("Ensure header succeeds if header does exist") {
val req = Request.get(Path.root).ensureHeader(jsonContentType)

req.parse(jsonRequest)(using Raise.toOption).map(_.isDefined).assert
req
.parse(jsonRequest)(using Raise.toOption)
.map(opt => assertEquals(opt, Some(EmptyTuple)))
}

test("Extract header extracts desired header (by type version)") {
Expand Down
48 changes: 37 additions & 11 deletions core/shared/src/main/scala/krop/route/Request.scala
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ object Request {
type Result = Types.TupleConcat[P, I]

import RequestHeaders.Process
import RequestHeaders.failure

export path.pathTo
export path.pathAndQueryTo
Expand All @@ -229,27 +230,34 @@ object Request {
extension [A](opt: Option[A]) {
def orFail(header: Header[?, ?]): Raise[ParseFailure] ?=> A =
opt.getOrElse(
Raise.raise(
ParseFailure(
ParseStage.Header,
"Could not extract the header ${header.name}",
""
)
)
Raise.raise(failure.headerNotFound(header.name.toString))
)
}

val extracted: F[ParseFailure, List[?]] =
Raise.handle { (r: Raise[ParseFailure]) ?=>
val reqHeaders = req.headers
headers.map(p =>
headers.foldLeft(List.empty)((accum, p) =>
p match {
case Process.Extract(value, header, select) =>
reqHeaders.get(using select).orFail(header)
accum :+ reqHeaders.get(using select).orFail(header)
case Process.ExtractFromName(header, select) =>
reqHeaders.get(using select).orFail(header)
accum :+ reqHeaders.get(using select).orFail(header)
case Process.Ensure(value, header, select) =>
reqHeaders.get(using select).orFail(header)
reqHeaders.get(using select) match {
case None =>
Raise.raise(failure.headerNotFound(header.name.toString))
case s @ Some(actual) =>
if actual == value then accum
else
Raise.raise(
failure.headerDidntMatch(
header.name.toString,
actual,
value
)
)
}
}
)
}
Expand Down Expand Up @@ -326,6 +334,24 @@ object Request {
case Ensure[A](value: A, header: Header[A, ?], select: Header.Select[A])
}

object failure {
def headerNotFound(name: String) =
ParseFailure(
ParseStage.Header,
s"Could not extract the header ${name}",
s"""The header named ${name} did not exist in the request's headers,
|or the value could not be correctly parsed.""".stripMargin
)

def headerDidntMatch[A](name: String, actual: A, expected: A) =
ParseFailure(
ParseStage.Header,
s"The header $name} did not have the expected value",
s"""The header with name ${name} and value ${actual} was found in the request,
|but we expected to find the value ${expected}.""".stripMargin
)
}

def empty[P <: Tuple, Q <: Tuple](
path: RequestMethodPath[P, Q]
): RequestHeaders[P, Q, EmptyTuple, EmptyTuple] =
Expand Down

0 comments on commit 2f4b6c7

Please sign in to comment.