From 84dff10a3ee7f47e30a40409e56b5e9365c69815 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Wed, 25 Mar 2026 08:02:43 -0700 Subject: [PATCH] fix: Fixing tracing for function calls Fixing when the execute tool happens in the graph. Refactor and simplify the tracing logic for function and tool calls within the ADK. The primary goal is to consolidate multiple tracing events into more cohesive operations, specifically merging "tool_call" and "tool_response" into a single "execute_tool" operation. ### Key Changes: * **Consolidated Tool Tracing:** Replaced the separate `traceToolCall` and `traceToolResponse` methods with a unified `traceToolExecution` in `Tracing.java`. This reduces span noise by representing a tool's lifecycle as a single "execute_tool" operation containing both arguments and results (or errors). * **Standardized Operation Names:** Introduced constants for core Gen AI operations: `invoke_agent`, `execute_tool`, `send_data`, and `call_llm`. * **Improved Error Tracing:** `traceToolExecution` and `traceCallLlm` now explicitly accept an optional `Exception`, allowing them to automatically set the span status to error and record the exception. * **Refactored Tracing API:** * `traceSendData` and other methods now require an explicit `Span` argument, moving away from implicit context lookups where appropriate. * Added `traceMergedToolCalls` to specifically handle the telemetry for parallel tool executions. * **Flow Logic Cleanup:** Simplified `Functions.java` and `BaseLlmFlow.java` by removing redundant context passing and aligning with the new consolidated tracing methods. * **Test Suite Updates:** Significantly updated `ContextPropagationTest.java` to reflect the new tracing model. Several manual hierarchy tests were removed in favor of testing the consolidated `execute_tool` logic and updated attributes. PiperOrigin-RevId: 889246953 --- .../adk/flows/llmflows/BaseLlmFlow.java | 17 +- .../google/adk/flows/llmflows/Functions.java | 42 +-- .../com/google/adk/telemetry/Tracing.java | 184 ++++++----- .../com/google/adk/agents/LlmAgentTest.java | 10 +- .../com/google/adk/runner/RunnerTest.java | 30 +- .../adk/telemetry/ContextPropagationTest.java | 303 +++--------------- 6 files changed, 208 insertions(+), 378 deletions(-) diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index d4fe1b838..aa62b9f31 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -233,7 +233,7 @@ private Flowable callLlm( callLlmContext) .doOnSubscribe( s -> - Tracing.traceCallLlm( + traceCallLlm( span, context, eventForCallbackUsage.id(), @@ -520,6 +520,7 @@ public Flowable runLive(InvocationContext invocationContext) { .doOnComplete( () -> Tracing.traceSendData( + Span.current(), invocationContext, eventIdForSendData, llmRequestAfterPreprocess.contents())) @@ -529,6 +530,7 @@ public Flowable runLive(InvocationContext invocationContext) { span.setStatus(StatusCode.ERROR, error.getMessage()); span.recordException(error); Tracing.traceSendData( + Span.current(), invocationContext, eventIdForSendData, llmRequestAfterPreprocess.contents()); @@ -706,6 +708,19 @@ private Flowable buildPostprocessingEvents( return processorEvents.concatWith(Flowable.just(modelResponseEvent)).concatWith(functionEvents); } + /** + * Traces an LLM call without an associated exception. This is an overload for {@link + * Tracing#traceCallLlm} for successful calls. + */ + private void traceCallLlm( + Span span, + InvocationContext context, + String eventId, + LlmRequest llmRequest, + LlmResponse llmResponse) { + Tracing.traceCallLlm(span, context, eventId, llmRequest, llmResponse, null); + } + private Event buildModelResponseEvent( Event baseEventForLlmResponse, LlmRequest llmRequest, LlmResponse llmResponse) { Event.Builder eventBuilder = diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java index 84a8141ea..0b0e5b4d5 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java @@ -178,8 +178,12 @@ public static Maybe handleFunctionCalls( if (events.size() > 1) { return Maybe.just(mergedEvent) - .doOnSuccess(event -> Tracing.traceToolResponse(event.id(), event)) - .compose(Tracing.trace("tool_response").setParent(parentContext)); + .compose( + Tracing.trace("execute_tool (merged)") + .setParent(parentContext) + .onSuccess( + (span, event) -> + Tracing.traceMergedToolCalls(span, event.id(), event))); } return Maybe.just(mergedEvent); }); @@ -269,10 +273,8 @@ private static Function> getFunctionCallMapper( tool, toolContext, functionCall, - functionArgs, - parentContext) - : callTool( - tool, functionArgs, toolContext, parentContext)) + functionArgs) + : callTool(tool, functionArgs, toolContext)) .compose(Tracing.withContext(parentContext))); return postProcessFunctionResult( @@ -296,8 +298,7 @@ private static Maybe> processFunctionLive( BaseTool tool, ToolContext toolContext, FunctionCall functionCall, - Map args, - Context parentContext) { + Map args) { // Case 1: Handle a call to stopStreaming if (functionCall.name().get().equals("stopStreaming") && args.containsKey("functionName")) { String functionNameToStop = (String) args.get("functionName"); @@ -365,7 +366,7 @@ private static Maybe> processFunctionLive( } // Case 3: Fallback for regular, non-streaming tools - return callTool(tool, args, toolContext, parentContext); + return callTool(tool, args, toolContext); } public static Set getLongRunningFunctionCalls( @@ -426,12 +427,22 @@ private static Maybe postProcessFunctionResult( Event event = buildResponseEvent( tool, finalFunctionResult, toolContext, invocationContext); - Tracing.traceToolResponse(event.id(), event); return Maybe.just(event); }); }) .compose( - Tracing.trace("tool_response [" + tool.name() + "]").setParent(parentContext)); + Tracing.trace("execute_tool [" + tool.name() + "]") + .setParent(parentContext) + .onSuccess( + (span, event) -> + Tracing.traceToolExecution( + span, + tool.name(), + tool.description(), + tool.getClass().getSimpleName(), + functionArgs, + event, + null))); } private static Optional mergeParallelFunctionResponseEvents( @@ -579,17 +590,10 @@ private static Maybe> maybeInvokeAfterToolCall( } private static Maybe> callTool( - BaseTool tool, Map args, ToolContext toolContext, Context parentContext) { + BaseTool tool, Map args, ToolContext toolContext) { return tool.runAsync(args, toolContext) .toMaybe() - .doOnSubscribe( - d -> - Tracing.traceToolCall( - tool.name(), tool.description(), tool.getClass().getSimpleName(), args)) .doOnError(t -> Span.current().recordException(t)) - .compose( - Tracing.>trace("tool_call [" + tool.name() + "]") - .setParent(parentContext)) .onErrorResumeNext( e -> Maybe.error( diff --git a/core/src/main/java/com/google/adk/telemetry/Tracing.java b/core/src/main/java/com/google/adk/telemetry/Tracing.java index 215e317e1..589215073 100644 --- a/core/src/main/java/com/google/adk/telemetry/Tracing.java +++ b/core/src/main/java/com/google/adk/telemetry/Tracing.java @@ -33,6 +33,7 @@ import io.opentelemetry.api.GlobalOpenTelemetry; import io.opentelemetry.api.common.AttributeKey; import io.opentelemetry.api.trace.Span; +import io.opentelemetry.api.trace.StatusCode; import io.opentelemetry.api.trace.Tracer; import io.opentelemetry.context.Context; import io.opentelemetry.context.Scope; @@ -61,6 +62,7 @@ import java.util.function.BiConsumer; import java.util.function.Consumer; import java.util.function.Supplier; +import org.jspecify.annotations.Nullable; import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; @@ -77,6 +79,11 @@ public class Tracing { private static final Logger log = LoggerFactory.getLogger(Tracing.class); + private static final String INVOKE_AGENT_OPERATION = "invoke_agent"; + private static final String EXECUTE_TOOL_OPERATION = "execute_tool"; + private static final String SEND_DATA_OPERATION = "send_data"; + private static final String CALL_LLM_OPERATION = "call_llm"; + private static final AttributeKey> GEN_AI_RESPONSE_FINISH_REASONS = AttributeKey.stringArrayKey("gen_ai.response.finish_reasons"); @@ -134,15 +141,6 @@ public class Tracing { private Tracing() {} - private static void traceWithSpan(String methodName, Consumer traceAction) { - Span span = Span.current(); - if (!span.getSpanContext().isValid()) { - log.trace("{}: No valid span in current context.", methodName); - return; - } - traceAction.accept(span); - } - private static void setInvocationAttributes( Span span, InvocationContext invocationContext, String eventId) { span.setAttribute(ADK_INVOCATION_ID, invocationContext.invocationId()); @@ -159,12 +157,6 @@ private static void setInvocationAttributes( } } - private static void setToolExecutionAttributes(Span span) { - span.setAttribute(GEN_AI_OPERATION_NAME, "execute_tool"); - span.setAttribute(ADK_LLM_REQUEST, "{}"); - span.setAttribute(ADK_LLM_RESPONSE, "{}"); - } - private static void setJsonAttribute(Span span, AttributeKey key, Object value) { if (!CAPTURE_MESSAGE_CONTENT_IN_SPANS) { span.setAttribute(key, "{}"); @@ -198,7 +190,7 @@ public static void setTracerForTesting(Tracer tracer) { */ public static void traceAgentInvocation( Span span, String agentName, String agentDescription, InvocationContext invocationContext) { - span.setAttribute(GEN_AI_OPERATION_NAME, "invoke_agent"); + span.setAttribute(GEN_AI_OPERATION_NAME, INVOKE_AGENT_OPERATION); span.setAttribute(GEN_AI_AGENT_DESCRIPTION, agentDescription); span.setAttribute(GEN_AI_AGENT_NAME, agentName); if (invocationContext.session() != null && invocationContext.session().id() != null) { @@ -207,58 +199,62 @@ public static void traceAgentInvocation( } /** - * Traces tool call arguments. - * - * @param args The arguments to the tool call. - */ - public static void traceToolCall( - String toolName, String toolDescription, String toolType, Map args) { - traceWithSpan( - "traceToolCall", - span -> { - setToolExecutionAttributes(span); - span.setAttribute(GEN_AI_TOOL_NAME, toolName); - span.setAttribute(GEN_AI_TOOL_DESCRIPTION, toolDescription); - span.setAttribute(GEN_AI_TOOL_TYPE, toolType); - - setJsonAttribute(span, ADK_TOOL_CALL_ARGS, args); - }); - } - - /** - * Traces tool response event. + * Traces a tool execution, including its arguments, response, and any potential error. * - * @param eventId The ID of the event. - * @param functionResponseEvent The function response event. + * @param span The span representing the tool execution. + * @param toolName The name of the tool. + * @param toolDescription The tool's description. + * @param toolType The tool's type (e.g., "FunctionTool"). + * @param args The arguments passed to the tool. + * @param functionResponseEvent The event containing the tool's response, if successful. + * @param error The exception thrown during execution, if any. */ - public static void traceToolResponse(String eventId, Event functionResponseEvent) { - traceWithSpan( - "traceToolResponse", - span -> { - setToolExecutionAttributes(span); - span.setAttribute(ADK_EVENT_ID, eventId); - - FunctionResponse functionResponse = - functionResponseEvent.functionResponses().stream().findFirst().orElse(null); - - String toolCallId = ""; - Object toolResponse = ""; - if (functionResponse != null) { - toolCallId = functionResponse.id().orElse(toolCallId); - if (functionResponse.response().isPresent()) { - toolResponse = functionResponse.response().get(); - } - } - - span.setAttribute(GEN_AI_TOOL_CALL_ID, toolCallId); + public static void traceToolExecution( + Span span, + String toolName, + String toolDescription, + String toolType, + Map args, + @Nullable Event functionResponseEvent, + @Nullable Exception error) { + span.setAttribute(GEN_AI_OPERATION_NAME, EXECUTE_TOOL_OPERATION); + span.setAttribute(GEN_AI_TOOL_NAME, toolName); + span.setAttribute(GEN_AI_TOOL_DESCRIPTION, toolDescription); + span.setAttribute(GEN_AI_TOOL_TYPE, toolType); + + setJsonAttribute(span, ADK_TOOL_CALL_ARGS, args); + + if (functionResponseEvent != null) { + span.setAttribute(ADK_EVENT_ID, functionResponseEvent.id()); + FunctionResponse functionResponse = + functionResponseEvent.functionResponses().stream().findFirst().orElse(null); + + String toolCallId = ""; + Object toolResponse = ""; + if (functionResponse != null) { + toolCallId = functionResponse.id().orElse(toolCallId); + if (functionResponse.response().isPresent()) { + toolResponse = functionResponse.response().get(); + } + } + span.setAttribute(GEN_AI_TOOL_CALL_ID, toolCallId); + Object finalToolResponse = + (toolResponse instanceof Map) ? toolResponse : ImmutableMap.of("result", toolResponse); + setJsonAttribute(span, ADK_TOOL_RESPONSE, finalToolResponse); + } else { + // Set placeholder if no response event is available (e.g., due to an error) + span.setAttribute(GEN_AI_TOOL_CALL_ID, ""); + setJsonAttribute(span, ADK_TOOL_RESPONSE, "{}"); + } - Object finalToolResponse = - (toolResponse instanceof Map) - ? toolResponse - : ImmutableMap.of("result", toolResponse); + // Also set empty LLM attributes for UI compatibility, like in traceToolResponse + span.setAttribute(ADK_LLM_REQUEST, "{}"); + span.setAttribute(ADK_LLM_RESPONSE, "{}"); - setJsonAttribute(span, ADK_TOOL_RESPONSE, finalToolResponse); - }); + if (error != null) { + span.setStatus(StatusCode.ERROR, error.getMessage()); + span.recordException(error); + } } /** @@ -303,8 +299,10 @@ public static void traceCallLlm( InvocationContext invocationContext, String eventId, LlmRequest llmRequest, - LlmResponse llmResponse) { + LlmResponse llmResponse, + @Nullable Exception error) { span.setAttribute(GEN_AI_SYSTEM, "gcp.vertex.agent"); + span.setAttribute(GEN_AI_OPERATION_NAME, CALL_LLM_OPERATION); llmRequest.model().ifPresent(modelName -> span.setAttribute(GEN_AI_REQUEST_MODEL, modelName)); setInvocationAttributes(span, invocationContext, eventId); @@ -312,6 +310,11 @@ public static void traceCallLlm( setJsonAttribute(span, ADK_LLM_REQUEST, buildLlmRequestForTrace(llmRequest)); setJsonAttribute(span, ADK_LLM_RESPONSE, llmResponse); + if (error != null) { + span.setStatus(StatusCode.ERROR, error.getMessage()); + span.recordException(error); + } + llmRequest .config() .ifPresent( @@ -352,18 +355,45 @@ public static void traceCallLlm( * @param data A list of content objects being sent. */ public static void traceSendData( - InvocationContext invocationContext, String eventId, List data) { - traceWithSpan( - "traceSendData", - span -> { - setInvocationAttributes(span, invocationContext, eventId); - - ImmutableList safeData = - Optional.ofNullable(data).orElse(ImmutableList.of()).stream() - .filter(Objects::nonNull) - .collect(toImmutableList()); - setJsonAttribute(span, ADK_DATA, safeData); - }); + Span span, InvocationContext invocationContext, String eventId, List data) { + if (!span.getSpanContext().isValid()) { + log.trace("traceSendData: No valid span in current context."); + return; + } + setInvocationAttributes(span, invocationContext, eventId); + span.setAttribute(GEN_AI_OPERATION_NAME, SEND_DATA_OPERATION); + + ImmutableList safeData = + Optional.ofNullable(data).orElse(ImmutableList.of()).stream() + .filter(Objects::nonNull) + .collect(toImmutableList()); + setJsonAttribute(span, ADK_DATA, safeData); + } + + /** + * Traces merged tool call events. + * + *

