diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/Analysis.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/Analysis.scala index d8d7b7f62..3c3966f8d 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/Analysis.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/Analysis.scala @@ -29,7 +29,7 @@ class Analysis(apps: Seq[ApplicationInfo]) { def getDurations(tcs: ArrayBuffer[TaskCase]): (Long, Long, Long, Double) = { val durations = tcs.map(_.duration) - if (durations.size > 0 ) { + if (durations.nonEmpty ) { (durations.sum, durations.max, durations.min, ToolUtils.calculateAverage(durations.sum, durations.size, 1)) } else { @@ -49,10 +49,9 @@ class Analysis(apps: Seq[ApplicationInfo]) { def jobAndStageMetricsAggregation(): Seq[JobStageAggTaskMetricsProfileResult] = { val allJobRows = apps.flatMap { app => app.jobIdToInfo.map { case (id, jc) => - val stageIdsInJob = jc.stageIds val stagesInJob = app.stageIdToInfo.filterKeys { case (sid, _) => - stageIdsInJob.contains(sid) - }.keys.map(_._1).toSeq + jc.stageIds.contains(sid) + }.keys.map(_._1).toSet if (stagesInJob.isEmpty) { None } else { @@ -60,11 +59,10 @@ class Analysis(apps: Seq[ApplicationInfo]) { stagesInJob.contains(tc.stageId) } // count duplicate task attempts - val numTaskAttempt = tasksInJob.size val (durSum, durMax, durMin, durAvg) = getDurations(tasksInJob) Some(JobStageAggTaskMetricsProfileResult(app.index, s"job_$id", - numTaskAttempt, + tasksInJob.size, jc.duration, tasksInJob.map(_.diskBytesSpilled).sum, durSum, @@ -100,9 +98,8 @@ class Analysis(apps: Seq[ApplicationInfo]) { } val allJobStageRows = apps.flatMap { app => app.jobIdToInfo.flatMap { case (_, jc) => - val stageIdsInJob = jc.stageIds val stagesInJob = app.stageIdToInfo.filterKeys { case (sid, _) => - stageIdsInJob.contains(sid) + jc.stageIds.contains(sid) } if (stagesInJob.isEmpty) { None @@ -111,12 +108,10 @@ class Analysis(apps: Seq[ApplicationInfo]) { val tasksInStage = app.taskEnd.filter { tc => tc.stageId == id } - // count duplicate task attempts - val numAttempts = tasksInStage.size val (durSum, durMax, durMin, durAvg) = getDurations(tasksInStage) Some(JobStageAggTaskMetricsProfileResult(app.index, s"stage_$id", - numAttempts, + tasksInStage.size, sc.duration, tasksInStage.map(_.diskBytesSpilled).sum, durSum, @@ -153,17 +148,16 @@ class Analysis(apps: Seq[ApplicationInfo]) { } // stages that are missing from a job, perhaps dropped events val stagesWithoutJobs = apps.flatMap { app => - val allStageinJobs = app.jobIdToInfo.flatMap { case (_, jc) => - val stageIdsInJob = jc.stageIds + val allStageInJobs = app.jobIdToInfo.flatMap { case (_, jc) => app.stageIdToInfo.filterKeys { case (sid, _) => - stageIdsInJob.contains(sid) + jc.stageIds.contains(sid) } } - val missing = app.stageIdToInfo.keys.toSeq.diff(allStageinJobs.keys.toSeq) + val missing = app.stageIdToInfo.keys.toSet.diff(allStageInJobs.keys.toSet) if (missing.isEmpty) { Seq.empty } else { - missing.map { case ((id, saId)) => + missing.map { case (id, saId) => val scOpt = app.stageIdToInfo.get((id, saId)) scOpt match { case None => @@ -214,11 +208,11 @@ class Analysis(apps: Seq[ApplicationInfo]) { } val allRows = allJobRows ++ allJobStageRows ++ stagesWithoutJobs - val filteredRows = allRows.filter(_.isDefined).map(_.get) - if (filteredRows.size > 0) { + val filteredRows = allRows.flatMap(row => row) + if (filteredRows.nonEmpty) { val sortedRows = filteredRows.sortBy { cols => val sortDur = cols.duration.getOrElse(0L) - (cols.appIndex, -(sortDur), cols.id) + (cols.appIndex, -sortDur, cols.id) } sortedRows } else { @@ -231,12 +225,12 @@ class Analysis(apps: Seq[ApplicationInfo]) { val allRows = apps.flatMap { app => app.sqlIdToInfo.map { case (sqlId, sqlCase) => val jcs = app.jobIdToInfo.filter { case (_, jc) => - jc.sqlID.getOrElse(-1) == sqlId + jc.sqlID.isDefined && jc.sqlID.get == sqlId } if (jcs.isEmpty) { None } else { - val stageIdsForSQL = jcs.flatMap(_._2.stageIds).toSeq + val stageIdsForSQL = jcs.flatMap(_._2.stageIds).toSet val tasksInSQL = app.taskEnd.filter { tc => stageIdsForSQL.contains(tc.stageId) } @@ -298,7 +292,7 @@ class Analysis(apps: Seq[ApplicationInfo]) { } } } - val allFiltered = allRows.filter(_.isDefined).map(_.get) + val allFiltered = allRows.flatMap(row => row) if (allFiltered.size > 0) { val sortedRows = allFiltered.sortBy { cols => val sortDur = cols.duration.getOrElse(0L) @@ -314,12 +308,12 @@ class Analysis(apps: Seq[ApplicationInfo]) { val allRows = apps.flatMap { app => app.sqlIdToInfo.map { case (sqlId, _) => val jcs = app.jobIdToInfo.filter { case (_, jc) => - jc.sqlID.getOrElse(-1) == sqlId + jc.sqlID.isDefined && jc.sqlID.get == sqlId } if (jcs.isEmpty) { None } else { - val stageIdsForSQL = jcs.flatMap(_._2.stageIds).toSeq + val stageIdsForSQL = jcs.flatMap(_._2.stageIds).toSet val tasksInSQL = app.taskEnd.filter { tc => stageIdsForSQL.contains(tc.stageId) @@ -344,7 +338,7 @@ class Analysis(apps: Seq[ApplicationInfo]) { } } } - val allFiltered = allRows.filter(_.isDefined).map(_.get) + val allFiltered = allRows.flatMap(row => row) if (allFiltered.size > 0) { val sortedRows = allFiltered.sortBy { cols => (cols.appIndex, cols.sqlId) @@ -359,12 +353,12 @@ class Analysis(apps: Seq[ApplicationInfo]) { apps.map { app => val maxOfSqls = app.sqlIdToInfo.map { case (sqlId, _) => val jcs = app.jobIdToInfo.filter { case (_, jc) => - jc.sqlID.getOrElse(-1) == sqlId + jc.sqlID.isDefined && jc.sqlID.get == sqlId } if (jcs.isEmpty) { 0L } else { - val stageIdsForSQL = jcs.flatMap(_._2.stageIds).toSeq + val stageIdsForSQL = jcs.flatMap(_._2.stageIds).toSet val tasksInSQL = app.taskEnd.filter { tc => stageIdsForSQL.contains(tc.stageId) } @@ -394,7 +388,7 @@ class Analysis(apps: Seq[ApplicationInfo]) { sqlCase.sqlCpuTimePercent) } } - if (allRows.size > 0) { + if (allRows.nonEmpty) { val sortedRows = allRows.sortBy { cols => val sortDur = cols.duration.getOrElse(0L) (cols.appIndex, cols.sqlID, sortDur) @@ -443,8 +437,8 @@ class Analysis(apps: Seq[ApplicationInfo]) { } } - val allNonEmptyRows = allRows.filter(_.isDefined).map(_.get) - if (allNonEmptyRows.size > 0) { + val allNonEmptyRows = allRows.flatMap(row => row) + if (allNonEmptyRows.nonEmpty) { val sortedRows = allNonEmptyRows.sortBy { cols => (cols.appIndex, cols.stageId, cols.stageAttemptId, cols.taskId, cols.taskAttemptId) } diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/AutoTuner.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/AutoTuner.scala index 7a30a80a0..d564139d5 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/AutoTuner.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/AutoTuner.scala @@ -856,8 +856,8 @@ class AutoTuner( /** * Recommendation for 'spark.sql.files.maxPartitionBytes' based on input size for each task. - * Note that the logic can be disabled by adding the property to [[limitedLogicRecommendations]] - * which is one of the arguments of [[getRecommendedProperties()]]. + * Note that the logic can be disabled by adding the property to "limitedLogicRecommendations" + * which is one of the arguments of [[getRecommendedProperties]]. */ private def recommendMaxPartitionBytes(): Unit = { val maxPartitionProp = @@ -873,8 +873,8 @@ class AutoTuner( /** * Recommendations for 'spark.sql.shuffle.partitions' based on spills and skew in shuffle stages. - * Note that the logic can be disabled by adding the property to [[limitedLogicRecommendations]] - * which is one of the arguments of [[getRecommendedProperties()]]. + * Note that the logic can be disabled by adding the property to "limitedLogicRecommendations" + * which is one of the arguments of [[getRecommendedProperties]]. */ def recommendShufflePartitions(): Unit = { val lookup = "spark.sql.shuffle.partitions" diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/CollectInformation.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/CollectInformation.scala index 55ce1037c..3049dfac4 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/CollectInformation.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/CollectInformation.scala @@ -236,7 +236,7 @@ class CollectInformation(apps: Seq[ApplicationInfo]) extends Logging { CollectInformation.addNewProps(propsToKeep, props, numApps) } val allRows = props.map { case (k, v) => Seq(k) ++ v }.toSeq - if (allRows.size > 0) { + if (allRows.nonEmpty) { val resRows = allRows.map(r => RapidsPropertyProfileResult(r(0), outputHeaders, r)) resRows.sortBy(cols => cols.key) } else { @@ -259,7 +259,7 @@ class CollectInformation(apps: Seq[ApplicationInfo]) extends Logging { val allWholeStages = apps.flatMap { app => app.wholeStage } - if (allWholeStages.size > 0) { + if (allWholeStages.nonEmpty) { allWholeStages.sortBy(cols => (cols.appIndex, cols.sqlID, cols.nodeID)) } else { Seq.empty @@ -269,7 +269,7 @@ class CollectInformation(apps: Seq[ApplicationInfo]) extends Logging { // Print SQL Plan Metrics def getSQLPlanMetrics: Seq[SQLAccumProfileResults] = { val sqlAccums = CollectInformation.generateSQLAccums(apps) - if (sqlAccums.size > 0) { + if (sqlAccums.nonEmpty) { sqlAccums.sortBy(cols => (cols.appIndex, cols.sqlID, cols.nodeID, cols.nodeName, cols.accumulatorId, cols.metricType)) } else { @@ -286,11 +286,11 @@ object CollectInformation extends Logging { def generateSQLAccums(apps: Seq[ApplicationInfo]): Seq[SQLAccumProfileResults] = { val allRows = apps.flatMap { app => app.allSQLMetrics.map { metric => - val sqlId = metric.sqlID val jobsForSql = app.jobIdToInfo.filter { case (_, jc) => - jc.sqlID.getOrElse(-1) == sqlId + // Avoid getOrElse to reduce memory allocations + jc.sqlID.isDefined && jc.sqlID.get == metric.sqlID } - val stageIdsForSQL = jobsForSql.flatMap(_._2.stageIds).toSeq + val stageIdsForSQL = jobsForSql.flatMap(_._2.stageIds).toSet val accumsOpt = app.taskStageAccumMap.get(metric.accumulatorId) val taskMax = accumsOpt match { case Some(accums) => @@ -326,7 +326,7 @@ object CollectInformation extends Logging { val driverMax = driverAccumsOpt match { case Some(accums) => val filtered = accums.filter { a => - a.sqlID == sqlId + a.sqlID == metric.sqlID } val accumValues = filtered.map(_.value).sortWith(_ < _) if (accumValues.isEmpty) { diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/GenerateDot.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/GenerateDot.scala index afae57c59..1453429d8 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/GenerateDot.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/GenerateDot.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -144,7 +144,7 @@ case class QueryPlanWithMetrics(plan: SparkPlanInfoWithStage, metrics: Map[Long, * Each graph is defined with a set of nodes and a set of edges. Each node represents a node in the * SparkPlan tree, and each edge represents a parent-child relationship between two nodes. */ -case class SparkPlanGraph( +case class SparkPlanGraphForDot( nodes: Seq[SparkPlanGraphNode], edges: Seq[SparkPlanGraphEdge], appId: String, @@ -187,14 +187,14 @@ object SparkPlanGraph { appId: String, sqlId: String, physicalPlan: String, - stageIdToStageMetrics: Map[Int, StageMetrics]): SparkPlanGraph = { + stageIdToStageMetrics: Map[Int, StageMetrics]): SparkPlanGraphForDot = { val nodeIdGenerator = new AtomicLong(0) val nodes = mutable.ArrayBuffer[SparkPlanGraphNode]() val edges = mutable.ArrayBuffer[SparkPlanGraphEdge]() val exchanges = mutable.HashMap[SparkPlanInfoWithStage, SparkPlanGraphNode]() buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges, null, null, null, exchanges, stageIdToStageMetrics) - new SparkPlanGraph(nodes, edges, appId, sqlId, physicalPlan) + SparkPlanGraphForDot(nodes, edges, appId, sqlId, physicalPlan) } @tailrec diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/ProfileClassWarehouse.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/ProfileClassWarehouse.scala index 47f233fd5..088294ac6 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/ProfileClassWarehouse.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/ProfileClassWarehouse.scala @@ -176,7 +176,11 @@ class SQLExecutionInfoClass( var duration: Option[Long], var hasDatasetOrRDD: Boolean, var problematic: String = "", - var sqlCpuTimePercent: Double = -1) + var sqlCpuTimePercent: Double = -1) { + def setDsOrRdd(value: Boolean): Unit = { + hasDatasetOrRDD = value + } +} case class SQLAccumProfileResults(appIndex: Int, sqlID: Long, nodeID: Long, nodeName: String, accumulatorId: Long, name: String, min: Long, median:Long, diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/Profiler.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/Profiler.scala index 5b0f6cc47..37fbdd6b2 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/Profiler.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/Profiler.scala @@ -306,7 +306,7 @@ class Profiler(hadoopConf: Configuration, appArgs: ProfileArgs, enablePB: Boolea val app = new ApplicationInfo(hadoopConf, path, index) EventLogPathProcessor.logApplicationInfo(app) val endTime = System.currentTimeMillis() - logInfo(s"Took ${endTime - startTime}ms to process ${path.eventLog.toString}") + logInfo(s"Took ${endTime - startTime}ms to create App for ${path.eventLog.toString}") Some(app) } catch { case _: com.fasterxml.jackson.core.JsonParseException => @@ -327,9 +327,12 @@ class Profiler(hadoopConf: Configuration, appArgs: ProfileArgs, enablePB: Boolea * and returns the summary information. The summary information is much smaller than * the ApplicationInfo because it has processed and combined many of the raw events. */ - private def processApps(apps: Seq[ApplicationInfo], printPlans: Boolean, - profileOutputWriter: ProfileOutputWriter): (ApplicationSummaryInfo, - Option[CompareSummaryInfo]) = { + private def processApps( + apps: Seq[ApplicationInfo], + printPlans: Boolean, + profileOutputWriter: ProfileOutputWriter) + : (ApplicationSummaryInfo, Option[CompareSummaryInfo]) = { + val startTime = System.currentTimeMillis() val collect = new CollectInformation(apps) val appInfo = collect.getAppInfo @@ -403,7 +406,9 @@ class Profiler(hadoopConf: Configuration, appArgs: ProfileArgs, enablePB: Boolea s"to $outputDir in $duration second(s)\n") } } - (ApplicationSummaryInfo(appInfo, dsInfo, execInfo, jobInfo, rapidsProps, + val endTime = System.currentTimeMillis() + logInfo(s"Took ${endTime - startTime}ms to Process [${appInfo.head.appId}]") + (ApplicationSummaryInfo(appInfo, dsInfo, execInfo, jobInfo, rapidsProps, rapidsJar, sqlMetrics, jsMetAgg, sqlTaskAggMetrics, durAndCpuMet, skewInfo, failedTasks, failedStages, failedJobs, removedBMs, removedExecutors, unsupportedOps, sparkProps, sqlStageInfo, wholeStage, maxTaskInputInfo, diff --git a/core/src/main/scala/org/apache/spark/sql/rapids/tool/AppBase.scala b/core/src/main/scala/org/apache/spark/sql/rapids/tool/AppBase.scala index 20f197739..6f0ad9d98 100644 --- a/core/src/main/scala/org/apache/spark/sql/rapids/tool/AppBase.scala +++ b/core/src/main/scala/org/apache/spark/sql/rapids/tool/AppBase.scala @@ -331,7 +331,7 @@ abstract class AppBase( // at this point all paths should be valid event logs or event log dirs val hconf = hadoopConf.getOrElse(RapidsToolsConfUtil.newHadoopConf) val fs = eventLogPath.getFileSystem(hconf) - var totalNumEvents = 0 + var totalNumEvents = 0L val readerOpt = eventLog match { case _: DatabricksEventLog => Some(new DatabricksRollingEventLogFilesFileReader(fs, eventLogPath)) @@ -340,15 +340,14 @@ abstract class AppBase( if (readerOpt.isDefined) { val reader = readerOpt.get - val logFiles = reader.listEventLogFiles - logFiles.foreach { file => + val runtimeGetFromJsonMethod = EventUtils.getEventFromJsonMethod + reader.listEventLogFiles.foreach { file => Utils.tryWithResource(openEventLogInternal(file.getPath, fs)) { in => - val lines = Source.fromInputStream(in)(Codec.UTF8).getLines().toIterator - // Using find as foreach with conditional to exit early if we are done. - // Do NOT use a while loop as it is much much slower. - lines.find { line => + Source.fromInputStream(in)(Codec.UTF8).getLines().find { line => + // Using find as foreach with conditional to exit early if we are done. + // Do NOT use a while loop as it is much much slower. totalNumEvents += 1 - EventUtils.getEventFromJsonMethod(line) match { + runtimeGetFromJsonMethod.apply(line) match { case Some(e) => processEvent(e) case None => false } @@ -413,11 +412,12 @@ abstract class AppBase( } // The ReadSchema metadata is only in the eventlog for DataSource V1 readers - protected def checkMetadataForReadSchema(sqlID: Long, planInfo: SparkPlanInfo): Unit = { + protected def checkMetadataForReadSchema( + sqlPlanInfoGraph: SqlPlanInfoGraphEntry): ArrayBuffer[DataSourceCase] = { // check if planInfo has ReadSchema - val allMetaWithSchema = getPlanMetaWithSchema(planInfo) - val planGraph = ToolsPlanGraph(planInfo) - val allNodes = planGraph.allNodes + val allMetaWithSchema = getPlanMetaWithSchema(sqlPlanInfoGraph.planInfo) + val allNodes = sqlPlanInfoGraph.sparkPlanGraph.allNodes + val results = ArrayBuffer[DataSourceCase]() allMetaWithSchema.foreach { plan => val meta = plan.metadata @@ -432,7 +432,8 @@ abstract class AppBase( // add it to the dataSourceInfo // Processing Photon eventlogs issue: https://github.com/NVIDIA/spark-rapids-tools/issues/251 if (scanNode.nonEmpty) { - dataSourceInfo += DataSourceCase(sqlID, + results += DataSourceCase( + sqlPlanInfoGraph.sqlID, scanNode.head.id, meta.getOrElse("Format", "unknown"), meta.getOrElse("Location", "unknown"), @@ -444,12 +445,13 @@ abstract class AppBase( // "scan hive" has no "ReadSchema" defined. So, we need to look explicitly for nodes // that are scan hive and add them one by one to the dataSource if (hiveEnabled) { // only scan for hive when the CatalogImplementation is using hive - val allPlanWithHiveScan = getPlanInfoWithHiveScan(planInfo) + val allPlanWithHiveScan = getPlanInfoWithHiveScan(sqlPlanInfoGraph.planInfo) allPlanWithHiveScan.foreach { hiveReadPlan => val sqlGraph = ToolsPlanGraph(hiveReadPlan) val hiveScanNode = sqlGraph.allNodes.head val scanHiveMeta = HiveParseHelper.parseReadNode(hiveScanNode) - dataSourceInfo += DataSourceCase(sqlID, + results += DataSourceCase( + sqlPlanInfoGraph.sqlID, hiveScanNode.id, scanHiveMeta.format, scanHiveMeta.location, @@ -458,26 +460,32 @@ abstract class AppBase( ) } } + dataSourceInfo ++= results + results } // This will find scans for DataSource V2, if the schema is very large it // will likely be incomplete and have ... at the end. - protected def checkGraphNodeForReads(sqlID: Long, node: SparkPlanGraphNode): Unit = { + protected def checkGraphNodeForReads( + sqlID: Long, node: SparkPlanGraphNode): Option[DataSourceCase] = { if (ReadParser.isDataSourceV2Node(node)) { val res = ReadParser.parseReadNode(node) - - dataSourceInfo += DataSourceCase(sqlID, + val dsCase = DataSourceCase( + sqlID, node.id, res.format, res.location, res.filters, - res.schema - ) + res.schema) + dataSourceInfo += dsCase + Some(dsCase) + } else { + None } } protected def reportComplexTypes: (Seq[String], Seq[String]) = { - if (dataSourceInfo.size != 0) { + if (dataSourceInfo.nonEmpty) { val schema = dataSourceInfo.map { ds => ds.schema } AppBase.parseReadSchemaForNestedTypes(schema) } else { diff --git a/core/src/main/scala/org/apache/spark/sql/rapids/tool/ToolUtils.scala b/core/src/main/scala/org/apache/spark/sql/rapids/tool/ToolUtils.scala index f9143bc3a..f2f47087f 100644 --- a/core/src/main/scala/org/apache/spark/sql/rapids/tool/ToolUtils.scala +++ b/core/src/main/scala/org/apache/spark/sql/rapids/tool/ToolUtils.scala @@ -16,6 +16,7 @@ package org.apache.spark.sql.rapids.tool +import scala.collection.mutable.ArrayBuffer import scala.util.{Failure, Success, Try} import scala.util.control.NonFatal @@ -28,7 +29,9 @@ import org.json4s.jackson.JsonMethods.parse import org.apache.spark.internal.{config, Logging} import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.execution.ui.SparkPlanGraphNode +import org.apache.spark.sql.execution.SparkPlanInfo +import org.apache.spark.sql.execution.ui.{SparkPlanGraph, SparkPlanGraphNode} +import org.apache.spark.sql.rapids.tool.util.ToolsPlanGraph object ToolUtils extends Logging { // List of recommended file-encodings on the GPUs. @@ -472,3 +475,29 @@ object SupportedMLFuncsName { } case class GpuEventLogException(message: String) extends Exception(message) + +// Class used a container to hold the information of the Tuple +// to simplify arguments of methods and caching. +case class SqlPlanInfoGraphEntry( + sqlID: Long, + planInfo: SparkPlanInfo, + sparkPlanGraph: SparkPlanGraph +) + +// A class used to cache the SQLPlanInfoGraphs +class SqlPlanInfoGraphBuffer { + val sqlPlanInfoGraphs = ArrayBuffer[SqlPlanInfoGraphEntry]() + def addSqlPlanInfoGraph(sqlID: Long, planInfo: SparkPlanInfo): SqlPlanInfoGraphEntry = { + val newEntry = SqlPlanInfoGraphBuffer.createEntry(sqlID, planInfo) + sqlPlanInfoGraphs += newEntry + newEntry + } +} + +object SqlPlanInfoGraphBuffer { + def apply(): SqlPlanInfoGraphBuffer = new SqlPlanInfoGraphBuffer() + def createEntry(sqlID: Long, planInfo: SparkPlanInfo): SqlPlanInfoGraphEntry = { + val planGraph = ToolsPlanGraph(planInfo) + SqlPlanInfoGraphEntry(sqlID, planInfo, planGraph) + } +} diff --git a/core/src/main/scala/org/apache/spark/sql/rapids/tool/profiling/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/sql/rapids/tool/profiling/ApplicationInfo.scala index fca67ea98..990cdc1f5 100644 --- a/core/src/main/scala/org/apache/spark/sql/rapids/tool/profiling/ApplicationInfo.scala +++ b/core/src/main/scala/org/apache/spark/sql/rapids/tool/profiling/ApplicationInfo.scala @@ -28,7 +28,9 @@ import org.apache.spark.internal.Logging import org.apache.spark.scheduler._ import org.apache.spark.sql.execution.SparkPlanInfo import org.apache.spark.sql.execution.metric.SQLMetricInfo +import org.apache.spark.sql.execution.ui.SparkPlanGraph import org.apache.spark.sql.rapids.tool.{AppBase, RDDCheckHelper, ToolUtils} +import org.apache.spark.sql.rapids.tool.SqlPlanInfoGraphBuffer import org.apache.spark.sql.rapids.tool.util.ToolsPlanGraph import org.apache.spark.ui.UIUtils @@ -226,10 +228,14 @@ class ApplicationInfo( false } - // Connects Operators to Stages using AccumulatorIDs - def connectOperatorToStage(): Unit = { + /** + * Connects Operators to Stages using AccumulatorIDs + * @param cb function that creates a SparkPlanGraph. This can be used as a cacheHolder for the + * object created to be used later. + */ + private def connectOperatorToStage(cb: (Long, SparkPlanInfo) => SparkPlanGraph): Unit = { for ((sqlId, planInfo) <- sqlPlans) { - val planGraph = ToolsPlanGraph(planInfo) + val planGraph: SparkPlanGraph = cb.apply(sqlId, planInfo) // Maps stages to operators by checking for non-zero intersection // between nodeMetrics and stageAccumulateIDs val nodeIdToStage = planGraph.allNodes.map { node => @@ -244,66 +250,78 @@ class ApplicationInfo( * Function to process SQL Plan Metrics after all events are processed */ def processSQLPlanMetrics(): Unit = { - connectOperatorToStage() - for ((sqlID, planInfo) <- sqlPlans) { - checkMetadataForReadSchema(sqlID, planInfo) - val planGraph = ToolsPlanGraph(planInfo) - // SQLPlanMetric is a case Class of - // (name: String,accumulatorId: Long,metricType: String) - val allnodes = planGraph.allNodes - planGraph.nodes.foreach { n => - if (n.isInstanceOf[org.apache.spark.sql.execution.ui.SparkPlanGraphCluster]) { - val ch = n.asInstanceOf[org.apache.spark.sql.execution.ui.SparkPlanGraphCluster].nodes + // Define a buffer to cache the SQLPlanInfoGraphs + val sqlPlanInfoBuffer = SqlPlanInfoGraphBuffer() + // Define a function used to fill in the buffer while executing "connectOperatorToStage" + val createGraphFunc = (sqlId: Long, planInfo: SparkPlanInfo) => { + sqlPlanInfoBuffer.addSqlPlanInfoGraph(sqlId, planInfo).sparkPlanGraph + } + connectOperatorToStage(createGraphFunc) + for (sqlPIGEntry <- sqlPlanInfoBuffer.sqlPlanInfoGraphs) { + var sqlIsDsOrRDD = false + val potentialProblems = collection.mutable.Set[String]() + // store all datasources of the given SQL in a variable so that we won't have to iterate + // through the entire list + // get V1 dataSources for that SQLId + val sqlDataSources = checkMetadataForReadSchema(sqlPIGEntry) + for (node <- sqlPIGEntry.sparkPlanGraph.allNodes) { + var nodeIsDsOrRDD = false + if (node.isInstanceOf[org.apache.spark.sql.execution.ui.SparkPlanGraphCluster]) { + val ch = node.asInstanceOf[org.apache.spark.sql.execution.ui.SparkPlanGraphCluster].nodes ch.foreach { c => - wholeStage += WholeStageCodeGenResults(index, sqlID, n.id, n.name, c.name, c.id) + wholeStage += WholeStageCodeGenResults( + index, sqlPIGEntry.sqlID, node.id, node.name, c.name, c.id) } } - } - for (node <- allnodes) { - checkGraphNodeForReads(sqlID, node) - if (RDDCheckHelper.isDatasetOrRDDPlan(node.name, node.desc).isRDD) { - sqlIdToInfo.get(sqlID).foreach { sql => - sqlIDToDataSetOrRDDCase += sqlID - sql.hasDatasetOrRDD = true - } - if (gpuMode) { - val thisPlan = UnsupportedSQLPlan(sqlID, node.id, node.name, node.desc, + // get V2 dataSources for that node + val nodeV2Reads = checkGraphNodeForReads(sqlPIGEntry.sqlID, node) + if (nodeV2Reads.isDefined) { + sqlDataSources += nodeV2Reads.get + } + nodeIsDsOrRDD = RDDCheckHelper.isDatasetOrRDDPlan(node.name, node.desc).isRDD + if (nodeIsDsOrRDD) { + if (gpuMode) { // we want to report every node that is an RDD + val thisPlan = UnsupportedSQLPlan(sqlPIGEntry.sqlID, node.id, node.name, node.desc, "Contains Dataset or RDD") unsupportedSQLplan += thisPlan } + // If one node is RDD, the Sql should be set too + if (!sqlIsDsOrRDD) { // We need to set the flag only once for the given sqlID + sqlIsDsOrRDD = true + sqlIdToInfo.get(sqlPIGEntry.sqlID).foreach { sql => + sql.setDsOrRdd(sqlIsDsOrRDD) + sqlIDToDataSetOrRDDCase += sqlPIGEntry.sqlID + // Clear the potential problems since it is an RDD to free memory + potentialProblems.clear() + } + } } - - // find potential problems - val issues = findPotentialIssues(node.desc) - if (issues.nonEmpty) { - val existingIssues = sqlIDtoProblematic.getOrElse(sqlID, Set.empty[String]) - sqlIDtoProblematic(sqlID) = existingIssues ++ issues - } - val (_, nestedComplexTypes) = reportComplexTypes - val potentialProbs = getAllPotentialProblems(getPotentialProblemsForDf, nestedComplexTypes) - sqlIdToInfo.get(sqlID).foreach { sql => - sql.problematic = ToolUtils.formatPotentialProblems(potentialProbs) + if (!sqlIsDsOrRDD) { + // Append current node's potential problems to the Sql problems only if the SQL is not an + // RDD. This is an optimization since the potentialProblems won't be used any more. + potentialProblems ++= findPotentialIssues(node.desc) } - // Then process SQL plan metric type for (metric <- node.metrics) { - val stages = sqlPlanNodeIdToStageIds.get((sqlID, node.id)).getOrElse(Set.empty) - val allMetric = SQLMetricInfoCase(sqlID, metric.name, + val stages = + sqlPlanNodeIdToStageIds.get((sqlPIGEntry.sqlID, node.id)).getOrElse(Set.empty) + val allMetric = SQLMetricInfoCase(sqlPIGEntry.sqlID, metric.name, metric.accumulatorId, metric.metricType, node.id, node.name, node.desc, stages) allSQLMetrics += allMetric if (this.sqlPlanMetricsAdaptive.nonEmpty) { val adaptive = sqlPlanMetricsAdaptive.filter { adaptiveMetric => - adaptiveMetric.sqlID == sqlID && adaptiveMetric.accumulatorId == metric.accumulatorId + adaptiveMetric.sqlID == sqlPIGEntry.sqlID && + adaptiveMetric.accumulatorId == metric.accumulatorId } adaptive.foreach { adaptiveMetric => - val allMetric = SQLMetricInfoCase(sqlID, adaptiveMetric.name, + val allMetric = SQLMetricInfoCase(sqlPIGEntry.sqlID, adaptiveMetric.name, adaptiveMetric.accumulatorId, adaptiveMetric.metricType, node.id, node.name, node.desc, stages) // could make this more efficient but seems ok for now val exists = allSQLMetrics.filter { a => - ((a.accumulatorId == adaptiveMetric.accumulatorId) && (a.sqlID == sqlID) + ((a.accumulatorId == adaptiveMetric.accumulatorId) && (a.sqlID == sqlPIGEntry.sqlID) && (a.nodeID == node.id && adaptiveMetric.metricType == a.metricType)) } if (exists.isEmpty) { @@ -313,6 +331,19 @@ class ApplicationInfo( } } } + // Check if readsSchema is complex for the given sql + val sqlNestedComplexTypes = + AppBase.parseReadSchemaForNestedTypes(sqlDataSources.map { ds => ds.schema }) + // Append problematic issues to the global variable for that SqlID + if (sqlNestedComplexTypes._2.nonEmpty) { + potentialProblems += "NESTED COMPLEX TYPE" + } + // Finally, add the local potentialProblems to the global data structure if any. + sqlIDtoProblematic(sqlPIGEntry.sqlID) = potentialProblems.toSet + // Convert the problematic issues to a string and update the SQLInfo + sqlIdToInfo.get(sqlPIGEntry.sqlID).foreach { sqlInfoClass => + sqlInfoClass.problematic = ToolUtils.formatPotentialProblems(potentialProblems.toSeq) + } } } diff --git a/core/src/main/scala/org/apache/spark/sql/rapids/tool/qualification/QualificationAppInfo.scala b/core/src/main/scala/org/apache/spark/sql/rapids/tool/qualification/QualificationAppInfo.scala index 5abb8e669..5586f63d7 100644 --- a/core/src/main/scala/org/apache/spark/sql/rapids/tool/qualification/QualificationAppInfo.scala +++ b/core/src/main/scala/org/apache/spark/sql/rapids/tool/qualification/QualificationAppInfo.scala @@ -28,8 +28,8 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.internal.Logging import org.apache.spark.scheduler.{SparkListener, SparkListenerEnvironmentUpdate, SparkListenerEvent, SparkListenerJobStart} import org.apache.spark.sql.execution.SparkPlanInfo -import org.apache.spark.sql.rapids.tool.{AppBase, ClusterSummary, GpuEventLogException, SupportedMLFuncsName, ToolUtils} -import org.apache.spark.sql.rapids.tool.util.ToolsPlanGraph +import org.apache.spark.sql.rapids.tool.{AppBase, ClusterSummary, GpuEventLogException, SqlPlanInfoGraphBuffer, SupportedMLFuncsName, ToolUtils} + class QualificationAppInfo( eventLogInfo: Option[EventLogInfo], @@ -803,10 +803,9 @@ class QualificationAppInfo( } private[qualification] def processSQLPlan(sqlID: Long, planInfo: SparkPlanInfo): Unit = { - checkMetadataForReadSchema(sqlID, planInfo) - val planGraph = ToolsPlanGraph(planInfo) - val allnodes = planGraph.allNodes - for (node <- allnodes) { + val sqlPlanInfoGraphEntry = SqlPlanInfoGraphBuffer.createEntry(sqlID, planInfo) + checkMetadataForReadSchema(sqlPlanInfoGraphEntry) + for (node <- sqlPlanInfoGraphEntry.sparkPlanGraph.allNodes) { checkGraphNodeForReads(sqlID, node) val issues = findPotentialIssues(node.desc) if (issues.nonEmpty) { diff --git a/core/src/main/scala/org/apache/spark/sql/rapids/tool/util/EventUtils.scala b/core/src/main/scala/org/apache/spark/sql/rapids/tool/util/EventUtils.scala index 363b96178..45a7beafc 100644 --- a/core/src/main/scala/org/apache/spark/sql/rapids/tool/util/EventUtils.scala +++ b/core/src/main/scala/org/apache/spark/sql/rapids/tool/util/EventUtils.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.rapids.tool.util import java.lang.reflect.InvocationTargetException +import scala.collection.mutable import scala.util.{Failure, Success, Try} import scala.util.control.NonFatal @@ -32,6 +33,15 @@ import org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionStart * Utility containing the implementation of helpers used for parsing data from event. */ object EventUtils extends Logging { + // Set to keep track of missing classes + private val missingEventClasses = mutable.HashSet[String]() + + private def reportMissingEventClass(className: String): Unit = { + if (!missingEventClasses.contains(className)) { + missingEventClasses.add(className) + logWarning(s"ClassNotFoundException while parsing an event: ${className}") + } + } /** * Used to parse (value/update) fields of the AccumulableInfo object. If the data is not @@ -109,8 +119,7 @@ object EventUtils extends Logging { field } - lazy val getEventFromJsonMethod: - String => Option[org.apache.spark.scheduler.SparkListenerEvent] = { + private lazy val runtimeEventFromJsonMethod = { // Spark 3.4 and Databricks changed the signature on sparkEventFromJson // Note that it is preferred we use reflection rather than checking Spark-runtime // because some vendors may back-port features. @@ -128,15 +137,19 @@ object EventUtils extends Logging { (line: String) => b.invoke(null, line).asInstanceOf[org.apache.spark.scheduler.SparkListenerEvent] } + m + } + + lazy val getEventFromJsonMethod: + String => Option[org.apache.spark.scheduler.SparkListenerEvent] = { // At this point, the method is already defined. // Note that the Exception handling is moved within the method to make it easier // to isolate the exception reason. (line: String) => Try { - m.apply(line) + runtimeEventFromJsonMethod.apply(line) } match { case Success(i) => Some(i) case Failure(e) => - e match { case i: InvocationTargetException => val targetEx = i.getTargetException @@ -151,7 +164,8 @@ object EventUtils extends Logging { // malformed throw k case z: ClassNotFoundException if z.getMessage != null => - logWarning(s"ClassNotFoundException while parsing an event: ${z.getMessage}") + // Avoid reporting missing classes more than once to reduce the noise in the logs + reportMissingEventClass(z.getMessage) case t: Throwable => // We do not want to swallow unknown exceptions so that we can handle later logError(s"Unknown exception while parsing an event", t) diff --git a/core/src/test/resources/ProfilingExpectations/rapids_duration_and_cpu_expectation.csv b/core/src/test/resources/ProfilingExpectations/rapids_duration_and_cpu_expectation.csv index 5202c48dd..b4da60fe4 100644 --- a/core/src/test/resources/ProfilingExpectations/rapids_duration_and_cpu_expectation.csv +++ b/core/src/test/resources/ProfilingExpectations/rapids_duration_and_cpu_expectation.csv @@ -5,5 +5,5 @@ appIndex,App ID,sqlID,SQL Duration,Contains Dataset or RDD Op,App Duration,Poten 1,"local-1626104300434",3,76,false,131104,"NESTED COMPLEX TYPE",97.56 1,"local-1626104300434",4,65,false,131104,"NESTED COMPLEX TYPE",100.0 1,"local-1626104300434",5,479,false,131104,"NESTED COMPLEX TYPE",87.32 -1,"local-1626104300434",6,95,false,131104,"NESTED COMPLEX TYPE",96.3 -1,"local-1626104300434",7,65,false,131104,"NESTED COMPLEX TYPE",95.24 +1,"local-1626104300434",6,95,false,131104,"",96.3 +1,"local-1626104300434",7,65,false,131104,"",95.24