Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle edge cases with NaN/Infinity #1624

Merged
merged 5 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import smithy4s.schema.Schema.unit
trait DummyServiceGen[F[_, _, _, _, _]] {
self =>

def dummy(str: Option[String] = None, int: Option[Int] = None, ts1: Option[Timestamp] = None, ts2: Option[Timestamp] = None, ts3: Option[Timestamp] = None, ts4: Option[Timestamp] = None, b: Option[Boolean] = None, sl: Option[List[String]] = None, ie: Option[Numbers] = None, on: Option[OpenNums] = None, ons: Option[OpenNumsStr] = None, slm: Option[Map[String, String]] = None): F[Queries, Nothing, Unit, Nothing, Nothing]
def dummy(str: Option[String] = None, int: Option[Int] = None, ts1: Option[Timestamp] = None, ts2: Option[Timestamp] = None, ts3: Option[Timestamp] = None, ts4: Option[Timestamp] = None, b: Option[Boolean] = None, sl: Option[List[String]] = None, ie: Option[Numbers] = None, on: Option[OpenNums] = None, ons: Option[OpenNumsStr] = None, dbl: Option[Double] = None, slm: Option[Map[String, String]] = None): F[Queries, Nothing, Unit, Nothing, Nothing]
def dummyHostPrefix(label1: String, label2: String, label3: HostLabelEnum): F[HostLabelInput, Nothing, Unit, Nothing, Nothing]
def dummyPath(str: String, int: Int, ts1: Timestamp, ts2: Timestamp, ts3: Timestamp, ts4: Timestamp, b: Boolean, ie: Numbers): F[PathParams, Nothing, Unit, Nothing, Nothing]

Expand Down Expand Up @@ -69,12 +69,12 @@ sealed trait DummyServiceOperation[Input, Err, Output, StreamedInput, StreamedOu
object DummyServiceOperation {

object reified extends DummyServiceGen[DummyServiceOperation] {
def dummy(str: Option[String] = None, int: Option[Int] = None, ts1: Option[Timestamp] = None, ts2: Option[Timestamp] = None, ts3: Option[Timestamp] = None, ts4: Option[Timestamp] = None, b: Option[Boolean] = None, sl: Option[List[String]] = None, ie: Option[Numbers] = None, on: Option[OpenNums] = None, ons: Option[OpenNumsStr] = None, slm: Option[Map[String, String]] = None): Dummy = Dummy(Queries(str, int, ts1, ts2, ts3, ts4, b, sl, ie, on, ons, slm))
def dummy(str: Option[String] = None, int: Option[Int] = None, ts1: Option[Timestamp] = None, ts2: Option[Timestamp] = None, ts3: Option[Timestamp] = None, ts4: Option[Timestamp] = None, b: Option[Boolean] = None, sl: Option[List[String]] = None, ie: Option[Numbers] = None, on: Option[OpenNums] = None, ons: Option[OpenNumsStr] = None, dbl: Option[Double] = None, slm: Option[Map[String, String]] = None): Dummy = Dummy(Queries(str, int, ts1, ts2, ts3, ts4, b, sl, ie, on, ons, dbl, slm))
def dummyHostPrefix(label1: String, label2: String, label3: HostLabelEnum): DummyHostPrefix = DummyHostPrefix(HostLabelInput(label1, label2, label3))
def dummyPath(str: String, int: Int, ts1: Timestamp, ts2: Timestamp, ts3: Timestamp, ts4: Timestamp, b: Boolean, ie: Numbers): DummyPath = DummyPath(PathParams(str, int, ts1, ts2, ts3, ts4, b, ie))
}
class Transformed[P[_, _, _, _, _], P1[_ ,_ ,_ ,_ ,_]](alg: DummyServiceGen[P], f: PolyFunction5[P, P1]) extends DummyServiceGen[P1] {
def dummy(str: Option[String] = None, int: Option[Int] = None, ts1: Option[Timestamp] = None, ts2: Option[Timestamp] = None, ts3: Option[Timestamp] = None, ts4: Option[Timestamp] = None, b: Option[Boolean] = None, sl: Option[List[String]] = None, ie: Option[Numbers] = None, on: Option[OpenNums] = None, ons: Option[OpenNumsStr] = None, slm: Option[Map[String, String]] = None): P1[Queries, Nothing, Unit, Nothing, Nothing] = f[Queries, Nothing, Unit, Nothing, Nothing](alg.dummy(str, int, ts1, ts2, ts3, ts4, b, sl, ie, on, ons, slm))
def dummy(str: Option[String] = None, int: Option[Int] = None, ts1: Option[Timestamp] = None, ts2: Option[Timestamp] = None, ts3: Option[Timestamp] = None, ts4: Option[Timestamp] = None, b: Option[Boolean] = None, sl: Option[List[String]] = None, ie: Option[Numbers] = None, on: Option[OpenNums] = None, ons: Option[OpenNumsStr] = None, dbl: Option[Double] = None, slm: Option[Map[String, String]] = None): P1[Queries, Nothing, Unit, Nothing, Nothing] = f[Queries, Nothing, Unit, Nothing, Nothing](alg.dummy(str, int, ts1, ts2, ts3, ts4, b, sl, ie, on, ons, dbl, slm))
def dummyHostPrefix(label1: String, label2: String, label3: HostLabelEnum): P1[HostLabelInput, Nothing, Unit, Nothing, Nothing] = f[HostLabelInput, Nothing, Unit, Nothing, Nothing](alg.dummyHostPrefix(label1, label2, label3))
def dummyPath(str: String, int: Int, ts1: Timestamp, ts2: Timestamp, ts3: Timestamp, ts4: Timestamp, b: Boolean, ie: Numbers): P1[PathParams, Nothing, Unit, Nothing, Nothing] = f[PathParams, Nothing, Unit, Nothing, Nothing](alg.dummyPath(str, int, ts1, ts2, ts3, ts4, b, ie))
}
Expand All @@ -83,7 +83,7 @@ object DummyServiceOperation {
def apply[I, E, O, SI, SO](op: DummyServiceOperation[I, E, O, SI, SO]): P[I, E, O, SI, SO] = op.run(impl)
}
final case class Dummy(input: Queries) extends DummyServiceOperation[Queries, Nothing, Unit, Nothing, Nothing] {
def run[F[_, _, _, _, _]](impl: DummyServiceGen[F]): F[Queries, Nothing, Unit, Nothing, Nothing] = impl.dummy(input.str, input.int, input.ts1, input.ts2, input.ts3, input.ts4, input.b, input.sl, input.ie, input.on, input.ons, input.slm)
def run[F[_, _, _, _, _]](impl: DummyServiceGen[F]): F[Queries, Nothing, Unit, Nothing, Nothing] = impl.dummy(input.str, input.int, input.ts1, input.ts2, input.ts3, input.ts4, input.b, input.sl, input.ie, input.on, input.ons, input.dbl, input.slm)
def ordinal: Int = 0
def endpoint: smithy4s.Endpoint[DummyServiceOperation,Queries, Nothing, Unit, Nothing, Nothing] = Dummy
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,21 @@ import smithy4s.ShapeId
import smithy4s.ShapeTag
import smithy4s.Timestamp
import smithy4s.schema.Schema.boolean
import smithy4s.schema.Schema.double
import smithy4s.schema.Schema.int
import smithy4s.schema.Schema.string
import smithy4s.schema.Schema.struct
import smithy4s.schema.Schema.timestamp

final case class Queries(str: Option[String] = None, int: Option[Int] = None, ts1: Option[Timestamp] = None, ts2: Option[Timestamp] = None, ts3: Option[Timestamp] = None, ts4: Option[Timestamp] = None, b: Option[Boolean] = None, sl: Option[List[String]] = None, ie: Option[Numbers] = None, on: Option[OpenNums] = None, ons: Option[OpenNumsStr] = None, slm: Option[Map[String, String]] = None)
final case class Queries(str: Option[String] = None, int: Option[Int] = None, ts1: Option[Timestamp] = None, ts2: Option[Timestamp] = None, ts3: Option[Timestamp] = None, ts4: Option[Timestamp] = None, b: Option[Boolean] = None, sl: Option[List[String]] = None, ie: Option[Numbers] = None, on: Option[OpenNums] = None, ons: Option[OpenNumsStr] = None, dbl: Option[Double] = None, slm: Option[Map[String, String]] = None)

object Queries extends ShapeTag.Companion[Queries] {
val id: ShapeId = ShapeId("smithy4s.example", "Queries")

val hints: Hints = Hints.empty

// constructor using the original order from the spec
private def make(str: Option[String], int: Option[Int], ts1: Option[Timestamp], ts2: Option[Timestamp], ts3: Option[Timestamp], ts4: Option[Timestamp], b: Option[Boolean], sl: Option[List[String]], ie: Option[Numbers], on: Option[OpenNums], ons: Option[OpenNumsStr], slm: Option[Map[String, String]]): Queries = Queries(str, int, ts1, ts2, ts3, ts4, b, sl, ie, on, ons, slm)
private def make(str: Option[String], int: Option[Int], ts1: Option[Timestamp], ts2: Option[Timestamp], ts3: Option[Timestamp], ts4: Option[Timestamp], b: Option[Boolean], sl: Option[List[String]], ie: Option[Numbers], on: Option[OpenNums], ons: Option[OpenNumsStr], dbl: Option[Double], slm: Option[Map[String, String]]): Queries = Queries(str, int, ts1, ts2, ts3, ts4, b, sl, ie, on, ons, dbl, slm)

implicit val schema: Schema[Queries] = struct(
string.optional[Queries]("str", _.str).addHints(smithy.api.HttpQuery("str")),
Expand All @@ -33,6 +34,7 @@ object Queries extends ShapeTag.Companion[Queries] {
Numbers.schema.optional[Queries]("ie", _.ie).addHints(smithy.api.HttpQuery("nums")),
OpenNums.schema.optional[Queries]("on", _.on).addHints(smithy.api.HttpQuery("openNums")),
OpenNumsStr.schema.optional[Queries]("ons", _.ons).addHints(smithy.api.HttpQuery("openNumsStr")),
double.validated(smithy.api.Range(min = Some(scala.math.BigDecimal(0.0)), max = Some(scala.math.BigDecimal(100.0)))).optional[Queries]("dbl", _.dbl).addHints(smithy.api.HttpQuery("dbl")),
StringMap.underlyingSchema.optional[Queries]("slm", _.slm).addHints(smithy.api.HttpQueryParams()),
)(make).withId(id).addHints(hints)
}
22 changes: 22 additions & 0 deletions modules/bootstrapped/test/src/smithy4s/DocumentSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import munit._
import smithy4s.example.DefaultNullsOperationOutput
import alloy.Untagged
import smithy4s.example.TimestampOperationInput
import scala.util.Try

class DocumentSpec() extends FunSuite {

Expand Down Expand Up @@ -370,6 +371,27 @@ class DocumentSpec() extends FunSuite {
expect.same(roundTripped, Right(mapTest))
}

test("encoding NaN") {
// The Document type cannot hold a `NaN` value since it uses BigDecimal to hold numeric values
// this test exists to show this. For the same reason, a test on decoding from `NaN` is not necessary
// or possible.
implicit val schema: Schema[Double] =
double.validated(smithy.api.Range(None, Some(BigDecimal(3))))

val in = Double.NaN
val error = Try(Document.encode(in)).failed.get
val expectedMessage =
if (weaver.Platform.isJS || weaver.Platform.isNative)
"For input string: \"NaN\""
else
"Character N is neither a decimal digit number, decimal point, nor \"e\" notation exponential mark."

expect.same(
error.getMessage,
expectedMessage
)
}

test(
"optional fields for structs should decode Document.DNull"
) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,28 @@ class MetadataSpec() extends FunSuite {
.left
.map(_.getMessage())
expect.same(encoded, expectedEncoding)
expect(result == Right(finished))
expect.same(result, Right(finished))
}

def checkQueryRoundTripError[A](
initial: A,
expectedEncoding: Metadata,
errorMessage: String,
allowNaN: Boolean
)(implicit
s: Schema[A],
loc: Location
): Unit = {
val encoded = Metadata.encode(initial)
val decoder =
if (allowNaN) Metadata.AwsDecoder.fromSchema(s)
else Metadata.Decoder.fromSchema(s)
val result = decoder
.decode(encoded)
.left
.map(_.getMessage())
expect.same(encoded, expectedEncoding)
expect.same(result, Left(errorMessage))
}

def checkRoundTripDefault[A](expectedDecoded: A)(implicit
Expand Down Expand Up @@ -123,6 +144,27 @@ class MetadataSpec() extends FunSuite {
checkQueryRoundTrip(queries, expected, finished)
}

// In this test the Metadata Decoder will allow NaN by creating a `Double.NaN` value.
// The Range RefinementProvider will reject this since `NaN` is not a valid `BigDecimal`
// which it uses
test("Double NaN query parameter - allow NaN in decoder") {
val queries = Queries(dbl = Some(Double.NaN))
val expected = Metadata(query = Map("dbl" -> List("NaN")))
val errorMessage =
"Field dbl, found in Query parameter dbl, failed constraint checks with message: Numeric values must not be NaN or pos/neg infinity. Found NaN"
checkQueryRoundTripError(queries, expected, errorMessage, allowNaN = true)
}

// This test is where the Metadata Decoder will reject NaN itself
// As such the RefinementProvider for Range will not be called in this test
test("Double NaN query parameter - disallow NaN in decoder") {
val queries = Queries(dbl = Some(Double.NaN))
val expected = Metadata(query = Map("dbl" -> List("NaN")))
val errorMessage =
"NaN or pos/neg infinity are not allowed for inputs of type Double"
checkQueryRoundTripError(queries, expected, errorMessage, allowNaN = false)
}

test("String query parameter with default") {
val expectedDecoded = QueriesWithDefaults(dflt = "test")
checkRoundTripDefault(expectedDecoded)
Expand Down
49 changes: 28 additions & 21 deletions modules/core/src/smithy4s/RefinementProvider.scala
Original file line number Diff line number Diff line change
Expand Up @@ -161,27 +161,34 @@ object RefinementProvider extends LowPriorityImplicits {
val N = implicitly[Numeric[N]]

(a: A) =>
val value = BigDecimal(N.toDouble(getValue(a)))
(range.min, range.max) match {
case (Some(min), Some(max)) =>
if (value >= min && value <= max) Right(())
else
Left(
s"Input must be >= $min and <= $max, but was $value"
)
case (None, Some(max)) =>
if (value <= max) Right(())
else
Left(
s"Input must be <= $max, but was $value"
)
case (Some(min), None) =>
if (value >= min) Right(())
else
Left(
s"Input must be >= $min, but was $value"
)
case (None, None) => Right(())
val doubleValue = N.toDouble(getValue(a))
if (doubleValue.isNaN || doubleValue.isInfinite) {
Left(
s"Numeric values must not be NaN or pos/neg infinity. Found $doubleValue"
)
} else {
val value = BigDecimal.apply(d = doubleValue)
(range.min, range.max) match {
case (Some(min), Some(max)) =>
if (value >= min && value <= max) Right(())
else
Left(
s"Input must be >= $min and <= $max, but was $value"
)
case (None, Some(max)) =>
if (value <= max) Right(())
else
Left(
s"Input must be <= $max, but was $value"
)
case (Some(min), None) =>
if (value >= min) Right(())
else
Left(
s"Input must be >= $min, but was $value"
)
case (None, None) => Right(())
}
}
}
}
Expand Down
23 changes: 18 additions & 5 deletions modules/core/src/smithy4s/http/Metadata.scala
Original file line number Diff line number Diff line change
Expand Up @@ -171,15 +171,24 @@ object Metadata {
implicit def decoderFromSchema[A: Schema]: Decoder[A] =
Decoder.derivedImplicitInstance

object Decoder extends CachedDecoderCompilerImpl(awsHeaderEncoding = false) {
object Decoder
extends CachedDecoderCompilerImpl(
awsHeaderEncoding = false,
allowNaNAndInfiniteValues = false
) {
type Compiler = CachedSchemaCompiler[Decoder]
}

private[smithy4s] object AwsDecoder
extends CachedDecoderCompilerImpl(awsHeaderEncoding = true)
extends CachedDecoderCompilerImpl(
awsHeaderEncoding = true,
allowNaNAndInfiniteValues = true
)

private[http] class CachedDecoderCompilerImpl(awsHeaderEncoding: Boolean)
extends CachedSchemaCompiler.DerivingImpl[Decoder] {
private[http] class CachedDecoderCompilerImpl(
awsHeaderEncoding: Boolean,
allowNaNAndInfiniteValues: Boolean
) extends CachedSchemaCompiler.DerivingImpl[Decoder] {
type Aux[A] = internals.MetaDecode[A]

def apply[A](implicit instance: Decoder[A]): Decoder[A] =
Expand All @@ -190,7 +199,11 @@ object Metadata {
cache: CompilationCache[internals.MetaDecode]
): Decoder[A] = {
val metaDecode =
new SchemaVisitorMetadataReader(cache, awsHeaderEncoding)(schema)
new SchemaVisitorMetadataReader(
cache,
awsHeaderEncoding,
allowNaNAndInfiniteValues
)(schema)
metaDecode match {
case internals.MetaDecode.StructureMetaDecode(decodeFunction) =>
decodeFunction(_: Metadata)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,13 @@ import java.util.Base64
* contains values such as path-parameters, query-parameters, headers, and status code.
*
* @param awsHeaderEncoding defines whether the AWS encoding of headers should be expected.
* @param allowNaNAndInfiniteValues defines whether or not Double and Float values of 'NaN'
* positive/negative infinity should be accepted.
*/
private[http] class SchemaVisitorMetadataReader(
val cache: CompilationCache[MetaDecode],
awsHeaderEncoding: Boolean
awsHeaderEncoding: Boolean,
allowNaNAndInfiniteValues: Boolean
) extends SchemaVisitor.Cached[MetaDecode]
with ScalaCompat { self =>

Expand All @@ -50,6 +53,38 @@ private[http] class SchemaVisitorMetadataReader(
tag: Primitive[P]
): MetaDecode[P] = {
val desc = SchemaDescription.primitive(shapeId, hints, tag)

tag match {
case Primitive.PDouble =>
val decode: MetaDecode[Double] =
primitiveHandler(shapeId, hints, tag, desc)
decode.map(d =>
if (!allowNaNAndInfiniteValues && (d.isNaN || d.isInfinite))
throw MetadataError.ImpossibleDecoding(
s"NaN or pos/neg infinity are not allowed for inputs of type $desc"
)
else d
)
case Primitive.PFloat =>
val decode: MetaDecode[Float] =
primitiveHandler(shapeId, hints, tag, desc)
decode.map(f =>
if (!allowNaNAndInfiniteValues && (f.isNaN || f.isInfinite))
throw MetadataError.ImpossibleDecoding(
s"NaN or pos/neg infinity are not allowed for inputs of type $desc"
)
else f
)
case _ => primitiveHandler(shapeId, hints, tag, desc)
}
}

private def primitiveHandler[P](
shapeId: ShapeId,
hints: Hints,
tag: Primitive[P],
desc: String
): MetaDecode[P] = {
val hasMedia = hints.has(smithy.api.MediaType)
Primitive.stringParser(tag, hints) match {
case Some(parse) if hasMedia =>
Expand Down
11 changes: 11 additions & 0 deletions modules/json/test/src/smithy4s/json/JsonSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,17 @@ class JsonSpec() extends FunSuite {
assertEquals(roundTripped, Right(foo))
}

test("Json read - NaN") {
implicit val schemaDouble: Schema[Double] =
double.validated(smithy.api.Range(None, Some(BigDecimal(3))))
val expectedJson = """"NaN""""
val roundTripped = Json.read[Double](Blob(expectedJson))

assert(
roundTripped.left.toOption.get.message.startsWith("illegal number")
)
}

test("Json document read/write") {
val foo =
Document.obj("a" -> Document.fromInt(1), "b" -> Document.fromInt(2))
Expand Down
3 changes: 3 additions & 0 deletions sampleSpecs/metadata.smithy
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ structure Queries {
on: OpenNums
@httpQuery("openNumsStr")
ons: OpenNumsStr
@httpQuery("dbl")
@range(min: 0, max: 100)
dbl: Double
@httpQueryParams
slm: StringMap
}
Expand Down
Loading