forked from Azure/azure-sdk-for-java
-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Cosmos Spark - switching to reflection instead of bridge-approach to …
…avoid SecurityException when installing connector by directly copying into /databricks/jars folder (Azure#37934) * Fixing high number of PKRangeFeed calls when using BulkExecution without SparkConnector * Adding unit test coverage * Update CHANGELOG.md * Switching to reflection for spark internals access (custom metrics) for Spark 3.4 * Switching to reflection for custom metrics in Spark 3.1-3.3 * Changelog * Test fixes
- Loading branch information
1 parent
753adcb
commit 41243ba
Showing
16 changed files
with
516 additions
and
117 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
125 changes: 125 additions & 0 deletions
125
...re-cosmos-spark_3-1_2-12/src/main/scala/com/azure/cosmos/spark/SparkInternalsBridge.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
package com.azure.cosmos.spark | ||
|
||
import com.azure.cosmos.implementation.guava25.base.MoreObjects.firstNonNull | ||
import com.azure.cosmos.implementation.guava25.base.Strings.emptyToNull | ||
import com.azure.cosmos.spark.diagnostics.BasicLoggingTrait | ||
import org.apache.spark.TaskContext | ||
|
||
import java.lang.reflect.Method | ||
import java.util.Locale | ||
import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference} | ||
|
||
//scalastyle:off multiple.string.literals | ||
object SparkInternalsBridge extends BasicLoggingTrait { | ||
private val SPARK_REFLECTION_ACCESS_ALLOWED_PROPERTY = "COSMOS.SPARK_REFLECTION_ACCESS_ALLOWED" | ||
private val SPARK_REFLECTION_ACCESS_ALLOWED_VARIABLE = "COSMOS_SPARK_REFLECTION_ACCESS_ALLOWED" | ||
|
||
private val DEFAULT_SPARK_REFLECTION_ACCESS_ALLOWED = true | ||
|
||
val NUM_ROWS_PER_UPDATE = 100 | ||
private val outputMetricsMethod: AtomicReference[Method] = new AtomicReference[Method]() | ||
private val setBytesWrittenMethod: AtomicReference[Method] = new AtomicReference[Method]() | ||
private val setRecordsWrittenMethod: AtomicReference[Method] = new AtomicReference[Method]() | ||
|
||
private def getSparkReflectionAccessAllowed: Boolean = { | ||
val allowedText = System.getProperty( | ||
SPARK_REFLECTION_ACCESS_ALLOWED_PROPERTY, | ||
firstNonNull( | ||
emptyToNull(System.getenv.get(SPARK_REFLECTION_ACCESS_ALLOWED_VARIABLE)), | ||
String.valueOf(DEFAULT_SPARK_REFLECTION_ACCESS_ALLOWED))) | ||
|
||
try { | ||
java.lang.Boolean.valueOf(allowedText.toUpperCase(Locale.ROOT)) | ||
} | ||
catch { | ||
case e: Exception => | ||
logError(s"Parsing spark reflection access allowed $allowedText failed. Using the default $DEFAULT_SPARK_REFLECTION_ACCESS_ALLOWED.", e) | ||
DEFAULT_SPARK_REFLECTION_ACCESS_ALLOWED | ||
} | ||
} | ||
|
||
private final lazy val reflectionAccessAllowed = new AtomicBoolean(getSparkReflectionAccessAllowed) | ||
|
||
private def getOutputMetrics(taskCtx: TaskContext): Option[Object] = { | ||
try { | ||
val taskMetrics: Object = taskCtx.taskMetrics() | ||
|
||
val method = Option(outputMetricsMethod.get) match { | ||
case Some(existing) => existing | ||
case None => | ||
val newMethod = taskMetrics.getClass.getMethod("outputMetrics") | ||
newMethod.setAccessible(true) | ||
outputMetricsMethod.set(newMethod) | ||
newMethod | ||
} | ||
|
||
val outputMetrics = method.invoke(taskMetrics) | ||
Option(outputMetrics) | ||
} catch { | ||
case e: Exception => | ||
logInfo(s"Could not invoke getOutputMetrics via reflection - Error ${e.getMessage}", e) | ||
|
||
// reflection failed - disabling it for the future | ||
reflectionAccessAllowed.set(false) | ||
None | ||
} | ||
} | ||
|
||
private def setBytesWritten(outputMetrics: Object, metricValue: Object): Unit = { | ||
try { | ||
val method = Option(setBytesWrittenMethod.get) match { | ||
case Some(existing) => existing | ||
case None => | ||
val newMethod = outputMetrics.getClass.getMethod("setBytesWritten", java.lang.Long.TYPE) | ||
newMethod.setAccessible(true) | ||
setBytesWrittenMethod.set(newMethod) | ||
newMethod | ||
} | ||
|
||
method.invoke(outputMetrics, metricValue) | ||
} catch { | ||
case e: Exception => | ||
logInfo(s"Could not invoke setBytesWritten via reflection - Error ${e.getMessage}", e) | ||
|
||
// reflection failed - disabling it for the future | ||
reflectionAccessAllowed.set(false) | ||
} | ||
} | ||
|
||
private def setRecordsWritten(outputMetrics: Object, metricValue: Object): Unit = { | ||
try { | ||
val method = Option(setRecordsWrittenMethod.get) match { | ||
case Some(existing) => existing | ||
case None => | ||
val newMethod = outputMetrics.getClass.getMethod("setRecordsWritten", java.lang.Long.TYPE) | ||
newMethod.setAccessible(true) | ||
setRecordsWrittenMethod.set(newMethod) | ||
newMethod | ||
} | ||
method.invoke(outputMetrics, metricValue) | ||
} catch { | ||
case e: Exception => | ||
logInfo(s"Could not invoke setRecordsWritten via reflection - Error ${e.getMessage}", e) | ||
|
||
// reflection failed - disabling it for the future | ||
reflectionAccessAllowed.set(false) | ||
} | ||
} | ||
|
||
def updateInternalTaskMetrics(recordsWrittenSnapshot: Long, bytesWrittenSnapshot: Long): Unit = { | ||
if (reflectionAccessAllowed.get) { | ||
Option(TaskContext.get()) match { | ||
case Some(taskContext) => | ||
getOutputMetrics(taskContext) match { | ||
case Some(outputMetrics) => | ||
setRecordsWritten(outputMetrics, recordsWrittenSnapshot.asInstanceOf[Object]) | ||
setBytesWritten(outputMetrics, bytesWrittenSnapshot.asInstanceOf[Object]) | ||
case None => | ||
} | ||
} | ||
} | ||
} | ||
} | ||
//scalastyle:on multiple.string.literals |
18 changes: 0 additions & 18 deletions
18
...os/azure-cosmos-spark_3-1_2-12/src/main/scala/org/apache/spark/SparkInternalsBridge.scala
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
196 changes: 196 additions & 0 deletions
196
...re-cosmos-spark_3-2_2-12/src/main/scala/com/azure/cosmos/spark/SparkInternalsBridge.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,196 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
package com.azure.cosmos.spark | ||
|
||
import com.azure.cosmos.implementation.guava25.base.MoreObjects.firstNonNull | ||
import com.azure.cosmos.implementation.guava25.base.Strings.emptyToNull | ||
import com.azure.cosmos.spark.diagnostics.BasicLoggingTrait | ||
import org.apache.spark.TaskContext | ||
import org.apache.spark.sql.connector.metric.CustomTaskMetric | ||
import org.apache.spark.sql.execution.metric.SQLMetric | ||
import org.apache.spark.util.AccumulatorV2 | ||
|
||
import java.lang.reflect.Method | ||
import java.util.Locale | ||
import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference} | ||
import scala.collection.mutable.ArrayBuffer | ||
|
||
//scalastyle:off multiple.string.literals | ||
object SparkInternalsBridge extends BasicLoggingTrait { | ||
private val SPARK_REFLECTION_ACCESS_ALLOWED_PROPERTY = "COSMOS.SPARK_REFLECTION_ACCESS_ALLOWED" | ||
private val SPARK_REFLECTION_ACCESS_ALLOWED_VARIABLE = "COSMOS_SPARK_REFLECTION_ACCESS_ALLOWED" | ||
|
||
private val DEFAULT_SPARK_REFLECTION_ACCESS_ALLOWED = true | ||
val NUM_ROWS_PER_UPDATE = 100 | ||
|
||
private val BUILTIN_OUTPUT_METRICS = Set("bytesWritten", "recordsWritten") | ||
|
||
private val accumulatorsMethod : AtomicReference[Method] = new AtomicReference[Method]() | ||
private val outputMetricsMethod : AtomicReference[Method] = new AtomicReference[Method]() | ||
private val setBytesWrittenMethod : AtomicReference[Method] = new AtomicReference[Method]() | ||
private val setRecordsWrittenMethod : AtomicReference[Method] = new AtomicReference[Method]() | ||
|
||
private def getSparkReflectionAccessAllowed: Boolean = { | ||
val allowedText = System.getProperty( | ||
SPARK_REFLECTION_ACCESS_ALLOWED_PROPERTY, | ||
firstNonNull( | ||
emptyToNull(System.getenv.get(SPARK_REFLECTION_ACCESS_ALLOWED_VARIABLE)), | ||
String.valueOf(DEFAULT_SPARK_REFLECTION_ACCESS_ALLOWED))) | ||
|
||
try { | ||
java.lang.Boolean.valueOf(allowedText.toUpperCase(Locale.ROOT)) | ||
} | ||
catch { | ||
case e: Exception => | ||
logError(s"Parsing spark reflection access allowed $allowedText failed. Using the default $DEFAULT_SPARK_REFLECTION_ACCESS_ALLOWED.", e) | ||
DEFAULT_SPARK_REFLECTION_ACCESS_ALLOWED | ||
} | ||
} | ||
|
||
private final lazy val reflectionAccessAllowed = new AtomicBoolean(getSparkReflectionAccessAllowed) | ||
|
||
def getInternalCustomTaskMetricsAsSQLMetric(knownCosmosMetricNames: Set[String]): Map[String, SQLMetric] = { | ||
if (!reflectionAccessAllowed.get) { | ||
Map.empty[String, SQLMetric] | ||
} else { | ||
Option.apply(TaskContext.get()) match { | ||
case Some(taskCtx) => getInternalCustomTaskMetricsAsSQLMetricInternal(knownCosmosMetricNames, taskCtx) | ||
case None => Map.empty[String, SQLMetric] | ||
} | ||
} | ||
} | ||
|
||
private def getAccumulators(taskCtx: TaskContext): Option[ArrayBuffer[AccumulatorV2[_, _]]] = { | ||
try { | ||
val taskMetrics: Object = taskCtx.taskMetrics() | ||
val method = Option(accumulatorsMethod.get) match { | ||
case Some(existing) => existing | ||
case None => | ||
val newMethod = taskMetrics.getClass.getMethod("externalAccums") | ||
newMethod.setAccessible(true) | ||
accumulatorsMethod.set(newMethod) | ||
newMethod | ||
} | ||
|
||
val accums = method.invoke(taskMetrics).asInstanceOf[ArrayBuffer[AccumulatorV2[_, _]]] | ||
|
||
Some(accums) | ||
} catch { | ||
case e: Exception => | ||
logInfo(s"Could not invoke getAccumulators via reflection - Error ${e.getMessage}", e) | ||
|
||
// reflection failed - disabling it for the future | ||
reflectionAccessAllowed.set(false) | ||
None | ||
} | ||
} | ||
|
||
private def getOutputMetrics(taskCtx: TaskContext): Option[Object] = { | ||
try { | ||
val taskMetrics: Object = taskCtx.taskMetrics() | ||
|
||
val method = Option(outputMetricsMethod.get) match { | ||
case Some(existing) => existing | ||
case None => | ||
val newMethod = taskMetrics.getClass.getMethod("outputMetrics") | ||
newMethod.setAccessible(true) | ||
outputMetricsMethod.set(newMethod) | ||
newMethod | ||
} | ||
|
||
val outputMetrics = method.invoke(taskMetrics) | ||
Option(outputMetrics) | ||
} catch { | ||
case e: Exception => | ||
logInfo(s"Could not invoke getOutputMetrics via reflection - Error ${e.getMessage}", e) | ||
|
||
// reflection failed - disabling it for the future | ||
reflectionAccessAllowed.set(false) | ||
None | ||
} | ||
} | ||
|
||
private def setBytesWritten(outputMetrics: Object, metricValue: Object): Unit = { | ||
try { | ||
val method = Option(setBytesWrittenMethod.get) match { | ||
case Some(existing) => existing | ||
case None => | ||
val newMethod = outputMetrics.getClass.getMethod("setBytesWritten", java.lang.Long.TYPE) | ||
newMethod.setAccessible(true) | ||
setBytesWrittenMethod.set(newMethod) | ||
newMethod | ||
} | ||
|
||
method.invoke(outputMetrics, metricValue) | ||
} catch { | ||
case e: Exception => | ||
logInfo(s"Could not invoke setBytesWritten via reflection - Error ${e.getMessage}", e) | ||
|
||
// reflection failed - disabling it for the future | ||
reflectionAccessAllowed.set(false) | ||
} | ||
} | ||
|
||
private def setRecordsWritten(outputMetrics: Object, metricValue: Object): Unit = { | ||
try { | ||
val method = Option(setRecordsWrittenMethod.get) match { | ||
case Some(existing) => existing | ||
case None => | ||
val newMethod = outputMetrics.getClass.getMethod("setRecordsWritten", java.lang.Long.TYPE) | ||
newMethod.setAccessible(true) | ||
setRecordsWrittenMethod.set(newMethod) | ||
newMethod | ||
} | ||
method.invoke(outputMetrics, metricValue) | ||
} catch { | ||
case e: Exception => | ||
logInfo(s"Could not invoke setRecordsWritten via reflection - Error ${e.getMessage}", e) | ||
|
||
// reflection failed - disabling it for the future | ||
reflectionAccessAllowed.set(false) | ||
} | ||
} | ||
|
||
private def getInternalCustomTaskMetricsAsSQLMetricInternal( | ||
knownCosmosMetricNames: Set[String], | ||
taskCtx: TaskContext): Map[String, SQLMetric] = { | ||
getAccumulators(taskCtx) match { | ||
case Some(accumulators) => accumulators | ||
.filter(accumulable => accumulable.isInstanceOf[SQLMetric] | ||
&& accumulable.name.isDefined | ||
&& knownCosmosMetricNames.contains(accumulable.name.get)) | ||
.map(accumulable => { | ||
val sqlMetric = accumulable.asInstanceOf[SQLMetric] | ||
sqlMetric.name.get -> sqlMetric | ||
}) | ||
.toMap[String, SQLMetric] | ||
case None => Map.empty[String, SQLMetric] | ||
} | ||
} | ||
|
||
def updateInternalTaskMetrics(currentMetricsValues: Seq[CustomTaskMetric]): Unit = { | ||
if (reflectionAccessAllowed.get) { | ||
currentMetricsValues.foreach { metric => | ||
val metricName = metric.name() | ||
val metricValue = metric.value() | ||
|
||
if (BUILTIN_OUTPUT_METRICS.contains(metricName)) { | ||
Option(TaskContext.get()).map(getOutputMetrics).foreach { outputMetricsOption => | ||
|
||
outputMetricsOption match { | ||
case Some(outputMetrics) => | ||
|
||
metricName match { | ||
case "bytesWritten" => setBytesWritten(outputMetrics, metricValue.asInstanceOf[Object]) | ||
case "recordsWritten" => setRecordsWritten(outputMetrics, metricValue.asInstanceOf[Object]) | ||
case _ => // no-op | ||
} | ||
case None => | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} | ||
//scalastyle:on multiple.string.literals |
Oops, something went wrong.