Skip to content

Commit

Permalink
[SPARK-25718][SQL] Detect recursive reference in Avro schema and thro…
Browse files Browse the repository at this point in the history
…w exception

## What changes were proposed in this pull request?

Avro schema allows recursive reference, e.g. the schema for linked-list in https://avro.apache.org/docs/1.8.2/spec.html#schema_record
```
{
  "type": "record",
  "name": "LongList",
  "aliases": ["LinkedLongs"],                      // old name for this
  "fields" : [
    {"name": "value", "type": "long"},             // each element has a long
    {"name": "next", "type": ["null", "LongList"]} // optional next element
  ]
}
```

In current Spark SQL, it is impossible to convert the schema as `StructType` . Run `SchemaConverters.toSqlType(avroSchema)` and we will get stack overflow exception.

We should detect the recursive reference and throw exception for it.
## How was this patch tested?

New unit test case.

Closes #22709 from gengliangwang/avroRecursiveRef.

Authored-by: Gengliang Wang <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
gengliangwang authored and cloud-fan committed Oct 13, 2018
1 parent 8812746 commit 2eaf058
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ object SchemaConverters {
* This function takes an avro schema and returns a sql schema.
*/
def toSqlType(avroSchema: Schema): SchemaType = {
toSqlTypeHelper(avroSchema, Set.empty)
}

def toSqlTypeHelper(avroSchema: Schema, existingRecordNames: Set[String]): SchemaType = {
avroSchema.getType match {
case INT => avroSchema.getLogicalType match {
case _: Date => SchemaType(DateType, nullable = false)
Expand All @@ -67,21 +71,28 @@ object SchemaConverters {
case ENUM => SchemaType(StringType, nullable = false)

case RECORD =>
if (existingRecordNames.contains(avroSchema.getFullName)) {
throw new IncompatibleSchemaException(s"""
|Found recursive reference in Avro schema, which can not be processed by Spark:
|${avroSchema.toString(true)}
""".stripMargin)
}
val newRecordNames = existingRecordNames + avroSchema.getFullName
val fields = avroSchema.getFields.asScala.map { f =>
val schemaType = toSqlType(f.schema())
val schemaType = toSqlTypeHelper(f.schema(), newRecordNames)
StructField(f.name, schemaType.dataType, schemaType.nullable)
}

SchemaType(StructType(fields), nullable = false)

case ARRAY =>
val schemaType = toSqlType(avroSchema.getElementType)
val schemaType = toSqlTypeHelper(avroSchema.getElementType, existingRecordNames)
SchemaType(
ArrayType(schemaType.dataType, containsNull = schemaType.nullable),
nullable = false)

case MAP =>
val schemaType = toSqlType(avroSchema.getValueType)
val schemaType = toSqlTypeHelper(avroSchema.getValueType, existingRecordNames)
SchemaType(
MapType(StringType, schemaType.dataType, valueContainsNull = schemaType.nullable),
nullable = false)
Expand All @@ -91,13 +102,14 @@ object SchemaConverters {
// In case of a union with null, eliminate it and make a recursive call
val remainingUnionTypes = avroSchema.getTypes.asScala.filterNot(_.getType == NULL)
if (remainingUnionTypes.size == 1) {
toSqlType(remainingUnionTypes.head).copy(nullable = true)
toSqlTypeHelper(remainingUnionTypes.head, existingRecordNames).copy(nullable = true)
} else {
toSqlType(Schema.createUnion(remainingUnionTypes.asJava)).copy(nullable = true)
toSqlTypeHelper(Schema.createUnion(remainingUnionTypes.asJava), existingRecordNames)
.copy(nullable = true)
}
} else avroSchema.getTypes.asScala.map(_.getType) match {
case Seq(t1) =>
toSqlType(avroSchema.getTypes.get(0))
toSqlTypeHelper(avroSchema.getTypes.get(0), existingRecordNames)
case Seq(t1, t2) if Set(t1, t2) == Set(INT, LONG) =>
SchemaType(LongType, nullable = false)
case Seq(t1, t2) if Set(t1, t2) == Set(FLOAT, DOUBLE) =>
Expand All @@ -107,7 +119,7 @@ object SchemaConverters {
// This is consistent with the behavior when converting between Avro and Parquet.
val fields = avroSchema.getTypes.asScala.zipWithIndex.map {
case (s, i) =>
val schemaType = toSqlType(s)
val schemaType = toSqlTypeHelper(s, existingRecordNames)
// All fields are nullable because only one of them is set at a time
StructField(s"member$i", schemaType.dataType, nullable = true)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1309,4 +1309,69 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
checkCodec(df, path, "xz")
}
}

private def checkSchemaWithRecursiveLoop(avroSchema: String): Unit = {
val message = intercept[IncompatibleSchemaException] {
SchemaConverters.toSqlType(new Schema.Parser().parse(avroSchema))
}.getMessage

assert(message.contains("Found recursive reference in Avro schema"))
}

test("Detect recursive loop") {
checkSchemaWithRecursiveLoop("""
|{
| "type": "record",
| "name": "LongList",
| "fields" : [
| {"name": "value", "type": "long"}, // each element has a long
| {"name": "next", "type": ["null", "LongList"]} // optional next element
| ]
|}
""".stripMargin)

checkSchemaWithRecursiveLoop("""
|{
| "type": "record",
| "name": "LongList",
| "fields": [
| {
| "name": "value",
| "type": {
| "type": "record",
| "name": "foo",
| "fields": [
| {
| "name": "parent",
| "type": "LongList"
| }
| ]
| }
| }
| ]
|}
""".stripMargin)

checkSchemaWithRecursiveLoop("""
|{
| "type": "record",
| "name": "LongList",
| "fields" : [
| {"name": "value", "type": "long"},
| {"name": "array", "type": {"type": "array", "items": "LongList"}}
| ]
|}
""".stripMargin)

checkSchemaWithRecursiveLoop("""
|{
| "type": "record",
| "name": "LongList",
| "fields" : [
| {"name": "value", "type": "long"},
| {"name": "map", "type": {"type": "map", "values": "LongList"}}
| ]
|}
""".stripMargin)
}
}

0 comments on commit 2eaf058

Please sign in to comment.