Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.ChatClientCustomizer;
import org.springframework.ai.chat.client.advisor.observation.AdvisorObservationConvention;
import org.springframework.ai.chat.client.observation.ChatClientCompletionObservationHandler;
import org.springframework.ai.chat.client.observation.ChatClientObservationContext;
import org.springframework.ai.chat.client.observation.ChatClientObservationConvention;
Expand Down Expand Up @@ -90,11 +91,12 @@ ChatClientBuilderConfigurer chatClientBuilderConfigurer(ObjectProvider<ChatClien
@ConditionalOnMissingBean
ChatClient.Builder chatClientBuilder(ChatClientBuilderConfigurer chatClientBuilderConfigurer, ChatModel chatModel,
ObjectProvider<ObservationRegistry> observationRegistry,
ObjectProvider<ChatClientObservationConvention> observationConvention) {

ObjectProvider<ChatClientObservationConvention> chatClientObservationConvention,
ObjectProvider<AdvisorObservationConvention> advisorObservationConvention) {
ChatClient.Builder builder = ChatClient.builder(chatModel,
observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP),
observationConvention.getIfUnique(() -> null));
chatClientObservationConvention.getIfUnique(() -> null),
advisorObservationConvention.getIfUnique(() -> null));
return chatClientBuilderConfigurer.configure(builder);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import reactor.core.publisher.Flux;

import org.springframework.ai.chat.client.advisor.api.Advisor;
import org.springframework.ai.chat.client.advisor.observation.AdvisorObservationConvention;
import org.springframework.ai.chat.client.observation.ChatClientObservationConvention;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.model.ChatModel;
Expand Down Expand Up @@ -65,22 +66,46 @@ static ChatClient create(ChatModel chatModel, ObservationRegistry observationReg
return create(chatModel, observationRegistry, null);
}

/**
* @deprecated in favor of
* {@link #create(ChatModel, ObservationRegistry, ChatClientObservationConvention, AdvisorObservationConvention)}.
*/
@Deprecated(since = "1.1.0", forRemoval = true)
static ChatClient create(ChatModel chatModel, ObservationRegistry observationRegistry,
@Nullable ChatClientObservationConvention chatClientObservationConvention) {
return create(chatModel, observationRegistry, chatClientObservationConvention, null);
}

static ChatClient create(ChatModel chatModel, ObservationRegistry observationRegistry,
@Nullable ChatClientObservationConvention observationConvention) {
@Nullable ChatClientObservationConvention chatClientObservationConvention,
@Nullable AdvisorObservationConvention advisorObservationConvention) {
Assert.notNull(chatModel, "chatModel cannot be null");
Assert.notNull(observationRegistry, "observationRegistry cannot be null");
return builder(chatModel, observationRegistry, observationConvention).build();
return builder(chatModel, observationRegistry, chatClientObservationConvention, advisorObservationConvention)
.build();
}

static Builder builder(ChatModel chatModel) {
return builder(chatModel, ObservationRegistry.NOOP, null);
}

/**
* @deprecated in favor of
* {@link #builder(ChatModel, ObservationRegistry, ChatClientObservationConvention, AdvisorObservationConvention)}.
*/
@Deprecated(since = "1.1.0", forRemoval = true)
static Builder builder(ChatModel chatModel, ObservationRegistry observationRegistry,
@Nullable ChatClientObservationConvention chatClientObservationConvention) {
return builder(chatModel, observationRegistry, chatClientObservationConvention, null);
}

static Builder builder(ChatModel chatModel, ObservationRegistry observationRegistry,
@Nullable ChatClientObservationConvention customObservationConvention) {
@Nullable ChatClientObservationConvention chatClientObservationConvention,
@Nullable AdvisorObservationConvention advisorObservationConvention) {
Assert.notNull(chatModel, "chatModel cannot be null");
Assert.notNull(observationRegistry, "observationRegistry cannot be null");
return new DefaultChatClientBuilder(chatModel, observationRegistry, customObservationConvention);
return new DefaultChatClientBuilder(chatModel, observationRegistry, chatClientObservationConvention,
advisorObservationConvention);
}

ChatClientRequestSpec prompt();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.springframework.ai.chat.client.advisor.DefaultAroundAdvisorChain;
import org.springframework.ai.chat.client.advisor.api.Advisor;
import org.springframework.ai.chat.client.advisor.api.BaseAdvisorChain;
import org.springframework.ai.chat.client.advisor.observation.AdvisorObservationConvention;
import org.springframework.ai.chat.client.observation.ChatClientObservationContext;
import org.springframework.ai.chat.client.observation.ChatClientObservationConvention;
import org.springframework.ai.chat.client.observation.ChatClientObservationDocumentation;
Expand Down Expand Up @@ -615,7 +616,10 @@ public static class DefaultChatClientRequestSpec implements ChatClientRequestSpe

