Skip to content

Commit

Permalink
Cosmos Spark - switching to reflection instead of bridge-approach to …
Browse files Browse the repository at this point in the history
…avoid SecurityException when installing connector by directly copying into /databricks/jars folder (Azure#37934)

* Fixing high number of PKRangeFeed calls when using BulkExecution without SparkConnector

* Adding unit test coverage

* Update CHANGELOG.md

* Switching to reflection for spark internals access (custom metrics) for Spark 3.4

* Switching to reflection for custom metrics in Spark 3.1-3.3

* Changelog

* Test fixes
  • Loading branch information
FabianMeiswinkel authored Dec 6, 2023
1 parent 753adcb commit 41243ba
Show file tree
Hide file tree
Showing 16 changed files with 516 additions and 117 deletions.
1 change: 1 addition & 0 deletions sdk/cosmos/azure-cosmos-spark_3-1_2-12/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#### Breaking Changes

#### Bugs Fixed
* Fixed `SecurityException` with message `java.lang.SecurityException: class "org.apache.spark.SparkInternalsBridge$"'s signer information does not match signer information of other classes in the same package` when deploying the Spark connector in Databricks by copying it directly to `/databricks/jars` instead of going through the usual deployment APIs or UI-deployment. To fix this issue, instead of using a `bridge-approach` reflection is used to use the internal API necessary to publish custom metrics. See [PR 37934](https://github.com/Azure/azure-sdk-for-java/pull/37934)

#### Other Changes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ package com.azure.cosmos.spark

import com.azure.cosmos.CosmosDiagnosticsContext
import com.azure.cosmos.implementation.ImplementationBridgeHelpers
import org.apache.spark.SparkInternalsBridge
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.write.WriterCommitMessage
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
package com.azure.cosmos.spark

import com.azure.cosmos.implementation.guava25.base.MoreObjects.firstNonNull
import com.azure.cosmos.implementation.guava25.base.Strings.emptyToNull
import com.azure.cosmos.spark.diagnostics.BasicLoggingTrait
import org.apache.spark.TaskContext

import java.lang.reflect.Method
import java.util.Locale
import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference}

//scalastyle:off multiple.string.literals
object SparkInternalsBridge extends BasicLoggingTrait {
private val SPARK_REFLECTION_ACCESS_ALLOWED_PROPERTY = "COSMOS.SPARK_REFLECTION_ACCESS_ALLOWED"
private val SPARK_REFLECTION_ACCESS_ALLOWED_VARIABLE = "COSMOS_SPARK_REFLECTION_ACCESS_ALLOWED"

private val DEFAULT_SPARK_REFLECTION_ACCESS_ALLOWED = true

val NUM_ROWS_PER_UPDATE = 100
private val outputMetricsMethod: AtomicReference[Method] = new AtomicReference[Method]()
private val setBytesWrittenMethod: AtomicReference[Method] = new AtomicReference[Method]()
private val setRecordsWrittenMethod: AtomicReference[Method] = new AtomicReference[Method]()

private def getSparkReflectionAccessAllowed: Boolean = {
val allowedText = System.getProperty(
SPARK_REFLECTION_ACCESS_ALLOWED_PROPERTY,
firstNonNull(
emptyToNull(System.getenv.get(SPARK_REFLECTION_ACCESS_ALLOWED_VARIABLE)),
String.valueOf(DEFAULT_SPARK_REFLECTION_ACCESS_ALLOWED)))

try {
java.lang.Boolean.valueOf(allowedText.toUpperCase(Locale.ROOT))
}
catch {
case e: Exception =>
logError(s"Parsing spark reflection access allowed $allowedText failed. Using the default $DEFAULT_SPARK_REFLECTION_ACCESS_ALLOWED.", e)
DEFAULT_SPARK_REFLECTION_ACCESS_ALLOWED
}
}

private final lazy val reflectionAccessAllowed = new AtomicBoolean(getSparkReflectionAccessAllowed)

private def getOutputMetrics(taskCtx: TaskContext): Option[Object] = {
try {
val taskMetrics: Object = taskCtx.taskMetrics()

val method = Option(outputMetricsMethod.get) match {
case Some(existing) => existing
case None =>
val newMethod = taskMetrics.getClass.getMethod("outputMetrics")
newMethod.setAccessible(true)
outputMetricsMethod.set(newMethod)
newMethod
}

val outputMetrics = method.invoke(taskMetrics)
Option(outputMetrics)
} catch {
case e: Exception =>
logInfo(s"Could not invoke getOutputMetrics via reflection - Error ${e.getMessage}", e)

// reflection failed - disabling it for the future
reflectionAccessAllowed.set(false)
None
}
}

private def setBytesWritten(outputMetrics: Object, metricValue: Object): Unit = {
try {
val method = Option(setBytesWrittenMethod.get) match {
case Some(existing) => existing
case None =>
val newMethod = outputMetrics.getClass.getMethod("setBytesWritten", java.lang.Long.TYPE)
newMethod.setAccessible(true)
setBytesWrittenMethod.set(newMethod)
newMethod
}

method.invoke(outputMetrics, metricValue)
} catch {
case e: Exception =>
logInfo(s"Could not invoke setBytesWritten via reflection - Error ${e.getMessage}", e)

// reflection failed - disabling it for the future
reflectionAccessAllowed.set(false)
}
}

private def setRecordsWritten(outputMetrics: Object, metricValue: Object): Unit = {
try {
val method = Option(setRecordsWrittenMethod.get) match {
case Some(existing) => existing
case None =>
val newMethod = outputMetrics.getClass.getMethod("setRecordsWritten", java.lang.Long.TYPE)
newMethod.setAccessible(true)
setRecordsWrittenMethod.set(newMethod)
newMethod
}
method.invoke(outputMetrics, metricValue)
} catch {
case e: Exception =>
logInfo(s"Could not invoke setRecordsWritten via reflection - Error ${e.getMessage}", e)

// reflection failed - disabling it for the future
reflectionAccessAllowed.set(false)
}
}

def updateInternalTaskMetrics(recordsWrittenSnapshot: Long, bytesWrittenSnapshot: Long): Unit = {
if (reflectionAccessAllowed.get) {
Option(TaskContext.get()) match {
case Some(taskContext) =>
getOutputMetrics(taskContext) match {
case Some(outputMetrics) =>
setRecordsWritten(outputMetrics, recordsWrittenSnapshot.asInstanceOf[Object])
setBytesWritten(outputMetrics, bytesWrittenSnapshot.asInstanceOf[Object])
case None =>
}
}
}
}
}
//scalastyle:on multiple.string.literals

This file was deleted.

1 change: 1 addition & 0 deletions sdk/cosmos/azure-cosmos-spark_3-2_2-12/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#### Breaking Changes

#### Bugs Fixed
* Fixed `SecurityException` with message `java.lang.SecurityException: class "org.apache.spark.SparkInternalsBridge$"'s signer information does not match signer information of other classes in the same package` when deploying the Spark connector in Databricks by copying it directly to `/databricks/jars` instead of going through the usual deployment APIs or UI-deployment. To fix this issue, instead of using a `bridge-approach` reflection is used to use the internal API necessary to publish custom metrics. See [PR 37934](https://github.com/Azure/azure-sdk-for-java/pull/37934)

#### Other Changes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ package com.azure.cosmos.spark

import com.azure.cosmos.CosmosDiagnosticsContext
import com.azure.cosmos.implementation.ImplementationBridgeHelpers
import org.apache.spark.SparkInternalsBridge
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.metric.CustomTaskMetric
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
package com.azure.cosmos.spark

import com.azure.cosmos.implementation.guava25.base.MoreObjects.firstNonNull
import com.azure.cosmos.implementation.guava25.base.Strings.emptyToNull
import com.azure.cosmos.spark.diagnostics.BasicLoggingTrait
import org.apache.spark.TaskContext
import org.apache.spark.sql.connector.metric.CustomTaskMetric
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.util.AccumulatorV2

import java.lang.reflect.Method
import java.util.Locale
import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference}
import scala.collection.mutable.ArrayBuffer

//scalastyle:off multiple.string.literals
object SparkInternalsBridge extends BasicLoggingTrait {
private val SPARK_REFLECTION_ACCESS_ALLOWED_PROPERTY = "COSMOS.SPARK_REFLECTION_ACCESS_ALLOWED"
private val SPARK_REFLECTION_ACCESS_ALLOWED_VARIABLE = "COSMOS_SPARK_REFLECTION_ACCESS_ALLOWED"

private val DEFAULT_SPARK_REFLECTION_ACCESS_ALLOWED = true
val NUM_ROWS_PER_UPDATE = 100

private val BUILTIN_OUTPUT_METRICS = Set("bytesWritten", "recordsWritten")

private val accumulatorsMethod : AtomicReference[Method] = new AtomicReference[Method]()
private val outputMetricsMethod : AtomicReference[Method] = new AtomicReference[Method]()
private val setBytesWrittenMethod : AtomicReference[Method] = new AtomicReference[Method]()
private val setRecordsWrittenMethod : AtomicReference[Method] = new AtomicReference[Method]()

private def getSparkReflectionAccessAllowed: Boolean = {
val allowedText = System.getProperty(
SPARK_REFLECTION_ACCESS_ALLOWED_PROPERTY,
firstNonNull(
emptyToNull(System.getenv.get(SPARK_REFLECTION_ACCESS_ALLOWED_VARIABLE)),
String.valueOf(DEFAULT_SPARK_REFLECTION_ACCESS_ALLOWED)))

try {
java.lang.Boolean.valueOf(allowedText.toUpperCase(Locale.ROOT))
}
catch {
case e: Exception =>
logError(s"Parsing spark reflection access allowed $allowedText failed. Using the default $DEFAULT_SPARK_REFLECTION_ACCESS_ALLOWED.", e)
DEFAULT_SPARK_REFLECTION_ACCESS_ALLOWED
}
}

private final lazy val reflectionAccessAllowed = new AtomicBoolean(getSparkReflectionAccessAllowed)

def getInternalCustomTaskMetricsAsSQLMetric(knownCosmosMetricNames: Set[String]): Map[String, SQLMetric] = {
if (!reflectionAccessAllowed.get) {
Map.empty[String, SQLMetric]
} else {
Option.apply(TaskContext.get()) match {
case Some(taskCtx) => getInternalCustomTaskMetricsAsSQLMetricInternal(knownCosmosMetricNames, taskCtx)
case None => Map.empty[String, SQLMetric]
}
}
}

private def getAccumulators(taskCtx: TaskContext): Option[ArrayBuffer[AccumulatorV2[_, _]]] = {
try {
val taskMetrics: Object = taskCtx.taskMetrics()
val method = Option(accumulatorsMethod.get) match {
case Some(existing) => existing
case None =>
val newMethod = taskMetrics.getClass.getMethod("externalAccums")
newMethod.setAccessible(true)
accumulatorsMethod.set(newMethod)
newMethod
}

val accums = method.invoke(taskMetrics).asInstanceOf[ArrayBuffer[AccumulatorV2[_, _]]]

Some(accums)
} catch {
case e: Exception =>
logInfo(s"Could not invoke getAccumulators via reflection - Error ${e.getMessage}", e)

// reflection failed - disabling it for the future
reflectionAccessAllowed.set(false)
None
}
}

private def getOutputMetrics(taskCtx: TaskContext): Option[Object] = {
try {
val taskMetrics: Object = taskCtx.taskMetrics()

val method = Option(outputMetricsMethod.get) match {
case Some(existing) => existing
case None =>
val newMethod = taskMetrics.getClass.getMethod("outputMetrics")
newMethod.setAccessible(true)
outputMetricsMethod.set(newMethod)
newMethod
}

val outputMetrics = method.invoke(taskMetrics)
Option(outputMetrics)
} catch {
case e: Exception =>
logInfo(s"Could not invoke getOutputMetrics via reflection - Error ${e.getMessage}", e)

// reflection failed - disabling it for the future
reflectionAccessAllowed.set(false)
None
}
}

private def setBytesWritten(outputMetrics: Object, metricValue: Object): Unit = {
try {
val method = Option(setBytesWrittenMethod.get) match {
case Some(existing) => existing
case None =>
val newMethod = outputMetrics.getClass.getMethod("setBytesWritten", java.lang.Long.TYPE)
newMethod.setAccessible(true)
setBytesWrittenMethod.set(newMethod)
newMethod
}

method.invoke(outputMetrics, metricValue)
} catch {
case e: Exception =>
logInfo(s"Could not invoke setBytesWritten via reflection - Error ${e.getMessage}", e)

// reflection failed - disabling it for the future
reflectionAccessAllowed.set(false)
}
}

private def setRecordsWritten(outputMetrics: Object, metricValue: Object): Unit = {
try {
val method = Option(setRecordsWrittenMethod.get) match {
case Some(existing) => existing
case None =>
val newMethod = outputMetrics.getClass.getMethod("setRecordsWritten", java.lang.Long.TYPE)
newMethod.setAccessible(true)
setRecordsWrittenMethod.set(newMethod)
newMethod
}
method.invoke(outputMetrics, metricValue)
} catch {
case e: Exception =>
logInfo(s"Could not invoke setRecordsWritten via reflection - Error ${e.getMessage}", e)

// reflection failed - disabling it for the future
reflectionAccessAllowed.set(false)
}
}

private def getInternalCustomTaskMetricsAsSQLMetricInternal(
knownCosmosMetricNames: Set[String],
taskCtx: TaskContext): Map[String, SQLMetric] = {
getAccumulators(taskCtx) match {
case Some(accumulators) => accumulators
.filter(accumulable => accumulable.isInstanceOf[SQLMetric]
&& accumulable.name.isDefined
&& knownCosmosMetricNames.contains(accumulable.name.get))
.map(accumulable => {
val sqlMetric = accumulable.asInstanceOf[SQLMetric]
sqlMetric.name.get -> sqlMetric
})
.toMap[String, SQLMetric]
case None => Map.empty[String, SQLMetric]
}
}

def updateInternalTaskMetrics(currentMetricsValues: Seq[CustomTaskMetric]): Unit = {
if (reflectionAccessAllowed.get) {
currentMetricsValues.foreach { metric =>
val metricName = metric.name()
val metricValue = metric.value()

if (BUILTIN_OUTPUT_METRICS.contains(metricName)) {
Option(TaskContext.get()).map(getOutputMetrics).foreach { outputMetricsOption =>

outputMetricsOption match {
case Some(outputMetrics) =>

metricName match {
case "bytesWritten" => setBytesWritten(outputMetrics, metricValue.asInstanceOf[Object])
case "recordsWritten" => setRecordsWritten(outputMetrics, metricValue.asInstanceOf[Object])
case _ => // no-op
}
case None =>
}
}
}
}
}
}
}
//scalastyle:on multiple.string.literals
Loading

0 comments on commit 41243ba

Please sign in to comment.