Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,13 @@ public abstract class AbstractDatadogSparkListener extends SparkListener {
public static volatile boolean finishTraceOnApplicationEnd = true;
public static volatile boolean isPysparkShell = false;

private final int MAX_COLLECTION_SIZE = 5000;
private static final int MAX_COLLECTION_SIZE = 5000;

/** Overridable in tests to exercise collection-cap behaviour without filling 5000 entries. */
protected int maxCollectionSize() {
return MAX_COLLECTION_SIZE;
}

private final int MAX_ACCUMULATOR_SIZE = 50000;
private final String RUNTIME_TAGS_PREFIX = "spark.datadog.tags.";
private static final String AGENT_OL_ENDPOINT = "openlineage/api/v1/lineage";
Expand Down Expand Up @@ -115,6 +121,7 @@ public abstract class AbstractDatadogSparkListener extends SparkListener {

private final HashMap<Integer, Integer> stageToJob = new HashMap<>();
private final HashMap<Long, Properties> stageProperties = new HashMap<>();
private final HashMap<Integer, String> jobToSessionId = new HashMap<>();

private final SparkAggregatedTaskMetrics applicationMetrics = new SparkAggregatedTaskMetrics();
private final HashMap<String, SparkAggregatedTaskMetrics> streamingBatchMetrics = new HashMap<>();
Expand All @@ -127,6 +134,12 @@ public abstract class AbstractDatadogSparkListener extends SparkListener {
private final HashMap<Long, SparkListenerSQLExecutionStart> sqlQueries = new HashMap<>();
protected final HashMap<Long, SparkPlanInfo> sqlPlans = new HashMap<>();
private final HashMap<String, SparkListenerExecutorAdded> liveExecutors = new HashMap<>();
private final HashMap<String, AgentSpan> perSessionApplicationSpans = new HashMap<>();
private final HashMap<String, SparkAggregatedTaskMetrics> perSessionApplicationMetrics =
new HashMap<>();
private final HashMap<String, Boolean> perSessionLastJobFailed = new HashMap<>();
private final HashMap<String, String> perSessionLastJobFailedMessage = new HashMap<>();
private final HashMap<String, String> perSessionLastJobFailedStackTrace = new HashMap<>();

private final Map<Long, Integer> accumulatorToStageID = new HashMap<>();

Expand Down Expand Up @@ -361,6 +374,39 @@ public synchronized void finishApplication(
}
applicationEnded = true;

// TODO: per-session app spans are closed here (server shutdown) rather than when the session
// actually ends — so their duration is "first job → server shutdown". Spark Connect does emit
// a server-side session-close event, but it is not surfaced through SparkListener today.
// When that hook becomes available, finish the span there and remove it from the map so that
// long-lived servers don't accumulate unbounded open spans.

// Finish per-session application spans before the guard below, because a pure Connect server
// has applicationSpan == null with jobCount > 0, which would cause the guard to return early
// and skip finishing the per-session spans entirely.
for (Map.Entry<String, AgentSpan> entry : perSessionApplicationSpans.entrySet()) {
String sessionId = entry.getKey();
AgentSpan sessionAppSpan = entry.getValue();

if (Boolean.TRUE.equals(perSessionLastJobFailed.get(sessionId))) {
sessionAppSpan.setError(true);
sessionAppSpan.setTag(DDTags.ERROR_TYPE, "Spark Application Failed");
sessionAppSpan.setTag(DDTags.ERROR_MSG, perSessionLastJobFailedMessage.get(sessionId));
sessionAppSpan.setTag(DDTags.ERROR_STACK, perSessionLastJobFailedStackTrace.get(sessionId));
}

SparkAggregatedTaskMetrics sessionMetrics = perSessionApplicationMetrics.get(sessionId);
if (sessionMetrics != null) {
sessionMetrics.setSpanMetrics(sessionAppSpan);
}

sessionAppSpan.finish(time * 1000);
}
perSessionApplicationSpans.clear();
perSessionApplicationMetrics.clear();
perSessionLastJobFailed.clear();
perSessionLastJobFailedMessage.clear();
perSessionLastJobFailedStackTrace.clear();
Comment on lines +402 to +408
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Flush finished session spans before the early return

For a pure Spark Connect server, the comment above notes that applicationSpan == null with jobCount > 0, so after these per-session application spans are finished the existing guard at line 410 returns before reaching the tracer.flush() used at application shutdown. Since these session root spans remain open until finishApplication, they can be left only in the writer queue when the JVM/SparkContext is stopping and may be dropped instead of being synchronously written; flush after finishing the per-session spans when this early-return path is taken.

Useful? React with 👍 / 👎.


if ((applicationSpan == null && jobCount > 0) || isRunningOnDatabricks) {
// If the application span is not initialized, but spark jobs have been executed, all those
// spark jobs were databricks or streaming. In this case we don't send the application span
Expand Down Expand Up @@ -466,6 +512,8 @@ private AgentSpan getOrCreateSqlSpan(
return null;
}

String connectSessionId = getSparkConnectSessionId(jobProperties);

AgentTracer.SpanBuilder spanBuilder =
buildSparkSpan("spark.sql", jobProperties)
.withStartTimestamp(queryStart.time() * 1000)
Expand All @@ -479,6 +527,10 @@ private AgentSpan getOrCreateSqlSpan(
AgentSpan batchSpan =
getOrCreateStreamingBatchSpan(batchKey, queryStart.time(), jobProperties);
spanBuilder.asChildOf(batchSpan.context());
} else if (connectSessionId != null) {
AgentSpan sessionAppSpan =
getOrCreatePerSessionApplicationSpan(connectSessionId, queryStart.time(), jobProperties);
spanBuilder.asChildOf(sessionAppSpan.context());
} else if (isRunningOnDatabricks) {
addDatabricksSpecificTags(spanBuilder, jobProperties, true);
Comment on lines +530 to 535
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve Databricks parenting for Connect SQL jobs

When a SQL job on a Databricks cluster also carries Spark Connect spark.jobTags, this new branch wins before isRunningOnDatabricks, so addDatabricksSpecificTags(...) is skipped for the SQL span; because onJobStart then parents the job under that SQL span via the sqlSpan != null branch, neither span gets the Databricks job/run/task tags or parent context that this listener previously added for Databricks jobs. This only affects Spark Connect SQL workloads running in a Databricks environment; keep the Databricks path ahead of the Connect-session path or apply the Databricks tags/parenting as well.

Useful? React with 👍 / 👎.

} else {
Expand All @@ -492,10 +544,99 @@ private AgentSpan getOrCreateSqlSpan(
return sqlSpan;
}

private AgentSpan getOrCreatePerSessionApplicationSpan(
String sessionId, long timeMs, Properties jobProperties) {
AgentSpan span = perSessionApplicationSpans.get(sessionId);
if (span != null) {
return span;
}

if (perSessionApplicationSpans.size() >= maxCollectionSize()) {
// Cap exceeded: fall back to the global application span so this session's children
// are still parented and the started span is never orphaned.
initApplicationSpanIfNotInitialized();
return applicationSpan;
}

AgentTracer.SpanBuilder builder =
buildSparkSpan("spark.application", jobProperties)
// 1µs before first child so this span sorts strictly before its children.
.withStartTimestamp(timeMs * 1000 - 1)
.withTag("session_id", sessionId)
.withTag("spark.connect.server", true);

if (applicationStart != null) {
String ddTags =
Config.get().getGlobalTags().entrySet().stream()
.sorted(Map.Entry.comparingByKey())
.map(e -> e.getKey() + ":" + e.getValue())
.collect(Collectors.joining(","));

builder
.withTag("application_name", applicationStart.appName())
.withTag("djm.tags", ddTags)
.withTag("spark_user", applicationStart.sparkUser());

applicationStart.appAttemptId().foreach(id -> builder.withTag("app_attempt_id", id));
}

captureApplicationParameters(builder);
captureEmrStepId(builder);
captureOpenlineageJobInfo(builder);

// captureOpenlineageContextIfPresent and predeterminedTraceIdContext are intentionally NOT
// applied — per-session spans must be independent trace roots.

AgentSpan sessionAppSpan = builder.start();
sessionAppSpan.setMeasured(true);
setDataJobsSamplingPriority(sessionAppSpan);

perSessionApplicationSpans.put(sessionId, sessionAppSpan);
perSessionApplicationMetrics.put(sessionId, new SparkAggregatedTaskMetrics());
return sessionAppSpan;
}

// Spark Connect adds
// "SparkConnect_OperationTag_User_{userId}_Session_{sessionId}_Operation_{opId}"
// to every job's spark.jobTags via SparkContext.addJobTag in ExecuteThreadRunner.scala.
private static final String CONNECT_OP_TAG_PREFIX = "SparkConnect_OperationTag_";
private static final String SESSION_MARKER = "_Session_";
private static final String OPERATION_MARKER = "_Operation_";

private static String getSparkConnectSessionId(Properties properties) {
if (properties == null) {
return null;
}
String jobTags = properties.getProperty("spark.jobTags");
if (jobTags == null) {
return null;
}
for (String tag : jobTags.split(",")) {
tag = tag.trim();
if (!tag.startsWith(CONNECT_OP_TAG_PREFIX)) {
continue;
}
int sessionIdx = tag.indexOf(SESSION_MARKER);
if (sessionIdx < 0) {
continue;
}
int sessionStart = sessionIdx + SESSION_MARKER.length();
int operationIdx = tag.indexOf(OPERATION_MARKER, sessionStart);
if (operationIdx <= sessionStart) {
continue;
}
String sessionId = tag.substring(sessionStart, operationIdx);
if (!sessionId.isEmpty()) {
return sessionId;
}
}
return null;
}

@Override
public synchronized void onJobStart(SparkListenerJobStart jobStart) {
jobCount++;
if (jobSpans.size() > MAX_COLLECTION_SIZE) {
if (jobSpans.size() > maxCollectionSize()) {
return;
}

Expand All @@ -507,6 +648,7 @@ public synchronized void onJobStart(SparkListenerJobStart jobStart) {

String batchKey = getStreamingBatchKey(jobStart.properties());
Long sqlExecutionId = getSqlExecutionId(jobStart.properties());
String connectSessionId = getSparkConnectSessionId(jobStart.properties());
AgentSpan sqlSpan = null;

if (sqlExecutionId != null) {
Expand All @@ -531,6 +673,11 @@ public synchronized void onJobStart(SparkListenerJobStart jobStart) {
jobSpanBuilder.asChildOf(batchSpan.context());
} else if (isRunningOnDatabricks) {
addDatabricksSpecificTags(jobSpanBuilder, jobStart.properties(), true);
} else if (connectSessionId != null) {
AgentSpan sessionAppSpan =
getOrCreatePerSessionApplicationSpan(
connectSessionId, jobStart.time(), jobStart.properties());
jobSpanBuilder.asChildOf(sessionAppSpan.context());
} else {
// In non-databricks, non-streaming env, the spark application is the local root span
initApplicationSpanIfNotInitialized();
Expand All @@ -546,6 +693,12 @@ public synchronized void onJobStart(SparkListenerJobStart jobStart) {
for (int stageId : getSparkJobStageIds(jobStart)) {
stageToJob.put(stageId, jobStart.jobId());
}
// If the cap is reached the put is dropped; onJobEnd then recovers connectSessionId as null,
// so a failure on that job is attributed to the global lastJobFailed instead of the session.
// This requires >maxCollectionSize() in-flight Connect jobs concurrently, which is unlikely.
if (connectSessionId != null && jobToSessionId.size() < maxCollectionSize()) {
jobToSessionId.put(jobStart.jobId(), connectSessionId);
}
jobSpans.put(jobStart.jobId(), jobSpan);
notifyOl(x -> openLineageSparkListener.onJobStart(x), jobStart);
}
Expand All @@ -557,6 +710,8 @@ public synchronized void onJobEnd(SparkListenerJobEnd jobEnd) {
return;
}

String connectSessionId = jobToSessionId.remove(jobEnd.jobId());

if (jobEnd.jobResult() instanceof JobFailed) {
JobFailed jobFailed = (JobFailed) jobEnd.jobResult();
Exception exception = jobFailed.exception();
Expand All @@ -571,13 +726,23 @@ public synchronized void onJobEnd(SparkListenerJobEnd jobEnd) {

// Only propagate the error to the application if it is not a cancellation
if (errorMessage != null && !errorMessage.toLowerCase().contains("cancelled")) {
lastJobFailed = true;
lastJobFailedMessage = errorMessage;
lastJobFailedStackTrace = errorStackTrace;
if (connectSessionId != null && perSessionApplicationSpans.containsKey(connectSessionId)) {
perSessionLastJobFailed.put(connectSessionId, true);
perSessionLastJobFailedMessage.put(connectSessionId, errorMessage);
perSessionLastJobFailedStackTrace.put(connectSessionId, errorStackTrace);
} else {
lastJobFailed = true;
lastJobFailedMessage = errorMessage;
lastJobFailedStackTrace = errorStackTrace;
}
}
} else {
lastJobFailed = false;
lastSqlFailed = false;
if (connectSessionId != null && perSessionApplicationSpans.containsKey(connectSessionId)) {
perSessionLastJobFailed.put(connectSessionId, false);
} else {
lastJobFailed = false;
lastSqlFailed = false;
}
}

SparkAggregatedTaskMetrics metrics = jobMetrics.remove(jobEnd.jobId());
Expand All @@ -591,7 +756,7 @@ public synchronized void onJobEnd(SparkListenerJobEnd jobEnd) {

@Override
public synchronized void onStageSubmitted(SparkListenerStageSubmitted stageSubmitted) {
if (stageSpans.size() > MAX_COLLECTION_SIZE) {
if (stageSpans.size() > maxCollectionSize()) {
return;
}

Expand Down Expand Up @@ -683,12 +848,21 @@ public synchronized void onStageCompleted(SparkListenerStageCompleted stageCompl

Properties prop = stageProperties.remove(stageSpanKey);
Long sqlExecutionId = getSqlExecutionId(prop);
String connectSessionId = getSparkConnectSessionId(prop);

SparkAggregatedTaskMetrics stageMetric = stageMetrics.remove(stageSpanKey);
if (stageMetric != null) {
stageMetric.computeSkew();
stageMetric.setSpanMetrics(span);
applicationMetrics.accumulateStageMetrics(stageMetric);
if (connectSessionId != null) {
SparkAggregatedTaskMetrics sessionMetrics =
perSessionApplicationMetrics.get(connectSessionId);
if (sessionMetrics != null) {
sessionMetrics.accumulateStageMetrics(stageMetric);
}
} else {
applicationMetrics.accumulateStageMetrics(stageMetric);
}

jobMetrics
.computeIfAbsent(jobId, k -> new SparkAggregatedTaskMetrics())
Expand Down Expand Up @@ -820,7 +994,7 @@ public synchronized void onExecutorAdded(SparkListenerExecutorAdded executorAdde
currentExecutorCount += 1;
maxExecutorCount = Math.max(maxExecutorCount, currentExecutorCount);

if (liveExecutors.size() <= MAX_COLLECTION_SIZE) {
if (liveExecutors.size() <= maxCollectionSize()) {
liveExecutors.put(executorAdded.executorId(), executorAdded);
}
}
Expand Down Expand Up @@ -941,7 +1115,7 @@ private synchronized void onSQLExecutionEnd(SparkListenerSQLExecutionEnd sqlEnd)

private synchronized void onStreamingQueryStartedEvent(
StreamingQueryListener.QueryStartedEvent event) {
if (streamingQueries.size() > MAX_COLLECTION_SIZE) {
if (streamingQueries.size() > maxCollectionSize()) {
return;
}

Expand Down
Loading
Loading