Skip to content

Commit 634bd96

Browse files
Adds aggregation across metrics for failed/succeeded and non completed stages (#1558)
Fixes #1552 Currently we store the stageInfo using the stageModelManager class where we map incoming stage information during the following events - 1. doSparkListenerStageCompleted https://github.com/NVIDIA/spark-rapids-tools/blob/1f037fa867e4df0952e29d82164cc7fc507c9b4e/core/src/main/scala/org/apache/spark/sql/rapids/tool/EventProcessorBase.scala#L475 2. doSparkListenerStageSubmitted. - https://github.com/NVIDIA/spark-rapids-tools/blob/1f037fa867e4df0952e29d82164cc7fc507c9b4e/core/src/main/scala/org/apache/spark/sql/rapids/tool/EventProcessorBase.scala#L464 So a stage information is updated once when a stage is submitted and once during completion. A stageCompleted event comes for all attempts of a stage ( eg - there will be two stage Submitted and StageCompleted events for stage that fails on first attempt and succeeds on attempt 2) This PR changes that behavior to aggregate all attempts for a stage ( failed + succeeded ) ### Changes - This pull request includes several changes to improve the handling of stage attempts and task metrics in the Spark RAPIDS tool. The most important changes include adding logic to handle multiple stage attempts, modifying methods to aggregate metrics for these attempts, and updating the `AccumManager` to simplify task accumulation. Handling multiple stage attempts: * [`core/src/main/scala/com/nvidia/spark/rapids/tool/analysis/AppSparkMetricsAnalyzer.scala`](diffhunk://#diff-4b0aab10a86746bb7480cc3bde4e013c04707758c61782934c07604443160b40L450-R455): Added logic to handle multiple stage attempts by aggregating metrics for each attempt. * [`core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/ProfileClassWarehouse.scala`](diffhunk://#diff-8d5819c9445c1489d61ee8d03fd2b1ee1e0cb33896f402f4ceb7782c35deed69R688-R746): Introduced `aggregateStageProfileMetric` method to combine metrics for multiple attempts of the same stage. Simplifying task accumulation: * [`core/src/main/scala/org/apache/spark/sql/rapids/tool/EventProcessorBase.scala`](diffhunk://#diff-9b551b7ad326fd9175e0c5b0ba69e947058ee2587922f1fe059e85623604e9c1L372-R372): Modified `addAccToTask` method to remove the `taskId` parameter and simplify task accumulation. * [`core/src/main/scala/org/apache/spark/sql/rapids/tool/store/AccumInfo.scala`](diffhunk://#diff-2cdf5cec29c5cfc15962382b2134c8e88b6623afdfd7cc6a81ec3dfc5663b4a1L87-R89): Updated `addAccumToTask` method to remove the `taskId` parameter. * [`core/src/main/scala/org/apache/spark/sql/rapids/tool/store/AccumManager.scala`](diffhunk://#diff-ff390301f53c6470012e1c36878c1987f176c7eeaa52e30e18f93f76e58587b3L43-R45): Simplified `addAccToTask` method by removing the `taskId` parameter. ### Testing This change has been tested against internal event logs and integration tests have been updated to ensure this behavior is tested for the future --------- Signed-off-by: Sayed Bilal Bari <[email protected]>
1 parent 6cf44b3 commit 634bd96

File tree

6 files changed

+87
-11
lines changed

6 files changed

+87
-11
lines changed

core/src/main/scala/com/nvidia/spark/rapids/tool/analysis/AppSparkMetricsAnalyzer.scala

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -320,9 +320,7 @@ class AppSparkMetricsAnalyzer(app: AppBase) extends AppAnalysisBase(app) {
320320
AccumProfileResults(0, 0, AccumMetaRef.EMPTY_ACCUM_META_REF, 0L, 0L, 0L, 0L)
321321
val emptyNodeNames = Seq.empty[String]
322322
val emptyDiagnosticMetrics = HashMap.empty[String, AccumProfileResults]
323-
// TODO: this has stage attempts. we should handle different attempts
324323
app.stageManager.getAllStages.map { sm =>
325-
// TODO: Should we only consider successful tasks?
326324
val tasksInStage = app.taskManager.getTasks(sm.stageInfo.stageId,
327325
sm.stageInfo.attemptNumber())
328326
// count duplicate task attempts
@@ -358,13 +356,12 @@ class AppSparkMetricsAnalyzer(app: AppBase) extends AppAnalysisBase(app) {
358356
}
359357

360358
/**
361-
* Aggregates the SparkMetrics by stage. This is an internal method to populate the cached metrics
359+
* Aggregates the SparkMetrics by completed stage information.
360+
* This is an internal method to populate the cached metrics
362361
* to be used by other aggregators.
363362
* @param index AppIndex (used by the profiler tool)
364363
*/
365364
private def aggregateSparkMetricsByStageInternal(index: Int): Unit = {
366-
// TODO: this has stage attempts. we should handle different attempts
367-
368365
// For Photon apps, peak memory and shuffle write time need to be calculated from accumulators
369366
// instead of task metrics.
370367
// Approach:
@@ -447,7 +444,15 @@ class AppSparkMetricsAnalyzer(app: AppBase) extends AppAnalysisBase(app) {
447444
perStageRec.swBytesWrittenSum,
448445
perStageRec.swRecordsWrittenSum,
449446
perStageRec.swWriteTimeSum) // converted to milliseconds by the aggregator
450-
stageLevelSparkMetrics(index).put(sm.stageInfo.stageId, stageRow)
447+
// This logic is to handle the case where there are multiple attempts for a stage.
448+
// We check if the StageLevelCache already has a row for the stage.
449+
// If yes, we aggregate the metrics of the new row with the existing row.
450+
// If no, we just store the new row.
451+
val rowToStore = stageLevelSparkMetrics(index)
452+
.get(sm.stageInfo.stageId)
453+
.map(_.aggregateStageProfileMetric(stageRow))
454+
.getOrElse(stageRow)
455+
stageLevelSparkMetrics(index).put(sm.stageInfo.stageId, rowToStore)
451456
}
452457
}
453458
}

core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/ProfileClassWarehouse.scala

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -694,6 +694,65 @@ case class StageAggTaskMetricsProfileResult(
694694
swRecordsWrittenSum: Long,
695695
swWriteTimeSum: Long // milliseconds
696696
) extends BaseJobStageAggTaskMetricsProfileResult {
697+
698+
/**
699+
* Combines two StageAggTaskMetricsProfileResults for the same stage.
700+
* This method aggregates the metrics from the current instance and the provided `other` instance.
701+
*
702+
* Detailed explanation ->
703+
* 1. A stage can have two successful attempts.
704+
* 2. We store both of those attempt information using the StageManager
705+
* 3. During aggregation, we combine the metrics for a stage at a stageID
706+
* level
707+
* 4. For combining aggregated information for multiple stage attempts, we combine the
708+
* aggregated per attempt information into one using the below method
709+
*
710+
* @param other The StageAggTaskMetricsProfileResult to be combined with the current instance.
711+
* @return A new StageAggTaskMetricsProfileResult with aggregated metrics.
712+
*/
713+
def aggregateStageProfileMetric(
714+
other: StageAggTaskMetricsProfileResult
715+
): StageAggTaskMetricsProfileResult = {
716+
StageAggTaskMetricsProfileResult(
717+
appIndex = this.appIndex,
718+
id = this.id,
719+
numTasks = this.numTasks + other.numTasks,
720+
duration = Option(this.duration.getOrElse(0L) + other.duration.getOrElse(0L)),
721+
diskBytesSpilledSum = this.diskBytesSpilledSum + other.diskBytesSpilledSum,
722+
durationSum = this.durationSum + other.durationSum,
723+
durationMax = Math.max(this.durationMax, other.durationMax),
724+
durationMin = Math.min(this.durationMin, other.durationMin),
725+
durationAvg = (this.durationAvg + other.durationAvg) / 2,
726+
executorCPUTimeSum = this.executorCPUTimeSum + other.executorCPUTimeSum,
727+
executorDeserializeCpuTimeSum = this.executorDeserializeCpuTimeSum +
728+
other.executorDeserializeCpuTimeSum,
729+
executorDeserializeTimeSum = this.executorDeserializeTimeSum +
730+
other.executorDeserializeTimeSum,
731+
executorRunTimeSum = this.executorRunTimeSum + other.executorRunTimeSum,
732+
inputBytesReadSum = this.inputBytesReadSum + other.inputBytesReadSum,
733+
inputRecordsReadSum = this.inputRecordsReadSum + other.inputRecordsReadSum,
734+
jvmGCTimeSum = this.jvmGCTimeSum + other.jvmGCTimeSum,
735+
memoryBytesSpilledSum = this.memoryBytesSpilledSum + other.memoryBytesSpilledSum,
736+
outputBytesWrittenSum = this.outputBytesWrittenSum + other.outputBytesWrittenSum,
737+
outputRecordsWrittenSum = this.outputRecordsWrittenSum + other.outputRecordsWrittenSum,
738+
peakExecutionMemoryMax = Math.max(this.peakExecutionMemoryMax, other.peakExecutionMemoryMax),
739+
resultSerializationTimeSum = this.resultSerializationTimeSum +
740+
other.resultSerializationTimeSum,
741+
resultSizeMax = Math.max(this.resultSizeMax, other.resultSizeMax),
742+
srFetchWaitTimeSum = this.srFetchWaitTimeSum + other.srFetchWaitTimeSum,
743+
srLocalBlocksFetchedSum = this.srLocalBlocksFetchedSum + other.srLocalBlocksFetchedSum,
744+
srRemoteBlocksFetchSum = this.srRemoteBlocksFetchSum + other.srRemoteBlocksFetchSum,
745+
srRemoteBytesReadSum = this.srRemoteBytesReadSum + other.srRemoteBytesReadSum,
746+
srRemoteBytesReadToDiskSum = this.srRemoteBytesReadToDiskSum +
747+
other.srRemoteBytesReadToDiskSum,
748+
srTotalBytesReadSum = this.srTotalBytesReadSum + other.srTotalBytesReadSum,
749+
srcLocalBytesReadSum = this.srcLocalBytesReadSum + other.srcLocalBytesReadSum,
750+
swBytesWrittenSum = this.swBytesWrittenSum + other.swBytesWrittenSum,
751+
swRecordsWrittenSum = this.swRecordsWrittenSum + other.swRecordsWrittenSum,
752+
swWriteTimeSum = this.swWriteTimeSum + other.swWriteTimeSum
753+
)
754+
}
755+
697756
override def idHeader = "stageId"
698757
}
699758

core/src/main/scala/org/apache/spark/sql/rapids/tool/EventProcessorBase.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ abstract class EventProcessorBase[T <: AppBase](app: T) extends SparkListener wi
369369
// Parse task accumulables
370370
for (res <- event.taskInfo.accumulables) {
371371
try {
372-
app.accumManager.addAccToTask(event.stageId, event.taskInfo.taskId, res)
372+
app.accumManager.addAccToTask(event.stageId, res)
373373
} catch {
374374
case NonFatal(e) =>
375375
logWarning("Exception when parsing accumulables on task-completed "

core/src/main/scala/org/apache/spark/sql/rapids/tool/store/AccumInfo.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,9 @@ class AccumInfo(val infoRef: AccumMetaRef) {
8484
* attempt information with give no Stats at accumulable level
8585
*
8686
* @param stageId The ID of the stage containing the task
87-
* @param taskId The ID of the completed task
8887
* @param accumulableInfo Accumulator information from the TaskEnd event
8988
*/
90-
def addAccumToTask(stageId: Int, taskId: Long, accumulableInfo: AccumulableInfo): Unit = {
89+
def addAccumToTask(stageId: Int, accumulableInfo: AccumulableInfo): Unit = {
9190
// 1. We first extract the incoming task update value
9291
// 2. Then allocate a new Statistic metric object with min,max as incoming update
9392
// 3. Use count to calculate rolling average

core/src/main/scala/org/apache/spark/sql/rapids/tool/store/AccumManager.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,9 @@ class AccumManager {
4040
accumInfoRef.addAccumToStage(stageId, accumulableInfo)
4141
}
4242

43-
def addAccToTask(stageId: Int, taskId: Long, accumulableInfo: AccumulableInfo): Unit = {
43+
def addAccToTask(stageId: Int, accumulableInfo: AccumulableInfo): Unit = {
4444
val accumInfoRef = getOrCreateAccumInfo(accumulableInfo.id, accumulableInfo.name)
45-
accumInfoRef.addAccumToTask(stageId, taskId, accumulableInfo)
45+
accumInfoRef.addAccumToTask(stageId, accumulableInfo)
4646
}
4747

4848
def getAccStageIds(id: Long): Set[Int] = {

core/src/main/scala/org/apache/spark/sql/rapids/tool/store/StageModel.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,23 @@ class StageModel private(sInfo: StageInfo) {
6969
ProfileUtils.optionLongMinusOptionLong(stageInfo.completionTime, stageInfo.submissionTime)
7070
}
7171

72+
/**
73+
* Returns true if a stage attempt has failed.
74+
* There can be multiple attempts( retries ) of a stage
75+
* that can fail until the last attempt succeeds.
76+
*
77+
* @return true if a stage attempt has failed.
78+
*/
7279
def hasFailed: Boolean = {
7380
stageInfo.failureReason.isDefined
7481
}
7582

83+
/**
84+
* Returns the failure reason if the stage has failed.
85+
* Failure reason being set is the sure shot of a failed stage.
86+
*
87+
* @return the failure reason if the stage has failed, or an empty string otherwise
88+
*/
7689
def getFailureReason: String = {
7790
stageInfo.failureReason.getOrElse("")
7891
}

0 commit comments

Comments
 (0)