diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala index c947d948b4cf3..0740334724e82 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala @@ -85,7 +85,13 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM |""".stripMargin) .collect() } - assert(ex.getErrorClass != null) + assert( + ex.getErrorClass === + "INCONSISTENT_BEHAVIOR_CROSS_VERSION.PARSE_DATETIME_BY_NEW_PARSER") + assert( + ex.getMessageParameters.asScala == Map( + "datetime" -> "'02-29'", + "config" -> "\"spark.sql.legacy.timeParserPolicy\"")) if (enrichErrorEnabled) { assert(ex.getCause.isInstanceOf[DateTimeException]) } else { diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala index 075526e7521d9..cc47924de3b04 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala @@ -372,10 +372,14 @@ private[client] object GrpcExceptionConverter { .addAllErrorTypeHierarchy(classes.toImmutableArraySeq.asJava) if (errorClass != null) { + val messageParameters = JsonMethods + .parse(info.getMetadataOrDefault("messageParameters", "{}")) + .extract[Map[String, String]] builder.setSparkThrowable( FetchErrorDetailsResponse.SparkThrowable .newBuilder() .setErrorClass(errorClass) + .putAllMessageParameters(messageParameters.asJava) .build()) } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala index ab4f06d508a06..39bf1a630af62 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala @@ -256,4 +256,13 @@ object Connect { .version("4.0.0") .booleanConf .createWithDefault(true) + + val CONNECT_GRPC_MAX_METADATA_SIZE = + buildStaticConf("spark.connect.grpc.maxMetadataSize") + .doc( + "Sets the maximum size of metadata fields. For instance, it restricts metadata fields " + + "in `ErrorInfo`.") + .version("4.0.0") + .bytesConf(ByteUnit.BYTE) + .createWithDefault(1024) } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala index 703b11c0c736b..f489551a1dbab 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala @@ -172,6 +172,7 @@ private[connect] object ErrorUtils extends Logging { "classes", JsonMethods.compact(JsonMethods.render(allClasses(st.getClass).map(_.getName)))) + val maxMetadataSize = SparkEnv.get.conf.get(Connect.CONNECT_GRPC_MAX_METADATA_SIZE) // Add the SQL State and Error Class to the response metadata of the ErrorInfoObject. st match { case e: SparkThrowable => @@ -181,7 +182,12 @@ private[connect] object ErrorUtils extends Logging { } val errorClass = e.getErrorClass if (errorClass != null && errorClass.nonEmpty) { - errorInfo.putMetadata("errorClass", errorClass) + val messageParameters = JsonMethods.compact( + JsonMethods.render(map2jvalue(e.getMessageParameters.asScala.toMap))) + if (messageParameters.length <= maxMetadataSize) { + errorInfo.putMetadata("errorClass", errorClass) + errorInfo.putMetadata("messageParameters", messageParameters) + } } case _ => } @@ -200,8 +206,10 @@ private[connect] object ErrorUtils extends Logging { val withStackTrace = if (sessionHolderOpt.exists( _.session.conf.get(SQLConf.PYSPARK_JVM_STACKTRACE_ENABLED) && stackTrace.nonEmpty)) { - val maxSize = SparkEnv.get.conf.get(Connect.CONNECT_JVM_STACK_TRACE_MAX_SIZE) - errorInfo.putMetadata("stackTrace", StringUtils.abbreviate(stackTrace.get, maxSize)) + val maxSize = Math.min( + SparkEnv.get.conf.get(Connect.CONNECT_JVM_STACK_TRACE_MAX_SIZE), + maxMetadataSize) + errorInfo.putMetadata("stackTrace", StringUtils.abbreviate(stackTrace.get, maxSize.toInt)) } else { errorInfo } diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 32cd4ed624958..7275f40b39a93 100755 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -3452,7 +3452,7 @@ def test_error_stack_trace(self): self.spark.stop() spark = ( PySparkSession.builder.config(conf=self.conf()) - .config("spark.connect.jvmStacktrace.maxSize", 128) + .config("spark.connect.grpc.maxMetadataSize", 128) .remote("local[4]") .getOrCreate() )