From 3c8f4886f0e4c76abdbeb64a348bfccd5c16120e Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Mon, 9 Mar 2026 18:45:28 -0700 Subject: [PATCH] feat: Fixing the spans produced by agent calls to have the right parent spans PiperOrigin-RevId: 881142814 --- .../java/com/google/adk/agents/BaseAgent.java | 107 ++-- .../adk/flows/llmflows/BaseLlmFlow.java | 390 ++++++-------- .../google/adk/flows/llmflows/Functions.java | 227 ++++---- .../com/google/adk/plugins/PluginManager.java | 16 +- .../java/com/google/adk/runner/Runner.java | 167 +++--- .../com/google/adk/telemetry/Tracing.java | 488 +++++------------- 6 files changed, 508 insertions(+), 887 deletions(-) diff --git a/core/src/main/java/com/google/adk/agents/BaseAgent.java b/core/src/main/java/com/google/adk/agents/BaseAgent.java index c527eeab3..d74ba9ca5 100644 --- a/core/src/main/java/com/google/adk/agents/BaseAgent.java +++ b/core/src/main/java/com/google/adk/agents/BaseAgent.java @@ -29,10 +29,10 @@ import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.errorprone.annotations.DoNotCall; import com.google.genai.types.Content; -import io.opentelemetry.context.Context; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; +import io.reactivex.rxjava3.core.Single; import java.util.ArrayList; import java.util.HashSet; import java.util.List; @@ -312,47 +312,38 @@ public Flowable runAsync(InvocationContext parentContext) { private Flowable run( InvocationContext parentContext, Function> runImplementation) { - Context otelParentContext = Context.current(); - InvocationContext invocationContext = createInvocationContext(parentContext); - return Flowable.defer( - () -> { - return callCallback( - beforeCallbacksToFunctions( - invocationContext.pluginManager(), beforeAgentCallback), - invocationContext) - .flatMapPublisher( - beforeEvent -> { - if (invocationContext.endInvocation()) { - return Flowable.just(beforeEvent); - } - - return Flowable.just(beforeEvent) - .concatWith(runMainAndAfter(invocationContext, runImplementation)); - }) - .switchIfEmpty( - Flowable.defer(() -> runMainAndAfter(invocationContext, runImplementation))); - }) - .compose( - Tracing.traceAgent( - otelParentContext, - "invoke_agent " + name(), - name(), - description(), - invocationContext)); - } - - private Flowable runMainAndAfter( - InvocationContext invocationContext, - Function> runImplementation) { - Flowable mainEvents = runImplementation.apply(invocationContext); - Flowable afterEvents = - callCallback( - afterCallbacksToFunctions(invocationContext.pluginManager(), afterAgentCallback), - invocationContext) - .flatMapPublisher(Flowable::just); - - return Flowable.concat(mainEvents, afterEvents); + () -> { + InvocationContext invocationContext = createInvocationContext(parentContext); + + return callCallback( + beforeCallbacksToFunctions( + invocationContext.pluginManager(), beforeAgentCallback), + invocationContext) + .flatMapPublisher( + beforeEventOpt -> { + if (invocationContext.endInvocation()) { + return Flowable.fromOptional(beforeEventOpt); + } + + Flowable beforeEvents = Flowable.fromOptional(beforeEventOpt); + Flowable mainEvents = + Flowable.defer(() -> runImplementation.apply(invocationContext)); + Flowable afterEvents = + Flowable.defer( + () -> + callCallback( + afterCallbacksToFunctions( + invocationContext.pluginManager(), afterAgentCallback), + invocationContext) + .flatMapPublisher(Flowable::fromOptional)); + + return Flowable.concat(beforeEvents, mainEvents, afterEvents); + }) + .compose( + Tracing.traceAgent( + "invoke_agent " + name(), name(), description(), invocationContext)); + }); } /** @@ -392,13 +383,13 @@ private ImmutableList>> callbacksTo * * @param agentCallbacks Callback functions. * @param invocationContext Current invocation context. - * @return Maybe emitting first event, or empty if none. + * @return single emitting first event, or empty if none. */ - private Maybe callCallback( + private Single> callCallback( List>> agentCallbacks, InvocationContext invocationContext) { if (agentCallbacks.isEmpty()) { - return Maybe.empty(); + return Single.just(Optional.empty()); } CallbackContext callbackContext = @@ -407,25 +398,27 @@ private Maybe callCallback( return Flowable.fromIterable(agentCallbacks) .concatMap( callback -> { - return callback - .apply(callbackContext) + Maybe maybeContent = callback.apply(callbackContext); + + return maybeContent .map( content -> { invocationContext.setEndInvocation(true); - return Event.builder() - .id(Event.generateEventId()) - .invocationId(invocationContext.invocationId()) - .author(name()) - .branch(invocationContext.branch().orElse(null)) - .actions(callbackContext.eventActions()) - .content(content) - .build(); + return Optional.of( + Event.builder() + .id(Event.generateEventId()) + .invocationId(invocationContext.invocationId()) + .author(name()) + .branch(invocationContext.branch().orElse(null)) + .actions(callbackContext.eventActions()) + .content(content) + .build()); }) .toFlowable(); }) .firstElement() .switchIfEmpty( - Maybe.defer( + Single.defer( () -> { if (callbackContext.state().hasDelta()) { Event.Builder eventBuilder = @@ -436,9 +429,9 @@ private Maybe callCallback( .branch(invocationContext.branch().orElse(null)) .actions(callbackContext.eventActions()); - return Maybe.just(eventBuilder.build()); + return Single.just(Optional.of(eventBuilder.build())); } else { - return Maybe.empty(); + return Single.just(Optional.empty()); } })); } diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index fba7f10e0..6ed9ccaa3 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -92,9 +92,7 @@ public BaseLlmFlow( * events generated by them. */ protected Flowable preprocess( - InvocationContext context, - AtomicReference llmRequestRef, - Context otelParentContext) { + InvocationContext context, AtomicReference llmRequestRef) { LlmAgent agent = (LlmAgent) context.agent(); RequestProcessor toolsProcessor = @@ -106,8 +104,7 @@ protected Flowable preprocess( tool -> tool.processLlmRequest(builder, ToolContext.builder(ctx).build())) .andThen( Single.fromCallable( - () -> RequestProcessingResult.create(builder.build(), ImmutableList.of()))) - .compose(Tracing.withContext(otelParentContext)); + () -> RequestProcessingResult.create(builder.build(), ImmutableList.of()))); }; Iterable allProcessors = @@ -116,9 +113,7 @@ protected Flowable preprocess( return Flowable.fromIterable(allProcessors) .concatMap( processor -> - processor - .processRequest(context, llmRequestRef.get()) - .compose(Tracing.withContext(otelParentContext)) + Single.defer(() -> processor.processRequest(context, llmRequestRef.get())) .doOnSuccess(result -> llmRequestRef.set(result.updatedRequest())) .flattenAsFlowable( result -> result.events() != null ? result.events() : ImmutableList.of())); @@ -134,32 +129,13 @@ protected Flowable postprocess( Event baseEventForLlmResponse, LlmRequest llmRequest, LlmResponse llmResponse) { - return postprocess( - context, baseEventForLlmResponse, llmRequest, llmResponse, Context.current()); - } - - /** - * Post-processes the LLM response after receiving it from the LLM. Executes all registered {@link - * ResponseProcessor} instances. Emits events for the model response and any subsequent function - * calls. - */ - private Flowable postprocess( - InvocationContext context, - Event baseEventForLlmResponse, - LlmRequest llmRequest, - LlmResponse llmResponse, - Context otelParentContext) { List> eventIterables = new ArrayList<>(); Single currentLlmResponse = Single.just(llmResponse); for (ResponseProcessor processor : responseProcessors) { currentLlmResponse = currentLlmResponse - .flatMap( - response -> - processor - .processResponse(context, response) - .compose(Tracing.withContext(otelParentContext))) + .flatMap(response -> processor.processResponse(context, response)) .doOnSuccess( result -> { if (result.events() != null) { @@ -168,16 +144,15 @@ private Flowable postprocess( }) .map(ResponseProcessingResult::updatedResponse); } + Context parentContext = Context.current(); return currentLlmResponse.flatMapPublisher( - updatedResponse -> - buildPostprocessingEvents( - updatedResponse, - eventIterables, - context, - baseEventForLlmResponse, - llmRequest, - otelParentContext)); + updatedResponse -> { + try (Scope scope = parentContext.makeCurrent()) { + return buildPostprocessingEvents( + updatedResponse, eventIterables, context, baseEventForLlmResponse, llmRequest); + } + }); } /** @@ -189,100 +164,84 @@ private Flowable postprocess( * callbacks. Callbacks should not rely on its ID if they create their own separate events. */ private Flowable callLlm( - InvocationContext context, - LlmRequest llmRequest, - Event eventForCallbackUsage, - Context otelParentContext) { + InvocationContext context, LlmRequest llmRequest, Event eventForCallbackUsage) { LlmAgent agent = (LlmAgent) context.agent(); LlmRequest.Builder llmRequestBuilder = llmRequest.toBuilder(); - return handleBeforeModelCallback( - context, llmRequestBuilder, eventForCallbackUsage, otelParentContext) - .flatMapPublisher(Flowable::just) - .switchIfEmpty( - Flowable.defer( - () -> { - BaseLlm llm = - agent.resolvedModel().model().isPresent() - ? agent.resolvedModel().model().get() - : LlmRegistry.getLlm(agent.resolvedModel().modelName().get()); - return llm.generateContent( - llmRequestBuilder.build(), - context.runConfig().streamingMode() == StreamingMode.SSE) - .onErrorResumeNext( - exception -> - handleOnModelErrorCallback( - context, - llmRequestBuilder, - eventForCallbackUsage, - exception, - otelParentContext) - .switchIfEmpty(Single.error(exception)) - .toFlowable()) - .compose( - Tracing.trace("call_llm", otelParentContext) - .onSuccess( - (span, llmResp) -> - Tracing.traceCallLlm( - span, - context, - eventForCallbackUsage.id(), - llmRequestBuilder.build(), - llmResp))) - .doOnError( - error -> { - Span span = Span.current(); - span.setStatus(StatusCode.ERROR, error.getMessage()); - span.recordException(error); - }) - .concatMap( - llmResp -> - handleAfterModelCallback( - context, llmResp, eventForCallbackUsage, otelParentContext) - .toFlowable()); - })); + return handleBeforeModelCallback(context, llmRequestBuilder, eventForCallbackUsage) + .flatMapPublisher( + beforeResponse -> { + if (beforeResponse.isPresent()) { + return Flowable.just(beforeResponse.get()); + } + BaseLlm llm = + agent.resolvedModel().model().isPresent() + ? agent.resolvedModel().model().get() + : LlmRegistry.getLlm(agent.resolvedModel().modelName().get()); + return llm.generateContent( + llmRequestBuilder.build(), + context.runConfig().streamingMode() == StreamingMode.SSE) + .onErrorResumeNext( + exception -> + handleOnModelErrorCallback( + context, llmRequestBuilder, eventForCallbackUsage, exception) + .switchIfEmpty(Single.error(exception)) + .toFlowable()) + .doOnNext( + llmResp -> + Tracing.traceCallLlm( + context, + eventForCallbackUsage.id(), + llmRequestBuilder.build(), + llmResp)) + .doOnError( + error -> { + Span span = Span.current(); + span.setStatus(StatusCode.ERROR, error.getMessage()); + span.recordException(error); + }) + .compose(Tracing.trace("call_llm")) + .concatMap( + llmResp -> + handleAfterModelCallback(context, llmResp, eventForCallbackUsage) + .toFlowable()); + }); } /** * Invokes {@link BeforeModelCallback}s. If any returns a response, it's used instead of calling * the LLM. * - * @return A {@link Maybe} with the callback result. + * @return A {@link Single} with the callback result or {@link Optional#empty()}. */ - private Maybe handleBeforeModelCallback( - InvocationContext context, - LlmRequest.Builder llmRequestBuilder, - Event modelResponseEvent, - Context otelParentContext) { - try (Scope scope = otelParentContext.makeCurrent()) { - Event callbackEvent = modelResponseEvent.toBuilder().build(); - CallbackContext callbackContext = - new CallbackContext(context, callbackEvent.actions(), callbackEvent.id()); + private Single> handleBeforeModelCallback( + InvocationContext context, LlmRequest.Builder llmRequestBuilder, Event modelResponseEvent) { + Event callbackEvent = modelResponseEvent.toBuilder().build(); + CallbackContext callbackContext = + new CallbackContext(context, callbackEvent.actions(), callbackEvent.id()); - Maybe pluginResult = - context.pluginManager().beforeModelCallback(callbackContext, llmRequestBuilder); + Maybe pluginResult = + context.pluginManager().beforeModelCallback(callbackContext, llmRequestBuilder); - LlmAgent agent = (LlmAgent) context.agent(); - - List callbacks = agent.canonicalBeforeModelCallbacks(); - if (callbacks.isEmpty()) { - return pluginResult; - } + LlmAgent agent = (LlmAgent) context.agent(); - Maybe callbackResult = - Maybe.defer( - () -> - Flowable.fromIterable(callbacks) - .concatMapMaybe( - callback -> - callback - .call(callbackContext, llmRequestBuilder) - .compose(Tracing.withContext(otelParentContext))) - .firstElement()); - - return pluginResult.switchIfEmpty(callbackResult); + List callbacks = agent.canonicalBeforeModelCallbacks(); + if (callbacks.isEmpty()) { + return pluginResult.map(Optional::of).defaultIfEmpty(Optional.empty()); } + + Maybe callbackResult = + Maybe.defer( + () -> + Flowable.fromIterable(callbacks) + .concatMapMaybe(callback -> callback.call(callbackContext, llmRequestBuilder)) + .firstElement()); + + return pluginResult + .switchIfEmpty(callbackResult) + .map(Optional::of) + .defaultIfEmpty(Optional.empty()); } /** @@ -295,41 +254,32 @@ private Maybe handleOnModelErrorCallback( InvocationContext context, LlmRequest.Builder llmRequestBuilder, Event modelResponseEvent, - Throwable throwable, - Context otelParentContext) { - - try (Scope scope = otelParentContext.makeCurrent()) { - Event callbackEvent = modelResponseEvent.toBuilder().build(); - CallbackContext callbackContext = - new CallbackContext(context, callbackEvent.actions(), callbackEvent.id()); - Exception ex = throwable instanceof Exception e ? e : new Exception(throwable); - Maybe pluginResult = - context - .pluginManager() - .onModelErrorCallback(callbackContext, llmRequestBuilder, throwable); - - LlmAgent agent = (LlmAgent) context.agent(); - List callbacks = agent.canonicalOnModelErrorCallbacks(); - - if (callbacks.isEmpty()) { - return pluginResult; - } + Throwable throwable) { + Event callbackEvent = modelResponseEvent.toBuilder().build(); + CallbackContext callbackContext = + new CallbackContext(context, callbackEvent.actions(), callbackEvent.id()); + Exception ex = throwable instanceof Exception e ? e : new Exception(throwable); + + Maybe pluginResult = + context.pluginManager().onModelErrorCallback(callbackContext, llmRequestBuilder, throwable); - Maybe callbackResult = - Maybe.defer( - () -> { - LlmRequest llmRequest = llmRequestBuilder.build(); - return Flowable.fromIterable(callbacks) - .concatMapMaybe( - callback -> - callback - .call(callbackContext, llmRequest, ex) - .compose(Tracing.withContext(otelParentContext))) - .firstElement(); - }); - - return pluginResult.switchIfEmpty(callbackResult); + LlmAgent agent = (LlmAgent) context.agent(); + List callbacks = agent.canonicalOnModelErrorCallbacks(); + + if (callbacks.isEmpty()) { + return pluginResult; } + + Maybe callbackResult = + Maybe.defer( + () -> { + LlmRequest llmRequest = llmRequestBuilder.build(); + return Flowable.fromIterable(callbacks) + .concatMapMaybe(callback -> callback.call(callbackContext, llmRequest, ex)) + .firstElement(); + }); + + return pluginResult.switchIfEmpty(callbackResult); } /** @@ -339,39 +289,29 @@ private Maybe handleOnModelErrorCallback( * @return A {@link Single} with the final {@link LlmResponse}. */ private Single handleAfterModelCallback( - InvocationContext context, - LlmResponse llmResponse, - Event modelResponseEvent, - Context otelParentContext) { + InvocationContext context, LlmResponse llmResponse, Event modelResponseEvent) { + Event callbackEvent = modelResponseEvent.toBuilder().build(); + CallbackContext callbackContext = + new CallbackContext(context, callbackEvent.actions(), callbackEvent.id()); - try (Scope scope = otelParentContext.makeCurrent()) { - Event callbackEvent = modelResponseEvent.toBuilder().build(); - CallbackContext callbackContext = - new CallbackContext(context, callbackEvent.actions(), callbackEvent.id()); + Maybe pluginResult = + context.pluginManager().afterModelCallback(callbackContext, llmResponse); - Maybe pluginResult = - context.pluginManager().afterModelCallback(callbackContext, llmResponse); + LlmAgent agent = (LlmAgent) context.agent(); + List callbacks = agent.canonicalAfterModelCallbacks(); - LlmAgent agent = (LlmAgent) context.agent(); - List callbacks = agent.canonicalAfterModelCallbacks(); + if (callbacks.isEmpty()) { + return pluginResult.defaultIfEmpty(llmResponse); + } - if (callbacks.isEmpty()) { - return pluginResult.defaultIfEmpty(llmResponse); - } + Maybe callbackResult = + Maybe.defer( + () -> + Flowable.fromIterable(callbacks) + .concatMapMaybe(callback -> callback.call(callbackContext, llmResponse)) + .firstElement()); - Maybe callbackResult = - Maybe.defer( - () -> - Flowable.fromIterable(callbacks) - .concatMapMaybe( - callback -> - callback - .call(callbackContext, llmResponse) - .compose(Tracing.withContext(otelParentContext))) - .firstElement()); - - return pluginResult.switchIfEmpty(callbackResult).defaultIfEmpty(llmResponse); - } + return pluginResult.switchIfEmpty(callbackResult).defaultIfEmpty(llmResponse); } /** @@ -383,12 +323,13 @@ private Single handleAfterModelCallback( * @throws LlmCallsLimitExceededException if the agent exceeds allowed LLM invocations. * @throws IllegalStateException if a transfer agent is specified but not found. */ - private Flowable runOneStep(InvocationContext context, Context otelParentContext) { + private Flowable runOneStep(InvocationContext context) { AtomicReference llmRequestRef = new AtomicReference<>(LlmRequest.builder().build()); return Flowable.defer( () -> { - return preprocess(context, llmRequestRef, otelParentContext) + Context currentContext = Context.current(); + return preprocess(context, llmRequestRef) .concatWith( Flowable.defer( () -> { @@ -414,19 +355,15 @@ private Flowable runOneStep(InvocationContext context, Context otelParent .build(); mutableEventTemplate.setTimestamp(0L); - return callLlm( - context, - llmRequestAfterPreprocess, - mutableEventTemplate, - otelParentContext) + return callLlm(context, llmRequestAfterPreprocess, mutableEventTemplate) .concatMap( - llmResponse -> - postprocess( + llmResponse -> { + try (Scope postScope = currentContext.makeCurrent()) { + return postprocess( context, mutableEventTemplate, llmRequestAfterPreprocess, - llmResponse, - otelParentContext) + llmResponse) .doFinally( () -> { String oldId = mutableEventTemplate.id(); @@ -434,7 +371,9 @@ private Flowable runOneStep(InvocationContext context, Context otelParent logger.debug( "Resetting event ID from {} to {}", oldId, newId); mutableEventTemplate.setId(newId); - })) + }); + } + }) .concatMap( event -> { Flowable postProcessedEvents = Flowable.just(event); @@ -468,12 +407,11 @@ private Flowable runOneStep(InvocationContext context, Context otelParent */ @Override public Flowable run(InvocationContext invocationContext) { - return run(invocationContext, Context.current(), 0); + return run(invocationContext, 0); } - private Flowable run( - InvocationContext invocationContext, Context otelParentContext, int stepsCompleted) { - Flowable currentStepEvents = runOneStep(invocationContext, otelParentContext).cache(); + private Flowable run(InvocationContext invocationContext, int stepsCompleted) { + Flowable currentStepEvents = runOneStep(invocationContext).cache(); if (stepsCompleted + 1 >= maxSteps) { logger.debug("Ending flow execution because max steps reached."); return currentStepEvents; @@ -493,7 +431,7 @@ private Flowable run( return Flowable.empty(); } else { logger.debug("Continuing to next step of the flow."); - return run(invocationContext, otelParentContext, stepsCompleted + 1); + return run(invocationContext, stepsCompleted + 1); } })); } @@ -508,10 +446,8 @@ private Flowable run( */ @Override public Flowable runLive(InvocationContext invocationContext) { - Context otelParentContext = Context.current(); AtomicReference llmRequestRef = new AtomicReference<>(LlmRequest.builder().build()); - Flowable preprocessEvents = - preprocess(invocationContext, llmRequestRef, otelParentContext); + Flowable preprocessEvents = preprocess(invocationContext, llmRequestRef); return preprocessEvents.concatWith( Flowable.defer( @@ -533,7 +469,6 @@ public Flowable runLive(InvocationContext invocationContext) { ? Completable.complete() : connection .sendHistory(llmRequestAfterPreprocess.contents()) - .compose(Tracing.trace("send_data", otelParentContext)) .doOnComplete( () -> Tracing.traceSendData( @@ -549,7 +484,8 @@ public Flowable runLive(InvocationContext invocationContext) { invocationContext, eventIdForSendData, llmRequestAfterPreprocess.contents()); - }); + }) + .compose(Tracing.trace("send_data")); Flowable liveRequests = invocationContext @@ -606,16 +542,13 @@ public void onError(Throwable e) { .receive() .flatMap( llmResponse -> { - try (Scope scope = otelParentContext.makeCurrent()) { - Event baseEventForLlmResponse = - liveEventBuilderTemplate.id(Event.generateEventId()).build(); - return postprocess( - invocationContext, - baseEventForLlmResponse, - llmRequestAfterPreprocess, - llmResponse, - otelParentContext); - } + Event baseEventForThisLlmResponse = + liveEventBuilderTemplate.id(Event.generateEventId()).build(); + return postprocess( + invocationContext, + baseEventForThisLlmResponse, + llmRequestAfterPreprocess, + llmResponse); }) .flatMap( event -> { @@ -667,8 +600,7 @@ private Flowable buildPostprocessingEvents( List> eventIterables, InvocationContext context, Event baseEventForLlmResponse, - LlmRequest llmRequest, - Context otelParentContext) { + LlmRequest llmRequest) { Flowable processorEvents = Flowable.fromIterable(Iterables.concat(eventIterables)); if (updatedResponse.content().isEmpty() && updatedResponse.errorCode().isEmpty() @@ -684,27 +616,23 @@ private Flowable buildPostprocessingEvents( return processorEvents.concatWith(Flowable.just(modelResponseEvent)); } - try (Scope scope = otelParentContext.makeCurrent()) { - Maybe maybeFunctionResponseEvent = - context.runConfig().streamingMode() == StreamingMode.BIDI - ? Functions.handleFunctionCallsLive(context, modelResponseEvent, llmRequest.tools()) - : Functions.handleFunctionCalls(context, modelResponseEvent, llmRequest.tools()); - - Flowable functionEvents = - maybeFunctionResponseEvent.flatMapPublisher( - functionResponseEvent -> { - Optional toolConfirmationEvent = - Functions.generateRequestConfirmationEvent( - context, modelResponseEvent, functionResponseEvent); - return toolConfirmationEvent.isPresent() - ? Flowable.just(toolConfirmationEvent.get(), functionResponseEvent) - : Flowable.just(functionResponseEvent); - }); - - return processorEvents - .concatWith(Flowable.just(modelResponseEvent)) - .concatWith(functionEvents); - } + Maybe maybeFunctionResponseEvent = + context.runConfig().streamingMode() == StreamingMode.BIDI + ? Functions.handleFunctionCallsLive(context, modelResponseEvent, llmRequest.tools()) + : Functions.handleFunctionCalls(context, modelResponseEvent, llmRequest.tools()); + + Flowable functionEvents = + maybeFunctionResponseEvent.flatMapPublisher( + functionResponseEvent -> { + Optional toolConfirmationEvent = + Functions.generateRequestConfirmationEvent( + context, modelResponseEvent, functionResponseEvent); + return toolConfirmationEvent.isPresent() + ? Flowable.just(toolConfirmationEvent.get(), functionResponseEvent) + : Flowable.just(functionResponseEvent); + }); + + return processorEvents.concatWith(Flowable.just(modelResponseEvent)).concatWith(functionEvents); } private Event buildModelResponseEvent( diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java index f8b9e180d..ecc2bb412 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java @@ -257,8 +257,7 @@ private static Function> getFunctionCallMapper( functionCall.args().map(HashMap::new).orElse(new HashMap<>()); Maybe> maybeFunctionResult = - maybeInvokeBeforeToolCall( - invocationContext, tool, functionArgs, toolContext, parentContext) + maybeInvokeBeforeToolCall(invocationContext, tool, functionArgs, toolContext) .switchIfEmpty( Maybe.defer( () -> { @@ -396,49 +395,48 @@ private static Maybe postProcessFunctionResult( .defaultIfEmpty(Optional.empty()) .onErrorResumeNext( t -> { - try (Scope scope = parentContext.makeCurrent()) { - Maybe> errorCallbackResult = - handleOnToolErrorCallback( - invocationContext, tool, functionArgs, toolContext, t, parentContext); - Maybe>> mappedResult; - if (isLive) { - // In live mode, handle null results from the error callback gracefully. - mappedResult = errorCallbackResult.map(Optional::ofNullable); - } else { - // In non-live mode, a null result from the error callback will cause an NPE - // when wrapped with Optional.of(), potentially matching prior behavior. - mappedResult = errorCallbackResult.map(Optional::of); - } - return mappedResult.switchIfEmpty(Single.error(t)); + Maybe> errorCallbackResult = + handleOnToolErrorCallback(invocationContext, tool, functionArgs, toolContext, t); + Maybe>> mappedResult; + if (isLive) { + // In live mode, handle null results from the error callback gracefully. + mappedResult = errorCallbackResult.map(Optional::ofNullable); + } else { + // In non-live mode, a null result from the error callback will cause an NPE + // when wrapped with Optional.of(), potentially matching prior behavior. + mappedResult = errorCallbackResult.map(Optional::of); } + return mappedResult.switchIfEmpty(Single.error(t)); }) .flatMapMaybe( optionalInitialResult -> { - Map initialFunctionResult = optionalInitialResult.orElse(null); - - return maybeInvokeAfterToolCall( - invocationContext, - tool, - functionArgs, - toolContext, - initialFunctionResult, - parentContext) - .map(Optional::of) - .defaultIfEmpty(Optional.ofNullable(initialFunctionResult)) - .flatMapMaybe( - finalOptionalResult -> { - Map finalFunctionResult = finalOptionalResult.orElse(null); - if (tool.longRunning() && finalFunctionResult == null) { - return Maybe.empty(); - } - return Maybe.fromCallable( - () -> - buildResponseEvent( - tool, finalFunctionResult, toolContext, invocationContext)) - .compose( - Tracing.trace("tool_response [" + tool.name() + "]", parentContext)) - .doOnSuccess(event -> Tracing.traceToolResponse(event.id(), event)); - }); + try (Scope scope = parentContext.makeCurrent()) { + Map initialFunctionResult = optionalInitialResult.orElse(null); + + return maybeInvokeAfterToolCall( + invocationContext, tool, functionArgs, toolContext, initialFunctionResult) + .map(Optional::of) + .defaultIfEmpty(Optional.ofNullable(initialFunctionResult)) + .flatMapMaybe( + finalOptionalResult -> { + Map finalFunctionResult = + finalOptionalResult.orElse(null); + if (tool.longRunning() && finalFunctionResult == null) { + return Maybe.empty(); + } + return Maybe.fromCallable( + () -> + buildResponseEvent( + tool, + finalFunctionResult, + toolContext, + invocationContext)) + .compose( + Tracing.trace( + "tool_response [" + tool.name() + "]", parentContext)) + .doOnSuccess(event -> Tracing.traceToolResponse(event.id(), event)); + }); + } }); } @@ -481,32 +479,28 @@ private static Maybe> maybeInvokeBeforeToolCall( InvocationContext invocationContext, BaseTool tool, Map functionArgs, - ToolContext toolContext, - Context parentContext) { - if (invocationContext.agent() instanceof LlmAgent agent) { - try (Scope scope = parentContext.makeCurrent()) { - - Maybe> pluginResult = - invocationContext.pluginManager().beforeToolCallback(tool, functionArgs, toolContext); + ToolContext toolContext) { + if (invocationContext.agent() instanceof LlmAgent) { + LlmAgent agent = (LlmAgent) invocationContext.agent(); - List callbacks = agent.canonicalBeforeToolCallbacks(); - if (callbacks.isEmpty()) { - return pluginResult; - } + Maybe> pluginResult = + invocationContext.pluginManager().beforeToolCallback(tool, functionArgs, toolContext); - Maybe> callbackResult = - Maybe.defer( - () -> - Flowable.fromIterable(callbacks) - .concatMapMaybe( - callback -> - callback - .call(invocationContext, tool, functionArgs, toolContext) - .compose(Tracing.withContext(parentContext))) - .firstElement()); - - return pluginResult.switchIfEmpty(callbackResult); + List callbacks = agent.canonicalBeforeToolCallbacks(); + if (callbacks.isEmpty()) { + return pluginResult; } + + Maybe> callbackResult = + Maybe.defer( + () -> + Flowable.fromIterable(callbacks) + .concatMapMaybe( + callback -> + callback.call(invocationContext, tool, functionArgs, toolContext)) + .firstElement()); + + return pluginResult.switchIfEmpty(callbackResult); } return Maybe.empty(); } @@ -522,39 +516,34 @@ private static Maybe> handleOnToolErrorCallback( BaseTool tool, Map functionArgs, ToolContext toolContext, - Throwable throwable, - Context parentContext) { + Throwable throwable) { Exception ex = throwable instanceof Exception exception ? exception : new Exception(throwable); - try (Scope scope = parentContext.makeCurrent()) { - Maybe> pluginResult = - invocationContext - .pluginManager() - .onToolErrorCallback(tool, functionArgs, toolContext, throwable); - - if (invocationContext.agent() instanceof LlmAgent) { - LlmAgent agent = (LlmAgent) invocationContext.agent(); + Maybe> pluginResult = + invocationContext + .pluginManager() + .onToolErrorCallback(tool, functionArgs, toolContext, throwable); - List callbacks = agent.canonicalOnToolErrorCallbacks(); - if (callbacks.isEmpty()) { - return pluginResult; - } + if (invocationContext.agent() instanceof LlmAgent) { + LlmAgent agent = (LlmAgent) invocationContext.agent(); - Maybe> callbackResult = - Maybe.defer( - () -> - Flowable.fromIterable(callbacks) - .concatMapMaybe( - callback -> - callback - .call(invocationContext, tool, functionArgs, toolContext, ex) - .compose(Tracing.withContext(parentContext))) - .firstElement()); - - return pluginResult.switchIfEmpty(callbackResult); + List callbacks = agent.canonicalOnToolErrorCallbacks(); + if (callbacks.isEmpty()) { + return pluginResult; } - return pluginResult; + + Maybe> callbackResult = + Maybe.defer( + () -> + Flowable.fromIterable(callbacks) + .concatMapMaybe( + callback -> + callback.call(invocationContext, tool, functionArgs, toolContext, ex)) + .firstElement()); + + return pluginResult.switchIfEmpty(callbackResult); } + return pluginResult; } private static Maybe> maybeInvokeAfterToolCall( @@ -562,39 +551,35 @@ private static Maybe> maybeInvokeAfterToolCall( BaseTool tool, Map functionArgs, ToolContext toolContext, - Map functionResult, - Context parentContext) { - if (invocationContext.agent() instanceof LlmAgent agent) { + Map functionResult) { + if (invocationContext.agent() instanceof LlmAgent) { + LlmAgent agent = (LlmAgent) invocationContext.agent(); - try (Scope scope = parentContext.makeCurrent()) { - Maybe> pluginResult = - invocationContext - .pluginManager() - .afterToolCallback(tool, functionArgs, toolContext, functionResult); - - List callbacks = agent.canonicalAfterToolCallbacks(); - if (callbacks.isEmpty()) { - return pluginResult; - } + Maybe> pluginResult = + invocationContext + .pluginManager() + .afterToolCallback(tool, functionArgs, toolContext, functionResult); - Maybe> callbackResult = - Maybe.defer( - () -> - Flowable.fromIterable(callbacks) - .concatMapMaybe( - callback -> - callback - .call( - invocationContext, - tool, - functionArgs, - toolContext, - functionResult) - .compose(Tracing.withContext(parentContext))) - .firstElement()); - - return pluginResult.switchIfEmpty(callbackResult); + List callbacks = agent.canonicalAfterToolCallbacks(); + if (callbacks.isEmpty()) { + return pluginResult; } + + Maybe> callbackResult = + Maybe.defer( + () -> + Flowable.fromIterable(callbacks) + .concatMapMaybe( + callback -> + callback.call( + invocationContext, + tool, + functionArgs, + toolContext, + functionResult)) + .firstElement()); + + return pluginResult.switchIfEmpty(callbackResult); } return Maybe.empty(); } diff --git a/core/src/main/java/com/google/adk/plugins/PluginManager.java b/core/src/main/java/com/google/adk/plugins/PluginManager.java index 4d90ca7b5..56dea936a 100644 --- a/core/src/main/java/com/google/adk/plugins/PluginManager.java +++ b/core/src/main/java/com/google/adk/plugins/PluginManager.java @@ -21,13 +21,11 @@ import com.google.adk.events.Event; import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; -import com.google.adk.telemetry.Tracing; import com.google.adk.tools.BaseTool; import com.google.adk.tools.ToolContext; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.genai.types.Content; -import io.opentelemetry.context.Context; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; @@ -131,7 +129,6 @@ public Completable runAfterRunCallback(InvocationContext invocationContext) { @Override public Completable afterRunCallback(InvocationContext invocationContext) { - Context capturedContext = Context.current(); return Flowable.fromIterable(plugins) .concatMapCompletable( plugin -> @@ -142,13 +139,11 @@ public Completable afterRunCallback(InvocationContext invocationContext) { logger.error( "[{}] Error during callback 'afterRunCallback'", plugin.getName(), - e)) - .compose(Tracing.withContext(capturedContext))); + e))); } @Override public Completable close() { - Context capturedContext = Context.current(); return Flowable.fromIterable(plugins) .concatMapCompletableDelayError( plugin -> @@ -156,8 +151,8 @@ public Completable close() { .close() .doOnError( e -> - logger.error("[{}] Error during callback 'close'", plugin.getName(), e)) - .compose(Tracing.withContext(capturedContext))); + logger.error( + "[{}] Error during callback 'close'", plugin.getName(), e))); } public Maybe runOnEventCallback(InvocationContext invocationContext, Event event) { @@ -280,7 +275,7 @@ public Maybe> onToolErrorCallback( */ private Maybe runMaybeCallbacks( Function> callbackExecutor, String callbackName) { - Context capturedContext = Context.current(); + return Flowable.fromIterable(this.plugins) .concatMapMaybe( plugin -> @@ -299,8 +294,7 @@ private Maybe runMaybeCallbacks( "[{}] Error during callback '{}'", plugin.getName(), callbackName, - e)) - .compose(Tracing.withContext(capturedContext))) + e))) .firstElement(); } } diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index e35f5c33d..4371300fb 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -52,7 +52,6 @@ import com.google.genai.types.Part; import io.opentelemetry.api.trace.Span; import io.opentelemetry.api.trace.StatusCode; -import io.opentelemetry.context.Context; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; @@ -376,25 +375,20 @@ public Flowable runAsync( Content newMessage, RunConfig runConfig, @Nullable Map stateDelta) { - return Flowable.defer( - () -> - this.sessionService - .getSession(appName, userId, sessionId, Optional.empty()) - .switchIfEmpty( - Single.defer( - () -> { - if (runConfig.autoCreateSession()) { - return this.sessionService.createSession( - appName, userId, (Map) null, sessionId); - } - return Single.error( - new IllegalArgumentException( - String.format( - "Session not found: %s for user %s", sessionId, userId))); - })) - .flatMapPublisher( - session -> this.runAsyncImpl(session, newMessage, runConfig, stateDelta))) - .compose(Tracing.trace("invocation")); + Maybe maybeSession = + this.sessionService.getSession(appName, userId, sessionId, Optional.empty()); + return maybeSession + .switchIfEmpty( + Single.defer( + () -> { + if (runConfig.autoCreateSession()) { + return this.sessionService.createSession(appName, userId, null, sessionId); + } + return Single.error( + new IllegalArgumentException( + String.format("Session not found: %s for user %s", sessionId, userId))); + })) + .flatMapPublisher(session -> this.runAsyncImpl(session, newMessage, runConfig, stateDelta)); } /** See {@link #runAsync(String, String, Content, RunConfig, Map)}. */ @@ -447,8 +441,7 @@ public Flowable runAsync( Content newMessage, RunConfig runConfig, @Nullable Map stateDelta) { - return runAsyncImpl(session, newMessage, runConfig, stateDelta) - .compose(Tracing.trace("invocation")); + return runAsyncImpl(session, newMessage, runConfig, stateDelta); } /** @@ -467,7 +460,6 @@ protected Flowable runAsyncImpl( @Nullable Map stateDelta) { return Flowable.defer( () -> { - Context capturedContext = Context.current(); BaseAgent rootAgent = this.agent; String invocationId = InvocationContext.newInvocationContextId(); @@ -481,7 +473,6 @@ protected Flowable runAsyncImpl( return this.pluginManager .onUserMessageCallback(initialContext, newMessage) - .compose(Tracing.withContext(capturedContext)) .defaultIfEmpty(newMessage) .flatMap( content -> @@ -493,7 +484,6 @@ protected Flowable runAsyncImpl( runConfig.saveInputBlobsAsArtifacts(), stateDelta) : Single.just(null)) - .compose(Tracing.withContext(capturedContext)) .flatMapPublisher( event -> { if (event == null) { @@ -504,17 +494,15 @@ protected Flowable runAsyncImpl( return this.sessionService .getSession( session.appName(), session.userId(), session.id(), Optional.empty()) - .compose(Tracing.withContext(capturedContext)) .flatMapPublisher( updatedSession -> runAgentWithFreshSession( - session, - updatedSession, - event, - invocationId, - runConfig, - rootAgent) - .compose(Tracing.withContext(capturedContext))); + session, + updatedSession, + event, + invocationId, + runConfig, + rootAgent)); }); }) .doOnError( @@ -522,7 +510,8 @@ protected Flowable runAsyncImpl( Span span = Span.current(); span.setStatus(StatusCode.ERROR, "Error in runAsync Flowable execution"); span.recordException(throwable); - }); + }) + .compose(Tracing.trace("invocation")); } private Flowable runAgentWithFreshSession( @@ -579,7 +568,7 @@ private Flowable runAgentWithFreshSession( .toFlowable() .switchIfEmpty(agentEvents) .concatWith( - Completable.defer(() -> pluginManager.afterRunCallback(contextWithUpdatedSession))) + Completable.defer(() -> pluginManager.runAfterRunCallback(contextWithUpdatedSession))) .concatWith(Completable.defer(() -> compactEvents(updatedSession))); } @@ -652,51 +641,39 @@ private InvocationContext.Builder newInvocationContextBuilder(Session session) { */ public Flowable runLive( Session session, LiveRequestQueue liveRequestQueue, RunConfig runConfig) { - return runLiveImpl(session, liveRequestQueue, runConfig).compose(Tracing.trace("invocation")); - } - - /** - * Runs the agent in live mode, appending generated events to the session. - * - * @return stream of events from the agent. - */ - protected Flowable runLiveImpl( - Session session, @Nullable LiveRequestQueue liveRequestQueue, RunConfig runConfig) { return Flowable.defer( - () -> { - Context capturedContext = Context.current(); - InvocationContext invocationContext = - newInvocationContextForLive(session, liveRequestQueue, runConfig); - - Single invocationContextSingle; - if (invocationContext.agent() instanceof LlmAgent agent) { - invocationContextSingle = - agent - .tools() - .map( - tools -> { - this.addActiveStreamingTools(invocationContext, tools); - return invocationContext; - }); - } else { - invocationContextSingle = Single.just(invocationContext); - } - return invocationContextSingle - .compose(Tracing.withContext(capturedContext)) - .flatMapPublisher( - updatedInvocationContext -> - updatedInvocationContext - .agent() - .runLive(updatedInvocationContext) - .compose(Tracing.withContext(capturedContext)) - .doOnNext(event -> this.sessionService.appendEvent(session, event))) - .doOnError( - throwable -> { - Span span = Span.current(); - span.setStatus(StatusCode.ERROR, "Error in runLive Flowable execution"); - span.recordException(throwable); - }); - }); + () -> { + InvocationContext invocationContext = + newInvocationContextForLive(session, liveRequestQueue, runConfig); + + Single invocationContextSingle; + if (invocationContext.agent() instanceof LlmAgent agent) { + invocationContextSingle = + agent + .tools() + .map( + tools -> { + this.addActiveStreamingTools(invocationContext, tools); + return invocationContext; + }); + } else { + invocationContextSingle = Single.just(invocationContext); + } + return invocationContextSingle + .flatMapPublisher( + updatedInvocationContext -> + updatedInvocationContext + .agent() + .runLive(updatedInvocationContext) + .doOnNext(event -> this.sessionService.appendEvent(session, event))) + .doOnError( + throwable -> { + Span span = Span.current(); + span.setStatus(StatusCode.ERROR, "Error in runLive Flowable execution"); + span.recordException(throwable); + }); + }) + .compose(Tracing.trace("invocation")); } /** @@ -707,25 +684,19 @@ protected Flowable runLiveImpl( */ public Flowable runLive( String userId, String sessionId, LiveRequestQueue liveRequestQueue, RunConfig runConfig) { - return Flowable.defer( - () -> - this.sessionService - .getSession(appName, userId, sessionId, Optional.empty()) - .switchIfEmpty( - Single.defer( - () -> { - if (runConfig.autoCreateSession()) { - return this.sessionService.createSession( - appName, userId, (Map) null, sessionId); - } - return Single.error( - new IllegalArgumentException( - String.format( - "Session not found: %s for user %s", sessionId, userId))); - })) - .flatMapPublisher( - session -> this.runLiveImpl(session, liveRequestQueue, runConfig))) - .compose(Tracing.trace("invocation")); + return this.sessionService + .getSession(appName, userId, sessionId, Optional.empty()) + .switchIfEmpty( + Single.defer( + () -> { + if (runConfig.autoCreateSession()) { + return this.sessionService.createSession(appName, userId, null, sessionId); + } + return Single.error( + new IllegalArgumentException( + String.format("Session not found: %s for user %s", sessionId, userId))); + })) + .flatMapPublisher(session -> this.runLive(session, liveRequestQueue, runConfig)); } /** diff --git a/core/src/main/java/com/google/adk/telemetry/Tracing.java b/core/src/main/java/com/google/adk/telemetry/Tracing.java index 9fa68ee00..07a640c37 100644 --- a/core/src/main/java/com/google/adk/telemetry/Tracing.java +++ b/core/src/main/java/com/google/adk/telemetry/Tracing.java @@ -37,20 +37,16 @@ import io.opentelemetry.context.Context; import io.opentelemetry.context.Scope; import io.reactivex.rxjava3.core.Completable; -import io.reactivex.rxjava3.core.CompletableObserver; import io.reactivex.rxjava3.core.CompletableSource; import io.reactivex.rxjava3.core.CompletableTransformer; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.FlowableTransformer; import io.reactivex.rxjava3.core.Maybe; -import io.reactivex.rxjava3.core.MaybeObserver; import io.reactivex.rxjava3.core.MaybeSource; import io.reactivex.rxjava3.core.MaybeTransformer; import io.reactivex.rxjava3.core.Single; -import io.reactivex.rxjava3.core.SingleObserver; import io.reactivex.rxjava3.core.SingleSource; import io.reactivex.rxjava3.core.SingleTransformer; -import io.reactivex.rxjava3.disposables.Disposable; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -58,12 +54,9 @@ import java.util.Map; import java.util.Objects; import java.util.Optional; -import java.util.function.BiConsumer; import java.util.function.Consumer; import java.util.function.Supplier; import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -143,10 +136,6 @@ private static Optional getValidCurrentSpan(String methodName) { return Optional.of(span); } - private static void traceWithSpan(String methodName, Consumer action) { - getValidCurrentSpan(methodName).ifPresent(action); - } - private static void setInvocationAttributes( Span span, InvocationContext invocationContext, String eventId) { span.setAttribute(ADK_INVOCATION_ID, invocationContext.invocationId()); @@ -217,16 +206,16 @@ public static void traceAgentInvocation( */ public static void traceToolCall( String toolName, String toolDescription, String toolType, Map args) { - traceWithSpan( - "traceToolCall", - span -> { - setToolExecutionAttributes(span); - span.setAttribute(GEN_AI_TOOL_NAME, toolName); - span.setAttribute(GEN_AI_TOOL_DESCRIPTION, toolDescription); - span.setAttribute(GEN_AI_TOOL_TYPE, toolType); - - setJsonAttribute(span, ADK_TOOL_CALL_ARGS, args); - }); + getValidCurrentSpan("traceToolCall") + .ifPresent( + span -> { + setToolExecutionAttributes(span); + span.setAttribute(GEN_AI_TOOL_NAME, toolName); + span.setAttribute(GEN_AI_TOOL_DESCRIPTION, toolDescription); + span.setAttribute(GEN_AI_TOOL_TYPE, toolType); + + setJsonAttribute(span, ADK_TOOL_CALL_ARGS, args); + }); } /** @@ -236,32 +225,33 @@ public static void traceToolCall( * @param functionResponseEvent The function response event. */ public static void traceToolResponse(String eventId, Event functionResponseEvent) { - traceWithSpan( - "traceToolResponse", - span -> { - setToolExecutionAttributes(span); - span.setAttribute(ADK_EVENT_ID, eventId); - - Optional functionResponse = - functionResponseEvent.functionResponses().stream().findFirst(); - - String toolCallId = - functionResponse.flatMap(FunctionResponse::id).orElse(""); - Object toolResponse = - functionResponse - .flatMap(FunctionResponse::response) - .map(Object.class::cast) - .orElse(""); - - span.setAttribute(GEN_AI_TOOL_CALL_ID, toolCallId); - - Object finalToolResponse = - (toolResponse instanceof Map) - ? toolResponse - : ImmutableMap.of("result", toolResponse); - - setJsonAttribute(span, ADK_TOOL_RESPONSE, finalToolResponse); - }); + getValidCurrentSpan("traceToolResponse") + .ifPresent( + span -> { + setToolExecutionAttributes(span); + span.setAttribute(ADK_EVENT_ID, eventId); + + FunctionResponse functionResponse = + functionResponseEvent.functionResponses().stream().findFirst().orElse(null); + + String toolCallId = ""; + Object toolResponse = ""; + if (functionResponse != null) { + toolCallId = functionResponse.id().orElse(toolCallId); + if (functionResponse.response().isPresent()) { + toolResponse = functionResponse.response().get(); + } + } + + span.setAttribute(GEN_AI_TOOL_CALL_ID, toolCallId); + + Object finalToolResponse = + (toolResponse instanceof Map) + ? toolResponse + : ImmutableMap.of("result", toolResponse); + + setJsonAttribute(span, ADK_TOOL_RESPONSE, finalToolResponse); + }); } /** @@ -306,63 +296,58 @@ public static void traceCallLlm( String eventId, LlmRequest llmRequest, LlmResponse llmResponse) { - traceWithSpan( - "traceCallLlm", - span -> traceCallLlm(span, invocationContext, eventId, llmRequest, llmResponse)); - } - - /** - * Traces a call to the LLM. - * - * @param span The span to end when the stream completes - * @param invocationContext The invocation context. - * @param eventId The ID of the event associated with this LLM call/response. - * @param llmRequest The LLM request object. - * @param llmResponse The LLM response object. - */ - public static void traceCallLlm( - Span span, - InvocationContext invocationContext, - String eventId, - LlmRequest llmRequest, - LlmResponse llmResponse) { - span.setAttribute(GEN_AI_OPERATION_NAME, "call_llm"); - span.setAttribute(GEN_AI_SYSTEM, "gcp.vertex.agent"); - llmRequest.model().ifPresent(modelName -> span.setAttribute(GEN_AI_REQUEST_MODEL, modelName)); - - setInvocationAttributes(span, invocationContext, eventId); - - setJsonAttribute(span, ADK_LLM_REQUEST, buildLlmRequestForTrace(llmRequest)); - setJsonAttribute(span, ADK_LLM_RESPONSE, llmResponse); - - llmRequest - .config() - .flatMap(config -> config.topP()) - .ifPresent(topP -> span.setAttribute(GEN_AI_REQUEST_TOP_P, topP.doubleValue())); - llmRequest - .config() - .flatMap(config -> config.maxOutputTokens()) + getValidCurrentSpan("traceCallLlm") .ifPresent( - maxTokens -> span.setAttribute(GEN_AI_REQUEST_MAX_TOKENS, maxTokens.longValue())); + span -> { + span.setAttribute(GEN_AI_SYSTEM, "gcp.vertex.agent"); + llmRequest + .model() + .ifPresent(modelName -> span.setAttribute(GEN_AI_REQUEST_MODEL, modelName)); - llmResponse - .usageMetadata() - .ifPresent( - usage -> { - usage - .promptTokenCount() - .ifPresent(tokens -> span.setAttribute(GEN_AI_USAGE_INPUT_TOKENS, (long) tokens)); - usage - .candidatesTokenCount() + setInvocationAttributes(span, invocationContext, eventId); + + setJsonAttribute(span, ADK_LLM_REQUEST, buildLlmRequestForTrace(llmRequest)); + setJsonAttribute(span, ADK_LLM_RESPONSE, llmResponse); + + llmRequest + .config() + .ifPresent( + config -> { + config + .topP() + .ifPresent( + topP -> + span.setAttribute(GEN_AI_REQUEST_TOP_P, topP.doubleValue())); + config + .maxOutputTokens() + .ifPresent( + maxTokens -> + span.setAttribute( + GEN_AI_REQUEST_MAX_TOKENS, maxTokens.longValue())); + }); + llmResponse + .usageMetadata() + .ifPresent( + usage -> { + usage + .promptTokenCount() + .ifPresent( + tokens -> + span.setAttribute(GEN_AI_USAGE_INPUT_TOKENS, (long) tokens)); + usage + .candidatesTokenCount() + .ifPresent( + tokens -> + span.setAttribute(GEN_AI_USAGE_OUTPUT_TOKENS, (long) tokens)); + }); + llmResponse + .finishReason() + .map(reason -> reason.knownEnum().name().toLowerCase(Locale.ROOT)) .ifPresent( - tokens -> span.setAttribute(GEN_AI_USAGE_OUTPUT_TOKENS, (long) tokens)); + reason -> + span.setAttribute( + GEN_AI_RESPONSE_FINISH_REASONS, ImmutableList.of(reason))); }); - - llmResponse - .finishReason() - .map(reason -> reason.knownEnum().name().toLowerCase(Locale.ROOT)) - .ifPresent( - reason -> span.setAttribute(GEN_AI_RESPONSE_FINISH_REASONS, ImmutableList.of(reason))); } /** @@ -374,18 +359,17 @@ public static void traceCallLlm( */ public static void traceSendData( InvocationContext invocationContext, String eventId, List data) { - traceWithSpan( - "traceSendData", - span -> { - span.setAttribute(GEN_AI_OPERATION_NAME, "send_data"); - setInvocationAttributes(span, invocationContext, eventId); - - ImmutableList safeData = - Optional.ofNullable(data).orElse(ImmutableList.of()).stream() - .filter(Objects::nonNull) - .collect(toImmutableList()); - setJsonAttribute(span, ADK_DATA, safeData); - }); + getValidCurrentSpan("traceSendData") + .ifPresent( + span -> { + setInvocationAttributes(span, invocationContext, eventId); + + ImmutableList safeData = + Optional.ofNullable(data).orElse(ImmutableList.of()).stream() + .filter(Objects::nonNull) + .collect(toImmutableList()); + setJsonAttribute(span, ADK_DATA, safeData); + }); } /** @@ -421,17 +405,14 @@ public static Tracer getTracer() { @SuppressWarnings("MustBeClosedChecker") // Scope lifecycle managed by RxJava doFinally public static Flowable traceFlowable( Context spanContext, Span span, Supplier> flowableSupplier) { - return Flowable.defer( - () -> { - Scope scope = spanContext.makeCurrent(); - return flowableSupplier - .get() - .doFinally( - () -> { - scope.close(); - span.end(); - }); - }); + Scope scope = spanContext.makeCurrent(); + return flowableSupplier + .get() + .doFinally( + () -> { + scope.close(); + span.end(); + }); } /** @@ -469,66 +450,15 @@ public static TracerProvider trace(String spanName, Context parentContext * @return A TracerProvider configured for agent invocation. */ public static TracerProvider traceAgent( - Context parent, String spanName, String agentName, String agentDescription, InvocationContext invocationContext) { return new TracerProvider(spanName) - .setParent(parent) .configure( span -> traceAgentInvocation(span, agentName, agentDescription, invocationContext)); } - /** - * Returns a transformer that re-activates a given context for the duration of the stream's - * subscription. - * - * @param context The context to re-activate. - * @param The type of the stream. - * @return A transformer that re-activates the context. - */ - public static ContextTransformer withContext(Context context) { - return new ContextTransformer<>(context); - } - - /** - * A transformer that re-activates a given context for the duration of the stream's subscription. - * - * @param The type of the stream. - */ - public static final class ContextTransformer - implements FlowableTransformer, - SingleTransformer, - MaybeTransformer, - CompletableTransformer { - private final Context context; - - private ContextTransformer(Context context) { - this.context = context; - } - - @Override - public Publisher apply(Flowable upstream) { - return upstream.lift(subscriber -> TracingObserver.wrap(context, subscriber)); - } - - @Override - public SingleSource apply(Single upstream) { - return upstream.lift(observer -> TracingObserver.wrap(context, observer)); - } - - @Override - public MaybeSource apply(Maybe upstream) { - return upstream.lift(observer -> TracingObserver.wrap(context, observer)); - } - - @Override - public CompletableSource apply(Completable upstream) { - return upstream.lift(observer -> TracingObserver.wrap(context, observer)); - } - } - /** * A transformer that manages an OpenTelemetry span and scope for RxJava streams. * @@ -542,7 +472,6 @@ public static final class TracerProvider private final String spanName; private Context explicitParentContext; private final List> spanConfigurers = new ArrayList<>(); - private BiConsumer onSuccessConsumer; private TracerProvider(String spanName) { this.spanName = spanName; @@ -562,38 +491,27 @@ public TracerProvider setParent(Context parentContext) { return this; } - /** - * Registers a callback to be executed with the span and the result item when the stream emits a - * success value. - */ - @CanIgnoreReturnValue - public TracerProvider onSuccess(BiConsumer consumer) { - this.onSuccessConsumer = consumer; - return this; - } - private Context getParentContext() { return explicitParentContext != null ? explicitParentContext : Context.current(); } private final class TracingLifecycle { - private final Span span; - private final Context context; + private Span span; + private Scope scope; - TracingLifecycle() { - Context parentContext = getParentContext(); - span = tracer.spanBuilder(spanName).setParent(parentContext).startSpan(); + @SuppressWarnings("MustBeClosedChecker") + void start() { + span = tracer.spanBuilder(spanName).setParent(getParentContext()).startSpan(); spanConfigurers.forEach(c -> c.accept(span)); - context = parentContext.with(span); + scope = span.makeCurrent(); } void end() { - span.end(); - } - - void run(O observer, Consumer subscribeAction) { - try (Scope scope = context.makeCurrent()) { - subscribeAction.accept(observer); + if (scope != null) { + scope.close(); + } + if (span != null) { + span.end(); } } } @@ -603,18 +521,7 @@ public Publisher apply(Flowable upstream) { return Flowable.defer( () -> { TracingLifecycle lifecycle = new TracingLifecycle(); - return Flowable.fromPublisher( - observer -> - lifecycle.run( - observer, - o -> { - Flowable chain = upstream.compose(withContext(lifecycle.context)); - if (onSuccessConsumer != null) { - chain = - chain.doOnNext(t -> onSuccessConsumer.accept(lifecycle.span, t)); - } - chain.doFinally(lifecycle::end).subscribe(o); - })); + return upstream.doOnSubscribe(s -> lifecycle.start()).doFinally(lifecycle::end); }); } @@ -623,18 +530,7 @@ public SingleSource apply(Single upstream) { return Single.defer( () -> { TracingLifecycle lifecycle = new TracingLifecycle(); - return Single.wrap( - observer -> - lifecycle.run( - observer, - o -> { - Single chain = upstream.compose(withContext(lifecycle.context)); - if (onSuccessConsumer != null) { - chain = - chain.doOnSuccess(t -> onSuccessConsumer.accept(lifecycle.span, t)); - } - chain.doFinally(lifecycle::end).subscribe(o); - })); + return upstream.doOnSubscribe(s -> lifecycle.start()).doFinally(lifecycle::end); }); } @@ -643,18 +539,7 @@ public MaybeSource apply(Maybe upstream) { return Maybe.defer( () -> { TracingLifecycle lifecycle = new TracingLifecycle(); - return Maybe.wrap( - observer -> - lifecycle.run( - observer, - o -> { - Maybe chain = upstream.compose(withContext(lifecycle.context)); - if (onSuccessConsumer != null) { - chain = - chain.doOnSuccess(t -> onSuccessConsumer.accept(lifecycle.span, t)); - } - chain.doFinally(lifecycle::end).subscribe(o); - })); + return upstream.doOnSubscribe(s -> lifecycle.start()).doFinally(lifecycle::end); }); } @@ -663,142 +548,7 @@ public CompletableSource apply(Completable upstream) { return Completable.defer( () -> { TracingLifecycle lifecycle = new TracingLifecycle(); - return Completable.wrap( - observer -> - lifecycle.run( - observer, - o -> { - Completable chain = upstream.compose(withContext(lifecycle.context)); - // Completable does not emit items, so onSuccessConsumer is not - // applicable. - chain.doFinally(lifecycle::end).subscribe(o); - })); - }); - } - } - - /** - * An observer that wraps another observer and ensures that the OpenTelemetry context is active - * during all callback methods. - * - * @param The type of the items emitted by the stream. - */ - private static final class TracingObserver - implements Subscriber, SingleObserver, MaybeObserver, CompletableObserver { - private final Context context; - private final Subscriber subscriber; - private final SingleObserver singleObserver; - private final MaybeObserver maybeObserver; - private final CompletableObserver completableObserver; - - private TracingObserver( - Context context, - Subscriber subscriber, - SingleObserver singleObserver, - MaybeObserver maybeObserver, - CompletableObserver completableObserver) { - this.context = context; - this.subscriber = subscriber; - this.singleObserver = singleObserver; - this.maybeObserver = maybeObserver; - this.completableObserver = completableObserver; - } - - static TracingObserver wrap(Context context, Subscriber subscriber) { - return new TracingObserver<>(context, subscriber, null, null, null); - } - - static TracingObserver wrap(Context context, SingleObserver observer) { - return new TracingObserver<>(context, null, observer, null, null); - } - - static TracingObserver wrap(Context context, MaybeObserver observer) { - return new TracingObserver<>(context, null, null, observer, null); - } - - static TracingObserver wrap(Context context, CompletableObserver observer) { - return new TracingObserver<>(context, null, null, null, observer); - } - - private void runInContext(Runnable action) { - try (Scope scope = context.makeCurrent()) { - action.run(); - } - } - - @Override - public void onSubscribe(Subscription s) { - runInContext( - () -> { - if (subscriber != null) { - subscriber.onSubscribe(s); - } - }); - } - - @Override - public void onSubscribe(Disposable d) { - runInContext( - () -> { - if (singleObserver != null) { - singleObserver.onSubscribe(d); - } else if (maybeObserver != null) { - maybeObserver.onSubscribe(d); - } else if (completableObserver != null) { - completableObserver.onSubscribe(d); - } - }); - } - - @Override - public void onNext(T t) { - runInContext( - () -> { - if (subscriber != null) { - subscriber.onNext(t); - } - }); - } - - @Override - public void onSuccess(T t) { - runInContext( - () -> { - if (singleObserver != null) { - singleObserver.onSuccess(t); - } else if (maybeObserver != null) { - maybeObserver.onSuccess(t); - } - }); - } - - @Override - public void onError(Throwable t) { - runInContext( - () -> { - if (subscriber != null) { - subscriber.onError(t); - } else if (singleObserver != null) { - singleObserver.onError(t); - } else if (maybeObserver != null) { - maybeObserver.onError(t); - } else if (completableObserver != null) { - completableObserver.onError(t); - } - }); - } - - @Override - public void onComplete() { - runInContext( - () -> { - if (subscriber != null) { - subscriber.onComplete(); - } else if (maybeObserver != null) { - maybeObserver.onComplete(); - } else if (completableObserver != null) { - completableObserver.onComplete(); - } + return upstream.doOnSubscribe(s -> lifecycle.start()).doFinally(lifecycle::end); }); } }