Calling this function is not needed for telemetry purposes. This is provided for preventing + * /debug/trace requests (typically sent by web UI). + * + * @param responseEventId The ID of the response event. + * @param functionResponseEvent The merged response event. + */ + public static void traceMergedToolCalls( + Span span, String responseEventId, Event functionResponseEvent) { + if (!span.getSpanContext().isValid()) { + log.trace("traceMergedToolCalls: No valid span in current context."); + return; + } + span.setAttribute(GEN_AI_OPERATION_NAME, EXECUTE_TOOL_OPERATION); + span.setAttribute(GEN_AI_TOOL_NAME, "(merged tools)"); + span.setAttribute(GEN_AI_TOOL_DESCRIPTION, "(merged tools)"); + span.setAttribute(GEN_AI_TOOL_CALL_ID, responseEventId); + span.setAttribute(ADK_TOOL_CALL_ARGS, "N/A"); + span.setAttribute(ADK_EVENT_ID, responseEventId); + setJsonAttribute(span, ADK_TOOL_RESPONSE, functionResponseEvent); + span.setAttribute(ADK_LLM_REQUEST, "{}"); + span.setAttribute(ADK_LLM_RESPONSE, "{}"); } /** diff --git a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java index a9e7a6f8d..e40a83aa0 100644 --- a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java +++ b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java @@ -494,12 +494,10 @@ public void runAsync_withTools_createsToolSpans() throws InterruptedException { List spans = openTelemetryRule.getSpans(); SpanData agentSpan = findSpanByName(spans, "invoke_agent test agent"); List llmSpans = findSpansByName(spans, "call_llm"); - List toolCallSpans = findSpansByName(spans, "tool_call [echo_tool]"); - List toolResponseSpans = findSpansByName(spans, "tool_response [echo_tool]"); + List toolSpans = findSpansByName(spans, "execute_tool [echo_tool]"); assertThat(llmSpans).hasSize(2); - assertThat(toolCallSpans).hasSize(1); - assertThat(toolResponseSpans).hasSize(1); + assertThat(toolSpans).hasSize(1); String agentSpanId = agentSpan.getSpanContext().getSpanId(); llmSpans.forEach(s -> assertEquals(agentSpanId, s.getParentSpanContext().getSpanId())); @@ -507,9 +505,7 @@ public void runAsync_withTools_createsToolSpans() throws InterruptedException { // The tool calls and responses are children of the first LLM call that produced the function // call. String firstLlmSpanId = llmSpans.get(0).getSpanContext().getSpanId(); - toolCallSpans.forEach(s -> assertEquals(firstLlmSpanId, s.getParentSpanContext().getSpanId())); - toolResponseSpans.forEach( - s -> assertEquals(firstLlmSpanId, s.getParentSpanContext().getSpanId())); + toolSpans.forEach(s -> assertEquals(firstLlmSpanId, s.getParentSpanContext().getSpanId())); } @Test diff --git a/core/src/test/java/com/google/adk/runner/RunnerTest.java b/core/src/test/java/com/google/adk/runner/RunnerTest.java index efd565c16..b68b6ff5f 100644 --- a/core/src/test/java/com/google/adk/runner/RunnerTest.java +++ b/core/src/test/java/com/google/adk/runner/RunnerTest.java @@ -1061,21 +1061,16 @@ public void runAsync_createsToolSpansWithCorrectParent() { List spans = openTelemetryRule.getSpans(); List llmSpans = spans.stream().filter(s -> s.getName().equals("call_llm")).toList(); - List toolCallSpans = - spans.stream().filter(s -> s.getName().equals("tool_call [echo_tool]")).toList(); - List toolResponseSpans = - spans.stream().filter(s -> s.getName().equals("tool_response [echo_tool]")).toList(); + List toolSpans = + spans.stream().filter(s -> s.getName().equals("execute_tool [echo_tool]")).toList(); assertThat(llmSpans).hasSize(2); - assertThat(toolCallSpans).hasSize(1); - assertThat(toolResponseSpans).hasSize(1); + assertThat(toolSpans).hasSize(1); List llmSpanIds = llmSpans.stream().map(s -> s.getSpanContext().getSpanId()).toList(); - String toolCallParentId = toolCallSpans.get(0).getParentSpanContext().getSpanId(); - String toolResponseParentId = toolResponseSpans.get(0).getParentSpanContext().getSpanId(); + String toolParentId = toolSpans.get(0).getParentSpanContext().getSpanId(); - assertThat(toolCallParentId).isEqualTo(toolResponseParentId); - assertThat(llmSpanIds).contains(toolCallParentId); + assertThat(llmSpanIds).contains(toolParentId); } @Test @@ -1101,22 +1096,17 @@ public void runLive_createsToolSpansWithCorrectParent() throws Exception { List spans = openTelemetryRule.getSpans(); List llmSpans = spans.stream().filter(s -> s.getName().equals("call_llm")).toList(); - List toolCallSpans = - spans.stream().filter(s -> s.getName().equals("tool_call [echo_tool]")).toList(); - List toolResponseSpans = - spans.stream().filter(s -> s.getName().equals("tool_response [echo_tool]")).toList(); + List toolSpans = + spans.stream().filter(s -> s.getName().equals("execute_tool [echo_tool]")).toList(); // In runLive, there is one call_llm span for the execution assertThat(llmSpans).hasSize(1); - assertThat(toolCallSpans).hasSize(1); - assertThat(toolResponseSpans).hasSize(1); + assertThat(toolSpans).hasSize(1); List llmSpanIds = llmSpans.stream().map(s -> s.getSpanContext().getSpanId()).toList(); - String toolCallParentId = toolCallSpans.get(0).getParentSpanContext().getSpanId(); - String toolResponseParentId = toolResponseSpans.get(0).getParentSpanContext().getSpanId(); + String toolParentId = toolSpans.get(0).getParentSpanContext().getSpanId(); - assertThat(toolCallParentId).isEqualTo(toolResponseParentId); - assertThat(llmSpanIds).contains(toolCallParentId); + assertThat(llmSpanIds).contains(toolParentId); } @Test diff --git a/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java b/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java index b13904934..44877e972 100644 --- a/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java +++ b/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java @@ -18,7 +18,6 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import com.google.adk.agents.BaseAgent; @@ -100,242 +99,6 @@ public void tearDown() { Tracing.setTracerForTesting(originalTracer); } - @Test - public void testToolCallSpanLinksToParent() { - // Given: Parent span is active - Span parentSpan = tracer.spanBuilder("parent").startSpan(); - - try (Scope scope = parentSpan.makeCurrent()) { - // When: ADK creates tool_call span with setParent(Context.current()) - Span toolCallSpan = - tracer.spanBuilder("tool_call [testTool]").setParent(Context.current()).startSpan(); - - try (Scope toolScope = toolCallSpan.makeCurrent()) { - // Simulate tool execution - } finally { - toolCallSpan.end(); - } - } finally { - parentSpan.end(); - } - - // Then: tool_call should be child of parent - SpanData parentSpanData = findSpanByName("parent"); - SpanData toolCallSpanData = findSpanByName("tool_call [testTool]"); - - // Verify parent-child relationship - assertEquals( - "Tool call should have same trace ID as parent", - parentSpanData.getSpanContext().getTraceId(), - toolCallSpanData.getSpanContext().getTraceId()); - - assertParent(parentSpanData, toolCallSpanData); - } - - @Test - public void testToolCallWithoutParentCreatesRootSpan() { - // Given: No parent span active - // When: ADK creates tool_call span with setParent(Context.current()) - try (Scope s = Context.root().makeCurrent()) { - Span toolCallSpan = - tracer.spanBuilder("tool_call [testTool]").setParent(Context.current()).startSpan(); - - try (Scope scope = toolCallSpan.makeCurrent()) { - // Work - } finally { - toolCallSpan.end(); - } - } - - // Then: Should create root span (backward compatible) - List spans = openTelemetryRule.getSpans(); - assertThat(spans).hasSize(1); - - SpanData toolCallSpanData = spans.get(0); - assertFalse( - "Tool call should be root span when no parent exists", - toolCallSpanData.getParentSpanContext().isValid()); - } - - @Test - public void testNestedSpanHierarchy() { - // Test: parent → invocation → tool_call → tool_response hierarchy - - Span parentSpan = tracer.spanBuilder("parent").startSpan(); - - try (Scope parentScope = parentSpan.makeCurrent()) { - - Span invocationSpan = - tracer.spanBuilder("invocation").setParent(Context.current()).startSpan(); - - try (Scope invocationScope = invocationSpan.makeCurrent()) { - - Span toolCallSpan = - tracer.spanBuilder("tool_call [testTool]").setParent(Context.current()).startSpan(); - - try (Scope toolScope = toolCallSpan.makeCurrent()) { - - Span toolResponseSpan = - tracer - .spanBuilder("tool_response [testTool]") - .setParent(Context.current()) - .startSpan(); - - toolResponseSpan.end(); - } finally { - toolCallSpan.end(); - } - } finally { - invocationSpan.end(); - } - } finally { - parentSpan.end(); - } - - // Verify complete hierarchy - List spans = openTelemetryRule.getSpans(); - // The 4 spans are: "parent", "invocation", "tool_call [testTool]", and "tool_response - // [testTool]". - assertThat(spans).hasSize(4); - - SpanData parentSpanData = findSpanByName("parent"); - String parentTraceId = parentSpanData.getSpanContext().getTraceId(); - - // All spans should have same trace ID - for (SpanData span : openTelemetryRule.getSpans()) { - assertEquals( - "All spans should be in same trace", parentTraceId, span.getSpanContext().getTraceId()); - } - - // Verify parent-child relationships - SpanData invocationSpanData = findSpanByName("invocation"); - SpanData toolCallSpanData = findSpanByName("tool_call [testTool]"); - SpanData toolResponseSpanData = findSpanByName("tool_response [testTool]"); - - // invocation should be child of parent - assertParent(parentSpanData, invocationSpanData); - - // tool_call should be child of invocation - assertParent(invocationSpanData, toolCallSpanData); - - // tool_response should be child of tool_call - assertParent(toolCallSpanData, toolResponseSpanData); - } - - @Test - public void testMultipleSpansInParallel() { - // Test: Multiple tool calls in parallel should all link to same parent - - Span parentSpan = tracer.spanBuilder("parent").startSpan(); - - try (Scope parentScope = parentSpan.makeCurrent()) { - // Simulate parallel tool calls - Span toolCall1 = - tracer.spanBuilder("tool_call [tool1]").setParent(Context.current()).startSpan(); - Span toolCall2 = - tracer.spanBuilder("tool_call [tool2]").setParent(Context.current()).startSpan(); - Span toolCall3 = - tracer.spanBuilder("tool_call [tool3]").setParent(Context.current()).startSpan(); - - toolCall1.end(); - toolCall2.end(); - toolCall3.end(); - } finally { - parentSpan.end(); - } - - // Verify all tool calls link to same parent - SpanData parentSpanData = findSpanByName("parent"); - String parentTraceId = parentSpanData.getSpanContext().getTraceId(); - - // All tool calls should have same trace ID and parent span ID - List toolCallSpans = - openTelemetryRule.getSpans().stream() - .filter(s -> s.getName().startsWith("tool_call")) - .toList(); - - assertThat(toolCallSpans).hasSize(3); - - toolCallSpans.forEach( - span -> { - assertEquals( - "Tool call should have same trace ID as parent", - parentTraceId, - span.getSpanContext().getTraceId()); - assertParent(parentSpanData, span); - }); - } - - @Test - public void testInvokeAgentSpanLinksToInvocation() { - // Test: invoke_agent span should link to invocation span - - Span invocationSpan = tracer.spanBuilder("invocation").startSpan(); - - try (Scope invocationScope = invocationSpan.makeCurrent()) { - Span invokeAgentSpan = - tracer.spanBuilder("invoke_agent test-agent").setParent(Context.current()).startSpan(); - - try (Scope agentScope = invokeAgentSpan.makeCurrent()) { - // Simulate agent work - } finally { - invokeAgentSpan.end(); - } - } finally { - invocationSpan.end(); - } - - SpanData invocationSpanData = findSpanByName("invocation"); - SpanData invokeAgentSpanData = findSpanByName("invoke_agent test-agent"); - - assertParent(invocationSpanData, invokeAgentSpanData); - } - - @Test - public void testCallLlmSpanLinksToAgentRun() { - // Test: call_llm span should link to agent_run span - - Span invokeAgentSpan = tracer.spanBuilder("invoke_agent test-agent").startSpan(); - - try (Scope agentScope = invokeAgentSpan.makeCurrent()) { - Span callLlmSpan = tracer.spanBuilder("call_llm").setParent(Context.current()).startSpan(); - - try (Scope llmScope = callLlmSpan.makeCurrent()) { - // Simulate LLM call - } finally { - callLlmSpan.end(); - } - } finally { - invokeAgentSpan.end(); - } - - List spans = openTelemetryRule.getSpans(); - assertThat(spans).hasSize(2); - - SpanData invokeAgentSpanData = findSpanByName("invoke_agent test-agent"); - SpanData callLlmSpanData = findSpanByName("call_llm"); - - assertParent(invokeAgentSpanData, callLlmSpanData); - } - - @Test - public void testSpanCreatedWithinParentScopeIsCorrectlyParented() { - // Test: Simulates creating a span within the scope of a parent - - Span parentSpan = tracer.spanBuilder("invocation").startSpan(); - try (Scope scope = parentSpan.makeCurrent()) { - Span agentSpan = tracer.spanBuilder("invoke_agent").setParent(Context.current()).startSpan(); - agentSpan.end(); - } finally { - parentSpan.end(); - } - - SpanData parentSpanData = findSpanByName("invocation"); - SpanData agentSpanData = findSpanByName("invoke_agent"); - - assertParent(parentSpanData, agentSpanData); - } - @Test public void testTraceFlowable() throws InterruptedException { Span parentSpan = tracer.spanBuilder("parent").startSpan(); @@ -475,8 +238,14 @@ public void testTraceAgentInvocation() { public void testTraceToolCall() { Span span = tracer.spanBuilder("test").startSpan(); try (Scope scope = span.makeCurrent()) { - Tracing.traceToolCall( - "tool-name", "tool-description", "tool-type", ImmutableMap.of("arg1", "value1")); + Tracing.traceToolExecution( + span, + "tool-name", + "tool-description", + "tool-type", + ImmutableMap.of("arg1", "value1"), + null, + null); } finally { span.end(); } @@ -513,7 +282,14 @@ public void testTraceToolResponse() { .build()) .build())) .build(); - Tracing.traceToolResponse("event-1", functionResponseEvent); + Tracing.traceToolExecution( + span, + "tool-name", + "tool-description", + "tool-type", + ImmutableMap.of(), + functionResponseEvent, + null); } finally { span.end(); } @@ -524,6 +300,10 @@ public void testTraceToolResponse() { assertEquals("execute_tool", attrs.get(AttributeKey.stringKey("gen_ai.operation.name"))); assertEquals("event-1", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.event_id"))); assertEquals("tool-call-id", attrs.get(AttributeKey.stringKey("gen_ai.tool_call.id"))); + assertEquals("tool-name", attrs.get(AttributeKey.stringKey("gen_ai.tool.name"))); + assertEquals("tool-description", attrs.get(AttributeKey.stringKey("gen_ai.tool.description"))); + assertEquals("tool-type", attrs.get(AttributeKey.stringKey("gen_ai.tool.type"))); + assertEquals("{}", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.tool_call_args"))); assertEquals( "{\"result\":\"tool-result\"}", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.tool_response"))); @@ -550,7 +330,8 @@ public void testTraceCallLlm() { .totalTokenCount(30) .build()) .build(); - Tracing.traceCallLlm(span, buildInvocationContext(), "event-1", llmRequest, llmResponse); + Tracing.traceCallLlm( + span, buildInvocationContext(), "event-1", llmRequest, llmResponse, null); } finally { span.end(); } @@ -559,6 +340,7 @@ public void testTraceCallLlm() { SpanData spanData = spans.get(0); Attributes attrs = spanData.getAttributes(); assertEquals("gcp.vertex.agent", attrs.get(AttributeKey.stringKey("gen_ai.system"))); + assertEquals("call_llm", attrs.get(AttributeKey.stringKey("gen_ai.operation.name"))); assertEquals("gemini-pro", attrs.get(AttributeKey.stringKey("gen_ai.request.model"))); assertEquals( "test-invocation-id", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.invocation_id"))); @@ -581,6 +363,7 @@ public void testTraceSendData() { Span span = tracer.spanBuilder("test").startSpan(); try (Scope scope = span.makeCurrent()) { Tracing.traceSendData( + span, buildInvocationContext(), "event-1", ImmutableList.of(Content.fromParts(Part.fromText("hello")))); @@ -591,6 +374,7 @@ public void testTraceSendData() { assertThat(spans).hasSize(1); SpanData spanData = spans.get(0); Attributes attrs = spanData.getAttributes(); + assertEquals("send_data", attrs.get(AttributeKey.stringKey("gen_ai.operation.name"))); assertEquals( "test-invocation-id", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.invocation_id"))); assertEquals("event-1", attrs.get(AttributeKey.stringKey("gcp.vertex.agent.event_id"))); @@ -687,8 +471,7 @@ public void testAgentWithToolCallTraceHierarchy() throws InterruptedException { // invocation // └── invoke_agent test_agent // ├── call_llm - // │ ├── tool_call [search_flights] - // │ └── tool_response [search_flights] + // │ └── execute_tool [search_flights] // └── call_llm SearchFlightsTool searchFlightsTool = new SearchFlightsTool(); @@ -716,8 +499,7 @@ public void testAgentWithToolCallTraceHierarchy() throws InterruptedException { SpanData invocation = findSpanByName("invocation"); SpanData invokeAgent = findSpanByName("invoke_agent test_agent"); - SpanData toolCall = findSpanByName("tool_call [search_flights]"); - SpanData toolResponse = findSpanByName("tool_response [search_flights]"); + SpanData toolResponse = findSpanByName("execute_tool [search_flights]"); List callLlmSpans = openTelemetryRule.getSpans().stream() .filter(s -> s.getName().equals("call_llm")) @@ -733,12 +515,28 @@ public void testAgentWithToolCallTraceHierarchy() throws InterruptedException { assertParent(invocation, invokeAgent); // ├── call_llm 1 assertParent(invokeAgent, callLlm1); - // │ ├── tool_call [search_flights] - assertParent(callLlm1, toolCall); - // │ └── tool_response [search_flights] + // │ └── execute_tool [search_flights] assertParent(callLlm1, toolResponse); // └── call_llm 2 assertParent(invokeAgent, callLlm2); + + // Assert attributes + assertEquals( + "invoke_agent", + invokeAgent.getAttributes().get(AttributeKey.stringKey("gen_ai.operation.name"))); + assertEquals( + "call_llm", callLlm1.getAttributes().get(AttributeKey.stringKey("gen_ai.operation.name"))); + assertEquals( + "execute_tool", + toolResponse.getAttributes().get(AttributeKey.stringKey("gen_ai.operation.name"))); + assertEquals( + "search_flights", + toolResponse.getAttributes().get(AttributeKey.stringKey("gen_ai.tool.name"))); + assertEquals( + "execute_tool", + toolResponse.getAttributes().get(AttributeKey.stringKey("gen_ai.operation.name"))); + assertEquals( + "call_llm", callLlm2.getAttributes().get(AttributeKey.stringKey("gen_ai.operation.name"))); } @Test @@ -748,8 +546,7 @@ public void testNestedAgentTraceHierarchy() throws InterruptedException { // invocation // └── invoke_agent AgentA // ├── call_llm - // │ ├── tool_call [transfer_to_agent] - // │ └── tool_response [transfer_to_agent] + // │ └── execute_tool [transfer_to_agent] // └── invoke_agent AgentB // └── call_llm TestLlm llm = @@ -776,9 +573,8 @@ public void testNestedAgentTraceHierarchy() throws InterruptedException { SpanData invocation = findSpanByName("invocation"); SpanData agentASpan = findSpanByName("invoke_agent AgentA"); - SpanData toolCall = findSpanByName("tool_call [transfer_to_agent]"); + SpanData executeTool = findSpanByName("execute_tool [transfer_to_agent]"); SpanData agentBSpan = findSpanByName("invoke_agent AgentB"); - SpanData toolResponse = findSpanByName("tool_response [transfer_to_agent]"); List callLlmSpans = openTelemetryRule.getSpans().stream() @@ -792,8 +588,7 @@ public void testNestedAgentTraceHierarchy() throws InterruptedException { assertParent(invocation, agentASpan); assertParent(agentASpan, agentACallLlm1); - assertParent(agentACallLlm1, toolCall); - assertParent(agentACallLlm1, toolResponse); + assertParent(agentACallLlm1, executeTool); assertParent(agentASpan, agentBSpan); assertParent(agentBSpan, agentBCallLlm); }