Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ private DefaultAroundAdvisorChain copyAdvisorsAfter(List<? extends Advisor> advi
var remainingStreamAdvisors = advisors.subList(afterAdvisorIndex + 1, advisors.size());

return DefaultAroundAdvisorChain.builder(this.getObservationRegistry())
.observationConvention(this.observationConvention)
.pushAll(remainingStreamAdvisors)
.build();
}
Expand All @@ -194,6 +195,10 @@ public ObservationRegistry getObservationRegistry() {
return this.observationRegistry;
}

public AdvisorObservationConvention getObservationConvention() {
return this.observationConvention;
}

public static final class Builder {

private final ObservationRegistry observationRegistry;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
import org.springframework.ai.chat.client.advisor.api.CallAdvisorChain;
import org.springframework.ai.chat.client.advisor.api.StreamAdvisor;
import org.springframework.ai.chat.client.advisor.api.StreamAdvisorChain;
import org.springframework.ai.chat.client.advisor.observation.AdvisorObservationConvention;
import org.springframework.ai.chat.client.advisor.observation.DefaultAdvisorObservationConvention;
import org.springframework.ai.chat.prompt.Prompt;

import static org.assertj.core.api.Assertions.assertThat;
Expand Down Expand Up @@ -245,6 +247,71 @@ void whenCopyingChainThenObservationRegistryIsPreserved() {
assertThat(newChain.getObservationRegistry()).isSameAs(customRegistry);
}

@Test
void whenCopyingChainWithCustomObservationConventionThenConventionIsPreserved() {
CallAdvisor advisor1 = createMockAdvisor("advisor1", 1);
CallAdvisor advisor2 = createMockAdvisor("advisor2", 2);
CallAdvisor advisor3 = createMockAdvisor("advisor3", 3);

ObservationRegistry observationRegistry = ObservationRegistry.create();
AdvisorObservationConvention customConvention = new DefaultAdvisorObservationConvention("custom-convention") {
};

CallAdvisorChain chain = DefaultAroundAdvisorChain.builder(observationRegistry)
.observationConvention(customConvention)
.pushAll(List.of(advisor1, advisor2, advisor3))
.build();

CallAdvisorChain newChain = chain.copy(advisor1);

assertThat(newChain).isInstanceOf(DefaultAroundAdvisorChain.class);
assertThat(((DefaultAroundAdvisorChain) newChain).getObservationConvention().getName())
.isEqualTo("custom-convention");
}

@Test
void whenCopyingChainWithDefaultObservationConventionThenDefaultIsPreserved() {
CallAdvisor advisor1 = createMockAdvisor("advisor1", 1);
CallAdvisor advisor2 = createMockAdvisor("advisor2", 2);

ObservationRegistry observationRegistry = ObservationRegistry.create();
CallAdvisorChain chain = DefaultAroundAdvisorChain.builder(observationRegistry)
.pushAll(List.of(advisor1, advisor2))
.build();

CallAdvisorChain newChain = chain.copy(advisor1);

assertThat(newChain).isInstanceOf(DefaultAroundAdvisorChain.class);
assertThat(((DefaultAroundAdvisorChain) newChain).getObservationConvention().getName())
.isEqualTo("spring.ai.advisor");
}

@Test
void whenCopyingStreamChainThenObservationConventionIsPreserved() {
StreamAdvisor streamAdvisor1 = mock(StreamAdvisor.class);
when(streamAdvisor1.getName()).thenReturn("streamAdvisor1");
when(streamAdvisor1.adviseStream(any(), any())).thenReturn(Flux.just(ChatClientResponse.builder().build()));
StreamAdvisor streamAdvisor2 = mock(StreamAdvisor.class);
when(streamAdvisor2.getName()).thenReturn("streamAdvisor2");
when(streamAdvisor2.adviseStream(any(), any())).thenReturn(Flux.just(ChatClientResponse.builder().build()));

ObservationRegistry observationRegistry = ObservationRegistry.create();
AdvisorObservationConvention customConvention = new DefaultAdvisorObservationConvention(
"stream-custom-convention") {
};

StreamAdvisorChain chain = DefaultAroundAdvisorChain.builder(observationRegistry)
.observationConvention(customConvention)
.pushAll(List.of(streamAdvisor1, streamAdvisor2))
.build();

StreamAdvisorChain newChain = chain.copy(streamAdvisor1);

assertThat(newChain).isInstanceOf(DefaultAroundAdvisorChain.class);
assertThat(((DefaultAroundAdvisorChain) newChain).getObservationConvention().getName())
.isEqualTo("stream-custom-convention");
}

private CallAdvisor createMockAdvisor(String name, int order) {
return new CallAdvisor() {
@Override
Expand Down