Skip to content

Commit

Permalink
[JSON-API] Small refactoring to start using akka http routes internal…
Browse files Browse the repository at this point in the history
…ly (#10252)

* [JSON-API] Small refactoring to start using akka http routes internally

changelog_begin
changelog_end

* Fix build error

* Explicitly use the immutable seq via additional import

* Fix tests

* Fix scala 2.12 build
  • Loading branch information
realvictorprm authored Jul 13, 2021
1 parent 6e8ec1d commit 0a7f2b1
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@ import akka.http.scaladsl.model.headers.{
ModeledCustomHeader,
ModeledCustomHeaderCompanion,
OAuth2BearerToken,
`X-Forwarded-For`,
`X-Forwarded-Proto`,
`X-Real-Ip`,
}
import akka.http.scaladsl.server.Directives.extractClientIP
import akka.http.scaladsl.server.{Rejection, RequestContext, Route, RouteResult}
import akka.http.scaladsl.server.RouteResult._
import akka.stream.Materializer
import akka.stream.scaladsl.{Flow, Source}
import akka.util.ByteString
Expand Down Expand Up @@ -47,6 +48,8 @@ import scala.util.control.NonFatal
import com.daml.logging.{ContextualizedLogger, LoggingContextOf}
import com.daml.metrics.{Metrics, Timed}

import scala.collection.immutable

class Endpoints(
allowNonHttps: Boolean,
decodeJwt: EndpointsCompanion.ValidateJwt,
Expand All @@ -68,23 +71,10 @@ class Endpoints(
import util.ErrorOps._
import Uri.Path._

// Inspired by
// https://github.com/akka/akka-http/blob/master/akka-http/src/main/scala/akka/http/scaladsl/server/directives/MiscDirectives.scala#L110-L116
// Because the Remote-Address header is deprecated we don't match for it here.
def requestSource(req: HttpRequest): RemoteAddress =
req
.header[`X-Forwarded-For`]
.flatMap(_.addresses.headOption)
.orElse(req.header[`X-Real-Ip`].map(_.address))
.orElse(req.attribute(AttributeKeys.remoteAddress))
.getOrElse(RemoteAddress.Unknown)

// Parenthesis in the case matches below are required because otherwise scalafmt breaks.
//noinspection ScalaUnnecessaryParentheses
def all(implicit
lc: LoggingContextOf[InstanceUUID],
metrics: Metrics,
): PartialFunction[HttpRequest, Future[HttpResponse]] = {
): Route = extractClientIP { remoteAddress => (ctx: RequestContext) =>
val apiMetrics = metrics.daml.HttpJsonApi
type DispatchFun =
PartialFunction[HttpRequest, LoggingContextOf[InstanceUUID with RequestID] => Future[
Expand Down Expand Up @@ -151,27 +141,34 @@ class Endpoints(
_ => healthService.ready().map(_.toHttpResponse)
}
import scalaz.std.partialFunction._, scalaz.syntax.arrow._
((commandDispatch orElse
val dispatch = commandDispatch orElse
queryAllDispatch orElse
queryMatchingDispatch orElse
fetchDispatch orElse
getPartyDispatch orElse
allocatePartyDispatch orElse
packageManagementDispatch orElse
liveOrHealthDispatch) &&& { case r => r }) andThen { case (lcFhr, req) =>
extendWithRequestIdLogCtx(implicit lc => {
val t0 = System.nanoTime
logger.info(s"Incoming request on ${req.uri} from ${requestSource(req)}")
metrics.daml.HttpJsonApi.httpRequestThroughput.mark()
for {
res <- lcFhr(lc)
_ = {
logger.trace(s"Processed request after ${System.nanoTime() - t0}ns")
logger.info(s"Responding to client with HTTP ${res.status}")
}
} yield res
})
}
liveOrHealthDispatch
dispatch
.&&& { case r => r }
.andThen { case (lcFhr, req) =>
extendWithRequestIdLogCtx(implicit lc => {
val t0 = System.nanoTime
logger.info(s"Incoming request on ${req.uri} from $remoteAddress")
metrics.daml.HttpJsonApi.httpRequestThroughput.mark()
for {
res <- lcFhr(lc)
_ = {
logger.trace(s"Processed request after ${System.nanoTime() - t0}ns")
logger.info(s"Responding to client with HTTP ${res.status}")
}
} yield Complete(res)
})
}
.applyOrElse[HttpRequest, Future[RouteResult]](
ctx.request,
_ => Future(Rejected(immutable.Seq.empty[Rejection])),
)
}

def getParseAndDecodeTimerCtx()(implicit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
package com.daml.http

import akka.http.scaladsl.model._
import akka.http.scaladsl.server.RouteResult.Complete
import akka.http.scaladsl.server.{RequestContext, Route}
import akka.util.ByteString
import com.daml.http.domain.{JwtPayload, JwtWritePayload}
import com.daml.http.json.SprayJson
Expand Down Expand Up @@ -104,10 +106,13 @@ object EndpointsCompanion {
}
}

lazy val notFound: PartialFunction[HttpRequest, Future[HttpResponse]] = {
case HttpRequest(method, uri, _, _, _) =>
Future.successful(httpResponseError(NotFound(s"${method: HttpMethod}, uri: ${uri: Uri}")))
}
lazy val notFound: Route = (ctx: RequestContext) =>
ctx.request match {
case HttpRequest(method, uri, _, _, _) =>
Future.successful(
Complete(httpResponseError(NotFound(s"${method: HttpMethod}, uri: ${uri: Uri}")))
)
}

private[http] def httpResponseError(error: Error): HttpResponse = {
import com.daml.http.json.JsonProtocol._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ object HttpService {
maxInboundMessageSize = maxInboundMessageSize,
)

import akka.http.scaladsl.server.Directives._
val bindingEt: EitherT[Future, Error, ServerBinding] = for {
client <- eitherT(
ledgerClient(
Expand Down Expand Up @@ -194,15 +195,17 @@ object HttpService {
)

defaultEndpoints =
jsonEndpoints.all orElse
websocketEndpoints.transactionWebSocket orElse
EndpointsCompanion.notFound

allEndpoints = staticContentConfig.cata(
c =>
StaticContentEndpoints.all(c) orElse
defaultEndpoints,
defaultEndpoints,
concat(
jsonEndpoints.all,
websocketEndpoints.transactionWebSocket,
)

allEndpoints = concat(
staticContentConfig.cata(
c => concat(StaticContentEndpoints.all(c), defaultEndpoints),
defaultEndpoints,
),
EndpointsCompanion.notFound,
)

binding <- liftET[Error](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,30 @@ package com.daml.http
import akka.actor.ActorSystem
import akka.http.scaladsl.model._
import akka.http.scaladsl.server.Directives._
import akka.http.scaladsl.server.RouteResult.{Complete, Rejected}
import akka.http.scaladsl.server.directives.ContentTypeResolver.Default
import akka.http.scaladsl.server.Directives
import akka.http.scaladsl.server.{Directives, Rejection, RequestContext, Route, RouteResult}
import com.daml.http.util.Logging.InstanceUUID
import com.daml.logging.{ContextualizedLogger, LoggingContextOf}
import scalaz.syntax.show._

import scala.concurrent.Future
import scala.concurrent.{ExecutionContext, Future}
import scala.collection.immutable.Seq

object StaticContentEndpoints {
def all(config: StaticContentConfig)(implicit
asys: ActorSystem,
lc: LoggingContextOf[InstanceUUID],
): PartialFunction[HttpRequest, Future[HttpResponse]] =
ec: ExecutionContext,
): Route = (ctx: RequestContext) =>
new StaticContentRouter(config)
.andThen(
_ map Complete
)
.applyOrElse[HttpRequest, Future[RouteResult]](
ctx.request,
_ => Future(Rejected(Seq.empty[Rejection])),
)
}

private class StaticContentRouter(config: StaticContentConfig)(implicit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,17 @@ import scalaz.syntax.std.boolean._
import scalaz.syntax.std.option._
import scalaz.\/

import scala.concurrent.Future
import scala.concurrent.{ExecutionContext, Future}
import EndpointsCompanion._
import akka.http.scaladsl.server.{Rejection, RequestContext, Route, RouteResult}
import akka.http.scaladsl.server.RouteResult.{Complete, Rejected}
import com.daml.http.domain.JwtPayload
import com.daml.http.util.Logging.{InstanceUUID, RequestID, extendWithRequestIdLogCtx}
import com.daml.logging.{ContextualizedLogger, LoggingContextOf}
import com.daml.metrics.Metrics

import scala.collection.immutable.Seq

object WebsocketEndpoints {
private[http] val tokenPrefix: String = "jwt.token."
private[http] val wsProtocol: String = "daml.ws.auth"
Expand Down Expand Up @@ -49,7 +53,7 @@ object WebsocketEndpoints {
class WebsocketEndpoints(
decodeJwt: ValidateJwt,
webSocketService: WebSocketService,
) {
)(implicit ec: ExecutionContext) {

import WebsocketEndpoints._

Expand All @@ -58,7 +62,7 @@ class WebsocketEndpoints(
def transactionWebSocket(implicit
lc: LoggingContextOf[InstanceUUID],
metrics: Metrics,
) = {
): Route = { (ctx: RequestContext) =>
val dispatch: PartialFunction[HttpRequest, LoggingContextOf[
InstanceUUID with RequestID
] => Future[HttpResponse]] = {
Expand Down Expand Up @@ -105,12 +109,18 @@ class WebsocketEndpoints(
)
}
import scalaz.std.partialFunction._, scalaz.syntax.arrow._
(dispatch &&& { case r => r }) andThen { case (lcFhr, req) =>
extendWithRequestIdLogCtx(implicit lc => {
logger.trace(s"Incoming request on ${req.uri}")
lcFhr(lc)
})
}
dispatch
.&&& { case r => r }
.andThen { case (lcFhr, req) =>
extendWithRequestIdLogCtx(implicit lc => {
logger.trace(s"Incoming request on ${req.uri}")
lcFhr(lc) map Complete
})
}
.applyOrElse[HttpRequest, Future[RouteResult]](
ctx.request,
_ => Future(Rejected(Seq.empty[Rejection])),
)
}

def handleWebsocketRequest[A: WebSocketService.StreamQueryReader](
Expand Down

0 comments on commit 0a7f2b1

Please sign in to comment.