private final ObservationRegistry observationRegistry;

private final ChatClientObservationConvention observationConvention;
private final ChatClientObservationConvention chatClientObservationConvention;

@Nullable
private final AdvisorObservationConvention advisorObservationConvention;

private final ChatModel chatModel;

Expand Down Expand Up @@ -659,18 +663,36 @@ public static class DefaultChatClientRequestSpec implements ChatClientRequestSpe
this(ccr.chatModel, ccr.userText, ccr.userParams, ccr.userMetadata, ccr.systemText, ccr.systemParams,
ccr.systemMetadata, ccr.toolCallbacks, ccr.toolCallbackProviders, ccr.messages, ccr.toolNames,
ccr.media, ccr.chatOptions, ccr.advisors, ccr.advisorParams, ccr.observationRegistry,
ccr.observationConvention, ccr.toolContext, ccr.templateRenderer);
ccr.chatClientObservationConvention, ccr.toolContext, ccr.templateRenderer,
ccr.advisorObservationConvention);
}

/**
* @deprecated in favor of the other constructor.
*/
@Deprecated(since = "1.1.0", forRemoval = true)
public DefaultChatClientRequestSpec(ChatModel chatModel, @Nullable String userText,
Map<String, Object> userParams, Map<String, Object> userMetadata, @Nullable String systemText,
Map<String, Object> systemParams, Map<String, Object> systemMetadata, List<ToolCallback> toolCallbacks,
List<ToolCallbackProvider> toolCallbackProviders, List<Message> messages, List<String> toolNames,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Slightly related to this PR: during rebasing Thomas' changes, I found an extra param here: List<ToolCallbackProvider> toolCallbackProviders that wasn't there before, I guess that's a breaking change between 1.0 and 1.1. If that's ok, maybe the deprecated ctor can be removed too in this PR? (I prefer not breaking things though, that's why I kept it.)

List<Media> media, @Nullable ChatOptions chatOptions, List<Advisor> advisors,
Map<String, Object> advisorParams, ObservationRegistry observationRegistry,
@Nullable ChatClientObservationConvention observationConvention, Map<String, Object> toolContext,
@Nullable TemplateRenderer templateRenderer) {
@Nullable ChatClientObservationConvention chatClientObservationConvention,
Map<String, Object> toolContext, @Nullable TemplateRenderer templateRenderer) {
this(chatModel, userText, userParams, userMetadata, systemText, systemParams, systemMetadata, toolCallbacks,
toolCallbackProviders, messages, toolNames, media, chatOptions, advisors, advisorParams,
observationRegistry, chatClientObservationConvention, toolContext, templateRenderer, null);
}

public DefaultChatClientRequestSpec(ChatModel chatModel, @Nullable String userText,
Map<String, Object> userParams, Map<String, Object> userMetadata, @Nullable String systemText,
Map<String, Object> systemParams, Map<String, Object> systemMetadata, List<ToolCallback> toolCallbacks,
List<ToolCallbackProvider> toolCallbackProviders, List<Message> messages, List<String> toolNames,
List<Media> media, @Nullable ChatOptions chatOptions, List<Advisor> advisors,
Map<String, Object> advisorParams, ObservationRegistry observationRegistry,
@Nullable ChatClientObservationConvention chatClientObservationConvention,
Map<String, Object> toolContext, @Nullable TemplateRenderer templateRenderer,
@Nullable AdvisorObservationConvention advisorObservationConvention) {
Assert.notNull(chatModel, "chatModel cannot be null");
Assert.notNull(userParams, "userParams cannot be null");
Assert.notNull(userMetadata, "userMetadata cannot be null");
Expand Down Expand Up @@ -706,10 +728,11 @@ public DefaultChatClientRequestSpec(ChatModel chatModel, @Nullable String userTe
this.advisors.addAll(advisors);
this.advisorParams.putAll(advisorParams);
this.observationRegistry = observationRegistry;
this.observationConvention = observationConvention != null ? observationConvention
: DEFAULT_CHAT_CLIENT_OBSERVATION_CONVENTION;
this.chatClientObservationConvention = chatClientObservationConvention != null
? chatClientObservationConvention : DEFAULT_CHAT_CLIENT_OBSERVATION_CONVENTION;
this.toolContext.putAll(toolContext);
this.templateRenderer = templateRenderer != null ? templateRenderer : DEFAULT_TEMPLATE_RENDERER;
this.advisorObservationConvention = advisorObservationConvention;
}

@Nullable
Expand Down Expand Up @@ -786,7 +809,8 @@ public TemplateRenderer getTemplateRenderer() {
@Override
public Builder mutate() {
DefaultChatClientBuilder builder = (DefaultChatClientBuilder) ChatClient
.builder(this.chatModel, this.observationRegistry, this.observationConvention)
.builder(this.chatModel, this.observationRegistry, this.chatClientObservationConvention,
this.advisorObservationConvention)
.defaultTemplateRenderer(this.templateRenderer)
.defaultToolCallbacks(this.toolCallbacks)
.defaultToolCallbacks(this.toolCallbackProviders.toArray(new ToolCallback[0]))
Expand Down Expand Up @@ -1005,14 +1029,14 @@ public ChatClientRequestSpec templateRenderer(TemplateRenderer templateRenderer)
public CallResponseSpec call() {
BaseAdvisorChain advisorChain = buildAdvisorChain();
return new DefaultCallResponseSpec(DefaultChatClientUtils.toChatClientRequest(this), advisorChain,
this.observationRegistry, this.observationConvention);
this.observationRegistry, this.chatClientObservationConvention);
}

@Override
public StreamResponseSpec stream() {
BaseAdvisorChain advisorChain = buildAdvisorChain();
return new DefaultStreamResponseSpec(DefaultChatClientUtils.toChatClientRequest(this), advisorChain,
this.observationRegistry, this.observationConvention);
this.observationRegistry, this.chatClientObservationConvention);
}

private BaseAdvisorChain buildAdvisorChain() {
Expand All @@ -1021,7 +1045,10 @@ private BaseAdvisorChain buildAdvisorChain() {
this.advisors.add(ChatModelCallAdvisor.builder().chatModel(this.chatModel).build());
this.advisors.add(ChatModelStreamAdvisor.builder().chatModel(this.chatModel).build());

return DefaultAroundAdvisorChain.builder(this.observationRegistry).pushAll(this.advisors).build();
return DefaultAroundAdvisorChain.builder(this.observationRegistry)
.observationConvention(this.advisorObservationConvention)
.pushAll(this.advisors)
.build();
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.springframework.ai.chat.client.ChatClient.PromptUserSpec;
import org.springframework.ai.chat.client.DefaultChatClient.DefaultChatClientRequestSpec;
import org.springframework.ai.chat.client.advisor.api.Advisor;
import org.springframework.ai.chat.client.advisor.observation.AdvisorObservationConvention;
import org.springframework.ai.chat.client.observation.ChatClientObservationConvention;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.model.ChatModel;
Expand Down Expand Up @@ -60,13 +61,24 @@ public class DefaultChatClientBuilder implements Builder {
this(chatModel, ObservationRegistry.NOOP, null);
}

/**
* @deprecated in favor of
* {@link #DefaultChatClientBuilder(ChatModel, ObservationRegistry, ChatClientObservationConvention, AdvisorObservationConvention)}.
*/
@Deprecated(since = "1.1.0", forRemoval = true)
public DefaultChatClientBuilder(ChatModel chatModel, ObservationRegistry observationRegistry,
@Nullable ChatClientObservationConvention customObservationConvention) {
@Nullable ChatClientObservationConvention chatClientObservationConvention) {
this(chatModel, observationRegistry, chatClientObservationConvention, null);
}

public DefaultChatClientBuilder(ChatModel chatModel, ObservationRegistry observationRegistry,
@Nullable ChatClientObservationConvention chatClientObservationConvention,
@Nullable AdvisorObservationConvention advisorObservationConvention) {
Assert.notNull(chatModel, "the " + ChatModel.class.getName() + " must be non-null");
Assert.notNull(observationRegistry, "the " + ObservationRegistry.class.getName() + " must be non-null");
this.defaultRequest = new DefaultChatClientRequestSpec(chatModel, null, Map.of(), Map.of(), null, Map.of(),
Map.of(), List.of(), List.of(), List.of(), List.of(), List.of(), null, List.of(), Map.of(),
observationRegistry, customObservationConvention, Map.of(), null);
observationRegistry, chatClientObservationConvention, Map.of(), null, advisorObservationConvention);
}

public ChatClient build() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
import reactor.core.publisher.Flux;

import org.springframework.ai.chat.client.ChatClientMessageAggregator;
import org.springframework.ai.chat.client.ChatClientRequest;
import org.springframework.ai.chat.client.ChatClientResponse;
import org.springframework.ai.chat.client.advisor.api.Advisor;
Expand All @@ -37,6 +38,7 @@
import org.springframework.ai.chat.client.advisor.observation.AdvisorObservationDocumentation;
import org.springframework.ai.chat.client.advisor.observation.DefaultAdvisorObservationConvention;
import org.springframework.core.OrderComparator;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;

Expand All @@ -54,6 +56,8 @@ public class DefaultAroundAdvisorChain implements BaseAdvisorChain {

public static final AdvisorObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultAdvisorObservationConvention();

private static final ChatClientMessageAggregator CHAT_CLIENT_MESSAGE_AGGREGATOR = new ChatClientMessageAggregator();

private final List<CallAdvisor> originalCallAdvisors;

private final List<StreamAdvisor> originalStreamAdvisors;
Expand All @@ -64,8 +68,10 @@ public class DefaultAroundAdvisorChain implements BaseAdvisorChain {

private final ObservationRegistry observationRegistry;

private final AdvisorObservationConvention observationConvention;

DefaultAroundAdvisorChain(ObservationRegistry observationRegistry, Deque<CallAdvisor> callAdvisors,
Deque<StreamAdvisor> streamAdvisors) {
Deque<StreamAdvisor> streamAdvisors, @Nullable AdvisorObservationConvention observationConvention) {

Assert.notNull(observationRegistry, "the observationRegistry must be non-null");
Assert.notNull(callAdvisors, "the callAdvisors must be non-null");
Expand All @@ -76,6 +82,8 @@ public class DefaultAroundAdvisorChain implements BaseAdvisorChain {
this.streamAdvisors = streamAdvisors;
this.originalCallAdvisors = List.copyOf(callAdvisors);
this.originalStreamAdvisors = List.copyOf(streamAdvisors);
this.observationConvention = observationConvention != null ? observationConvention
: DEFAULT_OBSERVATION_CONVENTION;
}

public static Builder builder(ObservationRegistry observationRegistry) {
Expand All @@ -99,8 +107,13 @@ public ChatClientResponse nextCall(ChatClientRequest chatClientRequest) {
.build();

return AdvisorObservationDocumentation.AI_ADVISOR
.observation(null, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry)
.observe(() -> advisor.adviseCall(chatClientRequest, this));
.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
this.observationRegistry)
.observe(() -> {
var chatClientResponse = advisor.adviseCall(chatClientRequest, this);
observationContext.setChatClientResponse(chatClientResponse);
return chatClientResponse;
});
}

@Override
Expand All @@ -120,17 +133,19 @@ public Flux<ChatClientResponse> nextStream(ChatClientRequest chatClientRequest)
.order(advisor.getOrder())
.build();

var observation = AdvisorObservationDocumentation.AI_ADVISOR.observation(null,
var observation = AdvisorObservationDocumentation.AI_ADVISOR.observation(this.observationConvention,
DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry);

observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start();

// @formatter:off
return Flux.defer(() -> advisor.adviseStream(chatClientRequest, this)
Flux<ChatClientResponse> chatClientResponse = Flux.defer(() -> advisor.adviseStream(chatClientRequest, this)
.doOnError(observation::error)
.doFinally(s -> observation.stop())
.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)));
// @formatter:on
return CHAT_CLIENT_MESSAGE_AGGREGATOR.aggregateChatClientResponse(chatClientResponse,
observationContext::setChatClientResponse);
});
}

Expand Down Expand Up @@ -175,12 +190,20 @@ public static final class Builder {

private final Deque<StreamAdvisor> streamAdvisors;

@Nullable
private AdvisorObservationConvention observationConvention;

public Builder(ObservationRegistry observationRegistry) {
this.observationRegistry = observationRegistry;
this.callAdvisors = new ConcurrentLinkedDeque<>();
this.streamAdvisors = new ConcurrentLinkedDeque<>();
}

public Builder observationConvention(@Nullable AdvisorObservationConvention observationConvention) {
this.observationConvention = observationConvention;
return this;
}

public Builder push(Advisor advisor) {
Assert.notNull(advisor, "the advisor must be non-null");
return this.pushAll(List.of(advisor));
Expand Down Expand Up @@ -229,7 +252,8 @@ private void reOrder() {
}

public DefaultAroundAdvisorChain build() {
return new DefaultAroundAdvisorChain(this.observationRegistry, this.callAdvisors, this.streamAdvisors);
return new DefaultAroundAdvisorChain(this.observationRegistry, this.callAdvisors, this.streamAdvisors,
this.observationConvention);
}

}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -18,6 +18,7 @@

import java.nio.charset.Charset;

import io.micrometer.observation.ObservationRegistry;
import org.junit.jupiter.api.Test;

import org.springframework.ai.chat.model.ChatModel;
Expand Down Expand Up @@ -63,6 +64,12 @@ void whenObservationRegistryIsNullThenThrows() {
.hasMessage("the io.micrometer.observation.ObservationRegistry must be non-null");
}

@Test
void whenAdvisorObservationConventionIsNullThenReturn() {
var builder = new DefaultChatClientBuilder(mock(ChatModel.class), mock(ObservationRegistry.class), null, null);
assertThat(builder).isNotNull();
}

@Test
void whenUserResourceIsNullThenThrows() {
DefaultChatClientBuilder builder = new DefaultChatClientBuilder(mock(ChatModel.class));
Expand Down
Loading