From 2eaf0587883ac3c65e77d01ffbb39f64c6152f87 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Sat, 13 Oct 2018 14:49:38 +0800 Subject: [PATCH] [SPARK-25718][SQL] Detect recursive reference in Avro schema and throw exception ## What changes were proposed in this pull request? Avro schema allows recursive reference, e.g. the schema for linked-list in https://avro.apache.org/docs/1.8.2/spec.html#schema_record ``` { "type": "record", "name": "LongList", "aliases": ["LinkedLongs"], // old name for this "fields" : [ {"name": "value", "type": "long"}, // each element has a long {"name": "next", "type": ["null", "LongList"]} // optional next element ] } ``` In current Spark SQL, it is impossible to convert the schema as `StructType` . Run `SchemaConverters.toSqlType(avroSchema)` and we will get stack overflow exception. We should detect the recursive reference and throw exception for it. ## How was this patch tested? New unit test case. Closes #22709 from gengliangwang/avroRecursiveRef. Authored-by: Gengliang Wang Signed-off-by: Wenchen Fan --- .../spark/sql/avro/SchemaConverters.scala | 26 ++++++-- .../org/apache/spark/sql/avro/AvroSuite.scala | 65 +++++++++++++++++++ 2 files changed, 84 insertions(+), 7 deletions(-) diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala index bd1576587d7fa..64127af73881b 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala @@ -43,6 +43,10 @@ object SchemaConverters { * This function takes an avro schema and returns a sql schema. */ def toSqlType(avroSchema: Schema): SchemaType = { + toSqlTypeHelper(avroSchema, Set.empty) + } + + def toSqlTypeHelper(avroSchema: Schema, existingRecordNames: Set[String]): SchemaType = { avroSchema.getType match { case INT => avroSchema.getLogicalType match { case _: Date => SchemaType(DateType, nullable = false) @@ -67,21 +71,28 @@ object SchemaConverters { case ENUM => SchemaType(StringType, nullable = false) case RECORD => + if (existingRecordNames.contains(avroSchema.getFullName)) { + throw new IncompatibleSchemaException(s""" + |Found recursive reference in Avro schema, which can not be processed by Spark: + |${avroSchema.toString(true)} + """.stripMargin) + } + val newRecordNames = existingRecordNames + avroSchema.getFullName val fields = avroSchema.getFields.asScala.map { f => - val schemaType = toSqlType(f.schema()) + val schemaType = toSqlTypeHelper(f.schema(), newRecordNames) StructField(f.name, schemaType.dataType, schemaType.nullable) } SchemaType(StructType(fields), nullable = false) case ARRAY => - val schemaType = toSqlType(avroSchema.getElementType) + val schemaType = toSqlTypeHelper(avroSchema.getElementType, existingRecordNames) SchemaType( ArrayType(schemaType.dataType, containsNull = schemaType.nullable), nullable = false) case MAP => - val schemaType = toSqlType(avroSchema.getValueType) + val schemaType = toSqlTypeHelper(avroSchema.getValueType, existingRecordNames) SchemaType( MapType(StringType, schemaType.dataType, valueContainsNull = schemaType.nullable), nullable = false) @@ -91,13 +102,14 @@ object SchemaConverters { // In case of a union with null, eliminate it and make a recursive call val remainingUnionTypes = avroSchema.getTypes.asScala.filterNot(_.getType == NULL) if (remainingUnionTypes.size == 1) { - toSqlType(remainingUnionTypes.head).copy(nullable = true) + toSqlTypeHelper(remainingUnionTypes.head, existingRecordNames).copy(nullable = true) } else { - toSqlType(Schema.createUnion(remainingUnionTypes.asJava)).copy(nullable = true) + toSqlTypeHelper(Schema.createUnion(remainingUnionTypes.asJava), existingRecordNames) + .copy(nullable = true) } } else avroSchema.getTypes.asScala.map(_.getType) match { case Seq(t1) => - toSqlType(avroSchema.getTypes.get(0)) + toSqlTypeHelper(avroSchema.getTypes.get(0), existingRecordNames) case Seq(t1, t2) if Set(t1, t2) == Set(INT, LONG) => SchemaType(LongType, nullable = false) case Seq(t1, t2) if Set(t1, t2) == Set(FLOAT, DOUBLE) => @@ -107,7 +119,7 @@ object SchemaConverters { // This is consistent with the behavior when converting between Avro and Parquet. val fields = avroSchema.getTypes.asScala.zipWithIndex.map { case (s, i) => - val schemaType = toSqlType(s) + val schemaType = toSqlTypeHelper(s, existingRecordNames) // All fields are nullable because only one of them is set at a time StructField(s"member$i", schemaType.dataType, nullable = true) } diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index 1e08f7b50b115..4fea2cb969446 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -1309,4 +1309,69 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { checkCodec(df, path, "xz") } } + + private def checkSchemaWithRecursiveLoop(avroSchema: String): Unit = { + val message = intercept[IncompatibleSchemaException] { + SchemaConverters.toSqlType(new Schema.Parser().parse(avroSchema)) + }.getMessage + + assert(message.contains("Found recursive reference in Avro schema")) + } + + test("Detect recursive loop") { + checkSchemaWithRecursiveLoop(""" + |{ + | "type": "record", + | "name": "LongList", + | "fields" : [ + | {"name": "value", "type": "long"}, // each element has a long + | {"name": "next", "type": ["null", "LongList"]} // optional next element + | ] + |} + """.stripMargin) + + checkSchemaWithRecursiveLoop(""" + |{ + | "type": "record", + | "name": "LongList", + | "fields": [ + | { + | "name": "value", + | "type": { + | "type": "record", + | "name": "foo", + | "fields": [ + | { + | "name": "parent", + | "type": "LongList" + | } + | ] + | } + | } + | ] + |} + """.stripMargin) + + checkSchemaWithRecursiveLoop(""" + |{ + | "type": "record", + | "name": "LongList", + | "fields" : [ + | {"name": "value", "type": "long"}, + | {"name": "array", "type": {"type": "array", "items": "LongList"}} + | ] + |} + """.stripMargin) + + checkSchemaWithRecursiveLoop(""" + |{ + | "type": "record", + | "name": "LongList", + | "fields" : [ + | {"name": "value", "type": "long"}, + | {"name": "map", "type": {"type": "map", "values": "LongList"}} + | ] + |} + """.stripMargin) + } }