diff --git a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/DefaultAroundAdvisorChain.java b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/DefaultAroundAdvisorChain.java index c86fa0b1b6..c16e24b0dc 100644 --- a/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/DefaultAroundAdvisorChain.java +++ b/spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/DefaultAroundAdvisorChain.java @@ -175,6 +175,7 @@ private DefaultAroundAdvisorChain copyAdvisorsAfter(List advi var remainingStreamAdvisors = advisors.subList(afterAdvisorIndex + 1, advisors.size()); return DefaultAroundAdvisorChain.builder(this.getObservationRegistry()) + .observationConvention(this.observationConvention) .pushAll(remainingStreamAdvisors) .build(); } @@ -194,6 +195,10 @@ public ObservationRegistry getObservationRegistry() { return this.observationRegistry; } + public AdvisorObservationConvention getObservationConvention() { + return this.observationConvention; + } + public static final class Builder { private final ObservationRegistry observationRegistry; diff --git a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/DefaultAroundAdvisorChainTests.java b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/DefaultAroundAdvisorChainTests.java index 7202cff7dc..3bfb62ed53 100644 --- a/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/DefaultAroundAdvisorChainTests.java +++ b/spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/DefaultAroundAdvisorChainTests.java @@ -31,6 +31,8 @@ import org.springframework.ai.chat.client.advisor.api.CallAdvisorChain; import org.springframework.ai.chat.client.advisor.api.StreamAdvisor; import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain; +import org.springframework.ai.chat.client.advisor.observation.AdvisorObservationConvention; +import org.springframework.ai.chat.client.advisor.observation.DefaultAdvisorObservationConvention; import org.springframework.ai.chat.prompt.Prompt; import static org.assertj.core.api.Assertions.assertThat; @@ -245,6 +247,71 @@ void whenCopyingChainThenObservationRegistryIsPreserved() { assertThat(newChain.getObservationRegistry()).isSameAs(customRegistry); } + @Test + void whenCopyingChainWithCustomObservationConventionThenConventionIsPreserved() { + CallAdvisor advisor1 = createMockAdvisor("advisor1", 1); + CallAdvisor advisor2 = createMockAdvisor("advisor2", 2); + CallAdvisor advisor3 = createMockAdvisor("advisor3", 3); + + ObservationRegistry observationRegistry = ObservationRegistry.create(); + AdvisorObservationConvention customConvention = new DefaultAdvisorObservationConvention("custom-convention") { + }; + + CallAdvisorChain chain = DefaultAroundAdvisorChain.builder(observationRegistry) + .observationConvention(customConvention) + .pushAll(List.of(advisor1, advisor2, advisor3)) + .build(); + + CallAdvisorChain newChain = chain.copy(advisor1); + + assertThat(newChain).isInstanceOf(DefaultAroundAdvisorChain.class); + assertThat(((DefaultAroundAdvisorChain) newChain).getObservationConvention().getName()) + .isEqualTo("custom-convention"); + } + + @Test + void whenCopyingChainWithDefaultObservationConventionThenDefaultIsPreserved() { + CallAdvisor advisor1 = createMockAdvisor("advisor1", 1); + CallAdvisor advisor2 = createMockAdvisor("advisor2", 2); + + ObservationRegistry observationRegistry = ObservationRegistry.create(); + CallAdvisorChain chain = DefaultAroundAdvisorChain.builder(observationRegistry) + .pushAll(List.of(advisor1, advisor2)) + .build(); + + CallAdvisorChain newChain = chain.copy(advisor1); + + assertThat(newChain).isInstanceOf(DefaultAroundAdvisorChain.class); + assertThat(((DefaultAroundAdvisorChain) newChain).getObservationConvention().getName()) + .isEqualTo("spring.ai.advisor"); + } + + @Test + void whenCopyingStreamChainThenObservationConventionIsPreserved() { + StreamAdvisor streamAdvisor1 = mock(StreamAdvisor.class); + when(streamAdvisor1.getName()).thenReturn("streamAdvisor1"); + when(streamAdvisor1.adviseStream(any(), any())).thenReturn(Flux.just(ChatClientResponse.builder().build())); + StreamAdvisor streamAdvisor2 = mock(StreamAdvisor.class); + when(streamAdvisor2.getName()).thenReturn("streamAdvisor2"); + when(streamAdvisor2.adviseStream(any(), any())).thenReturn(Flux.just(ChatClientResponse.builder().build())); + + ObservationRegistry observationRegistry = ObservationRegistry.create(); + AdvisorObservationConvention customConvention = new DefaultAdvisorObservationConvention( + "stream-custom-convention") { + }; + + StreamAdvisorChain chain = DefaultAroundAdvisorChain.builder(observationRegistry) + .observationConvention(customConvention) + .pushAll(List.of(streamAdvisor1, streamAdvisor2)) + .build(); + + StreamAdvisorChain newChain = chain.copy(streamAdvisor1); + + assertThat(newChain).isInstanceOf(DefaultAroundAdvisorChain.class); + assertThat(((DefaultAroundAdvisorChain) newChain).getObservationConvention().getName()) + .isEqualTo("stream-custom-convention"); + } + private CallAdvisor createMockAdvisor(String name, int order) { return new CallAdvisor() { @Override