Skip to content

Commit

Permalink
First draft of message attributes #1500
Browse files Browse the repository at this point in the history
This approach follows the Play API (https://www.playframework.com/documentation/2.8.x/Highlights26)
in that you can add attributes of any user type. Otherwise it follows the
existing conventions from Akka HTTP Headers, so you can have multiple
attributes of the same type. If you want to distinguish different attributes of
the same type, like you could in Play with different keys, you would have to
introduce a wrapper type (either holding the key or creating a separate wrapper
for each key).

The main downside of this approach is that it increases the memory usage of a
message with one field.
  • Loading branch information
raboof committed Jan 7, 2020
1 parent 5c467cb commit 024c144
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,13 @@ public interface HttpMessage {

/**
* Try to find the first header with the given name (case-insensitive) and return
* Some(header), otherwise this method returns None.
* Optional.of(header), otherwise this method returns an empty Optional.
*/
Optional<HttpHeader> getHeader(String headerName);

/**
* Try to find the first header of the given class and return
* Some(header), otherwise this method returns None.
* Optional.of(header), otherwise this method returns an empty Optional.
*/
<T extends HttpHeader> Optional<T> getHeader(Class<T> headerClass);

Expand All @@ -65,6 +65,22 @@ public interface HttpMessage {
*/
<T extends HttpHeader> Iterable<T> getHeaders(Class<T> headerClass);

/**
* An iterable containing the attributes for this message.
*/
Iterable<Object> getAttributes();

/**
* Try to find the first attribute of the given class and return
* Optional.of(attribute), otherwise this method returns an empty Optional
*/
<T> Optional<T> getAttribute(Class<T> attributeClass);

/**
* An iterable containing all attributes of the given class of this message
*/
<T> Iterable<T> getAttributes(Class<T> attributeClass);

/**
* The entity of this message.
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Adding a method to an interface marked DoNotInherit is OK
ProblemFilters.exclude[ReversedMissingMethodProblem]("akka.http.javadsl.model.HttpMessage.getAttributes")
ProblemFilters.exclude[ReversedMissingMethodProblem]("akka.http.javadsl.model.HttpMessage.getAttribute")

# Adding a method to a sealed trait is safe
ProblemFilters.exclude[ReversedMissingMethodProblem]("akka.http.scaladsl.model.HttpMessage.attributes")
ProblemFilters.exclude[ReversedMissingMethodProblem]("akka.http.scaladsl.model.HttpMessage.withAttributes")
ProblemFilters.exclude[ReversedMissingMethodProblem]("akka.http.scaladsl.model.HttpMessage.mapAttributes")
ProblemFilters.exclude[ReversedMissingMethodProblem]("akka.http.scaladsl.model.HttpMessage.addAttribute")
ProblemFilters.exclude[ReversedMissingMethodProblem]("akka.http.scaladsl.model.HttpMessage.getAttributes")
ProblemFilters.exclude[ReversedMissingMethodProblem]("akka.http.scaladsl.model.HttpMessage.getAttribute")
ProblemFilters.exclude[ReversedMissingMethodProblem]("akka.http.scaladsl.model.HttpMessage.getAttributes")
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

package akka.http.scaladsl.model

import akka.stream.scaladsl.Flow
import akka.stream.{ FlowShape, Graph }
import java.io.File
import java.nio.file.Path
import java.lang.{ Iterable => JIterable }
Expand All @@ -17,9 +15,14 @@ import scala.concurrent.duration.FiniteDuration
import scala.concurrent.{ ExecutionContext, Future }
import scala.collection.immutable
import scala.reflect.{ ClassTag, classTag }

import akka.Done
import akka.annotation.InternalApi
import akka.parboiled2.CharUtils
import akka.stream.FlowShape
import akka.stream.Graph
import akka.stream.Materializer
import akka.stream.scaladsl.Flow
import akka.util.{ ByteString, HashCode, OptionVal }
import akka.http.ccompat.{ pre213, since213 }
import akka.http.impl.util._
Expand All @@ -42,6 +45,7 @@ sealed trait HttpMessage extends jm.HttpMessage {
def isResponse: Boolean

def headers: immutable.Seq[HttpHeader]
def attributes: immutable.Seq[AnyRef]
def entity: ResponseEntity
def protocol: HttpProtocol

Expand Down Expand Up @@ -79,6 +83,13 @@ sealed trait HttpMessage extends jm.HttpMessage {
def withHeaders(firstHeader: HttpHeader, otherHeaders: HttpHeader*): Self =
withHeaders(firstHeader +: otherHeaders.toList)

/** Returns a copy of this message with the list of headers set to the given ones. */
def withAttributes(attributes: immutable.Seq[AnyRef]): Self

/** Returns a copy of this message with the list of headers set to the given ones. */
def withAttributes(firstAttribute: AnyRef, otherAttributes: AnyRef*): Self =
withAttributes(firstAttribute +: otherAttributes.toList)

/**
* Returns a new message that contains all of the given default headers which didn't already
* exist (by case-insensitive header name) in this message.
Expand Down Expand Up @@ -117,6 +128,9 @@ sealed trait HttpMessage extends jm.HttpMessage {
/** Returns a copy of this message with the list of headers transformed by the given function */
def mapHeaders(f: immutable.Seq[HttpHeader] => immutable.Seq[HttpHeader]): Self = withHeaders(f(headers))

/** Returns a copy of this message with the list of attributes transformed by the given function */
def mapAttributes(f: immutable.Seq[AnyRef] => immutable.Seq[AnyRef]): Self = withAttributes(f(headers))

/**
* The content encoding as specified by the Content-Encoding header. If no Content-Encoding header is present the
* default value 'identity' is returned.
Expand All @@ -140,6 +154,10 @@ sealed trait HttpMessage extends jm.HttpMessage {
def headers[T <: jm.HttpHeader: ClassTag]: immutable.Seq[T] = headers.collect {
case h: T => h
}
/** Returns all the attributes of the given type **/
def attributes[T: ClassTag]: immutable.Seq[T] = attributes.collect {
case a: T => a
}

/**
* Returns true if this message is an:
Expand All @@ -149,6 +167,7 @@ sealed trait HttpMessage extends jm.HttpMessage {
def connectionCloseExpected: Boolean = HttpMessage.connectionCloseExpected(protocol, header[Connection])

def addHeader(header: jm.HttpHeader): Self = mapHeaders(_ :+ header.asInstanceOf[HttpHeader])
def addAttribute(attribute: AnyRef): Self = mapAttributes(_ :+ header.asInstanceOf[HttpHeader])

def addCredentials(credentials: jm.headers.HttpCredentials): Self = addHeader(jm.headers.Authorization.create(credentials))

Expand Down Expand Up @@ -196,6 +215,21 @@ sealed trait HttpMessage extends jm.HttpMessage {
import JavaMapping.Implicits._
withHeaders(headers.asScala.toVector.map(_.asScala))
}
/** Java API */
def getAttributes: JIterable[AnyRef] = (attributes: immutable.Seq[AnyRef]).asJava
/** Java API */
def getAttribute[T](attributeClass: Class[T]): Optional[T] = {
fastFindAttribute[T](attributeClass) match {
case OptionVal.Some(h) => Optional.of(h.asInstanceOf[T])
case _ => Optional.empty()
}
val attrs = attributes[T](ClassTag[T](attributeClass))
if (attrs.isEmpty) Optional.empty()
else Optional.of(attrs.head)
}
def getAttributes[T](attributeClass: Class[T]): JIterable[T] =
attributes[T](ClassTag[T](attributeClass)).asJava

/** Java API */
def toStrict(timeoutMillis: Long, ec: Executor, materializer: Materializer): CompletionStage[Self] = {
val ex = ExecutionContext.fromExecutor(ec)
Expand All @@ -206,6 +240,14 @@ sealed trait HttpMessage extends jm.HttpMessage {
val ex = ExecutionContext.fromExecutor(ec)
toStrict(timeoutMillis.millis, maxBytes)(ex, materializer).toJava
}
private def fastFindAttribute[T](clazz: Class[T]): OptionVal[T] = {
val it = attributes.iterator
while (it.hasNext) it.next() match {
case h if clazz.isInstance(h) => return OptionVal.Some[T](h.asInstanceOf[T])
case _ => // continue ...
}
OptionVal.none[T]
}
}

object HttpMessage {
Expand Down Expand Up @@ -265,11 +307,12 @@ object HttpMessage {
* The immutable model HTTP request model.
*/
final class HttpRequest(
val method: HttpMethod,
val uri: Uri,
val headers: immutable.Seq[HttpHeader],
val entity: RequestEntity,
val protocol: HttpProtocol)
val method: HttpMethod,
val uri: Uri,
val headers: immutable.Seq[HttpHeader],
val attributes: immutable.Seq[AnyRef],
val entity: RequestEntity,
val protocol: HttpProtocol)
extends jm.HttpRequest with HttpMessage {

HttpRequest.verifyUri(uri)
Expand All @@ -284,6 +327,10 @@ final class HttpRequest(
override def isRequest = true
override def isResponse = false

@deprecated("for backwards compatibility", "10.2.0")
def this(method: HttpMethod, uri: Uri, headers: immutable.Seq[HttpHeader], entity: RequestEntity, protocol: HttpProtocol) =
this(method, uri, headers, Nil, entity, protocol)

/**
* Resolve this request's URI according to the logic defined at
* http://tools.ietf.org/html/rfc7230#section-5.5
Expand Down Expand Up @@ -313,6 +360,8 @@ final class HttpRequest(

override def withHeaders(headers: immutable.Seq[HttpHeader]): HttpRequest =
if (headers eq this.headers) this else copy(headers = headers)
override def withAttributes(attributes: immutable.Seq[AnyRef]): HttpRequest =
if (attributes eq this.attributes) this else copy(attributes = attributes)

override def withHeadersAndEntity(headers: immutable.Seq[HttpHeader], entity: RequestEntity): HttpRequest = copy(headers = headers, entity = entity)
override def withEntity(entity: jm.RequestEntity): HttpRequest = copy(entity = entity.asInstanceOf[RequestEntity])
Expand All @@ -335,12 +384,20 @@ final class HttpRequest(

/* Manual Case Class things, to easen bin-compat */

@deprecated("Kept for binary compatibility", "10.2.0")
def copy(
method: HttpMethod,
uri: Uri,
headers: immutable.Seq[HttpHeader],
entity: RequestEntity,
protocol: HttpProtocol) = new HttpRequest(method, uri, headers, attributes, entity, protocol)
def copy(
method: HttpMethod = method,
uri: Uri = uri,
headers: immutable.Seq[HttpHeader] = headers,
entity: RequestEntity = entity,
protocol: HttpProtocol = protocol) = new HttpRequest(method, uri, headers, entity, protocol)
method: HttpMethod = method,
uri: Uri = uri,
headers: immutable.Seq[HttpHeader] = headers,
entity: RequestEntity = entity,
protocol: HttpProtocol = protocol,
attributes: immutable.Seq[AnyRef] = attributes) = new HttpRequest(method, uri, headers, attributes, entity, protocol)

override def hashCode(): Int = {
var result = HashCode.SEED
Expand Down Expand Up @@ -439,7 +496,7 @@ object HttpRequest {
uri: Uri = Uri./,
headers: immutable.Seq[HttpHeader] = Nil,
entity: RequestEntity = HttpEntity.Empty,
protocol: HttpProtocol = HttpProtocols.`HTTP/1.1`) = new HttpRequest(method, uri, headers, entity, protocol)
protocol: HttpProtocol = HttpProtocols.`HTTP/1.1`) = new HttpRequest(method, uri, headers, Nil, entity, protocol)

def unapply(any: HttpRequest) = new OptHttpRequest(any)
}
Expand All @@ -448,10 +505,11 @@ object HttpRequest {
* The immutable HTTP response model.
*/
final class HttpResponse(
val status: StatusCode,
val headers: immutable.Seq[HttpHeader],
val entity: ResponseEntity,
val protocol: HttpProtocol)
val status: StatusCode,
val headers: immutable.Seq[HttpHeader],
val attributes: immutable.Seq[AnyRef],
val entity: ResponseEntity,
val protocol: HttpProtocol)
extends jm.HttpResponse with HttpMessage {

require(entity.isKnownEmpty || status.allowsEntity, "Responses with this status code must have an empty entity")
Expand All @@ -465,8 +523,14 @@ final class HttpResponse(
override def isRequest = false
override def isResponse = true

@deprecated("for backwards compatibility", "10.2.0")
def this(status: StatusCode, headers: immutable.Seq[HttpHeader], entity: ResponseEntity, protocol: HttpProtocol) =
this(status, headers, Nil, entity, protocol)

override def withHeaders(headers: immutable.Seq[HttpHeader]): HttpResponse =
if (headers eq this.headers) this else copy(headers = headers)
override def withAttributes(attributes: immutable.Seq[AnyRef]): HttpResponse =
if (attributes eq this.attributes) this else copy(attributes = attributes)

override def withProtocol(protocol: akka.http.javadsl.model.HttpProtocol): akka.http.javadsl.model.HttpResponse = withProtocol(protocol.asInstanceOf[HttpProtocol])
def withProtocol(protocol: HttpProtocol): HttpResponse = copy(protocol = protocol)
Expand All @@ -485,11 +549,18 @@ final class HttpResponse(

/* Manual Case Class things, to ease bin-compat */

@deprecated("Kept for binary compatibility", "10.2.0")
def copy(
status: StatusCode,
headers: immutable.Seq[HttpHeader],
entity: ResponseEntity,
protocol: HttpProtocol) = new HttpResponse(status, headers, attributes, entity, protocol)
def copy(
status: StatusCode = status,
headers: immutable.Seq[HttpHeader] = headers,
entity: ResponseEntity = entity,
protocol: HttpProtocol = protocol) = new HttpResponse(status, headers, entity, protocol)
status: StatusCode = status,
headers: immutable.Seq[HttpHeader] = headers,
entity: ResponseEntity = entity,
protocol: HttpProtocol = protocol,
attributes: immutable.Seq[AnyRef] = attributes) = new HttpResponse(status, headers, attributes, entity, protocol)

override def equals(obj: scala.Any): Boolean = obj match {
case HttpResponse(_status, _headers, _entity, _protocol) =>
Expand Down Expand Up @@ -526,7 +597,7 @@ object HttpResponse {
status: StatusCode = StatusCodes.OK,
headers: immutable.Seq[HttpHeader] = Nil,
entity: ResponseEntity = HttpEntity.Empty,
protocol: HttpProtocol = HttpProtocols.`HTTP/1.1`) = new HttpResponse(status, headers, entity, protocol)
protocol: HttpProtocol = HttpProtocols.`HTTP/1.1`) = new HttpResponse(status, headers, Nil, entity, protocol)

def unapply(any: HttpResponse): OptHttpResponse = new OptHttpResponse(any)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,14 @@ class HttpMessageSpec extends WordSpec with Matchers {
val request = HttpRequest().withHeaders(oneCookieHeader, anotherCookieHeader, hostHeader)
request.headers[`Set-Cookie`] should ===(Seq(oneCookieHeader, anotherCookieHeader))
}
"retrieve all attributes of a given class when calling attributes[...]" in {
val oneStringAttribute: String = "A string attribute!"
val anotherStringAttribute: String = "And another"
val intAttribute: Integer = 42
val request = HttpRequest().withAttributes(oneStringAttribute, anotherStringAttribute, intAttribute)
println(request.attributes)
request.attributes[String] should ===(Seq(oneStringAttribute, anotherStringAttribute))
}
}

}

0 comments on commit 024c144

Please sign in to comment.