Skip to content

Commit 3e712d1

Browse files
committed
Simplify filtering MCP ToolCallbackProviders for ToolCallingAutoConfiguration
See GH-4751 Signed-off-by: Yanming Zhou <zhouyanming@gmail.com>
1 parent 30af4e8 commit 3e712d1

File tree

2 files changed

+16
-64
lines changed

2 files changed

+16
-64
lines changed

auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/McpToolsConfigurationTests.java

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import org.springframework.ai.tool.definition.ToolDefinition;
4040
import org.springframework.ai.tool.resolution.ToolCallbackResolver;
4141
import org.springframework.ai.util.json.schema.JsonSchemaGenerator;
42+
import org.springframework.beans.factory.ObjectProvider;
4243
import org.springframework.boot.autoconfigure.AutoConfigurations;
4344
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
4445
import org.springframework.context.annotation.Bean;
@@ -47,10 +48,14 @@
4748

4849
import static org.assertj.core.api.Assertions.assertThat;
4950
import static org.mockito.Mockito.mock;
51+
import static org.mockito.Mockito.never;
52+
import static org.mockito.Mockito.spy;
53+
import static org.mockito.Mockito.verify;
5054
import static org.mockito.Mockito.when;
5155

5256
/**
5357
* @author Daniel Garnier-Moiroux
58+
* @author Yanming Zhou
5459
*/
5560
class McpToolsConfigurationTests {
5661

@@ -123,24 +128,16 @@ void toolCallbacksRegistered() {
123128

124129
// MCP toolcallback providers are never added to the resolver
125130

126-
// Bean graph setup
127-
var injectedProviders = (List<ToolCallbackProvider>) ctx.getBean(
128-
"org.springframework.ai.model.tool.autoconfigure.ToolCallingAutoConfiguration.toolcallbackprovider.mcp-excluded");
129-
// Beans exposed as non-MCP
130-
var toolCallbackProvider = (ToolCallbackProvider) ctx.getBean("toolCallbackProvider");
131-
var customToolCallbackProvider = (ToolCallbackProvider) ctx.getBean("customToolCallbackProvider");
132131
// This is injected in the resolver bean, because it's exposed as a
133132
// ToolCallbackProvider, but it's not added to the resolver
134133
var genericMcpToolCallbackProvider = (ToolCallbackProvider) ctx.getBean("genericMcpToolCallbackProvider");
135-
136134
// beans exposed as MCP
137135
var mcpToolCallbackProvider = (ToolCallbackProvider) ctx.getBean("mcpToolCallbackProvider");
138136
var customMcpToolCallbackProvider = (ToolCallbackProvider) ctx.getBean("customMcpToolCallbackProvider");
139137

140-
assertThat(injectedProviders)
141-
.containsExactlyInAnyOrder(toolCallbackProvider, customToolCallbackProvider,
142-
genericMcpToolCallbackProvider)
143-
.doesNotContain(mcpToolCallbackProvider, customMcpToolCallbackProvider);
138+
verify(genericMcpToolCallbackProvider, never()).getToolCallbacks();
139+
verify(mcpToolCallbackProvider, never()).getToolCallbacks();
140+
verify(customMcpToolCallbackProvider, never()).getToolCallbacks();
144141

145142
});
146143
}
@@ -196,7 +193,7 @@ ToolCallbackProvider toolCallbackProvider() {
196193

197194
// This bean depends on the resolver, to ensure there are no cyclic dependencies
198195
@Bean
199-
SyncMcpToolCallbackProvider mcpToolCallbackProvider(ToolCallbackResolver resolver) {
196+
SyncMcpToolCallbackProvider mcpToolCallbackProvider(ObjectProvider<ToolCallbackResolver> resolver) {
200197
var tcp = mock(SyncMcpToolCallbackProvider.class);
201198
when(tcp.getToolCallbacks())
202199
.thenThrow(new RuntimeException("mcpToolCallbackProvider#getToolCallbacks should not be called"));
@@ -210,16 +207,16 @@ CustomToolCallbackProvider customToolCallbackProvider() {
210207

211208
// This bean depends on the resolver, to ensure there are no cyclic dependencies
212209
@Bean
213-
CustomMcpToolCallbackProvider customMcpToolCallbackProvider(ToolCallbackResolver resolver) {
214-
return new CustomMcpToolCallbackProvider();
210+
CustomMcpToolCallbackProvider customMcpToolCallbackProvider(ObjectProvider<ToolCallbackResolver> resolver) {
211+
return spy(new CustomMcpToolCallbackProvider());
215212
}
216213

217214
// This will be added to the resolver, because the visible type of the bean
218215
// is ToolCallbackProvider ; we would need to actually instantiate the bean
219216
// to find out that it is MCP-related
220217
@Bean
221218
ToolCallbackProvider genericMcpToolCallbackProvider() {
222-
return new CustomMcpToolCallbackProvider();
219+
return spy(new CustomMcpToolCallbackProvider());
223220
}
224221

225222
static ToolCallback[] toolCallback(String name) {

auto-configurations/models/tool/spring-ai-autoconfigure-model-tool/src/main/java/org/springframework/ai/model/tool/autoconfigure/ToolCallingAutoConfiguration.java

Lines changed: 4 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
package org.springframework.ai.model.tool.autoconfigure;
1818

1919
import java.util.ArrayList;
20-
import java.util.Arrays;
2120
import java.util.List;
2221

2322
import io.micrometer.observation.ObservationRegistry;
@@ -36,14 +35,7 @@
3635
import org.springframework.ai.tool.resolution.SpringBeanToolCallbackResolver;
3736
import org.springframework.ai.tool.resolution.StaticToolCallbackResolver;
3837
import org.springframework.ai.tool.resolution.ToolCallbackResolver;
39-
import org.springframework.beans.BeansException;
4038
import org.springframework.beans.factory.ObjectProvider;
41-
import org.springframework.beans.factory.annotation.Qualifier;
42-
import org.springframework.beans.factory.config.BeanDefinition;
43-
import org.springframework.beans.factory.support.BeanDefinitionBuilder;
44-
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
45-
import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor;
46-
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
4739
import org.springframework.boot.autoconfigure.AutoConfiguration;
4840
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
4941
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
@@ -60,18 +52,16 @@
6052
* @author Thomas Vitale
6153
* @author Christian Tzolov
6254
* @author Daniel Garnier-Moiroux
55+
* @author Yanming Zhou
6356
* @since 1.0.0
6457
*/
6558
@AutoConfiguration
6659
@ConditionalOnClass(ChatModel.class)
6760
@EnableConfigurationProperties(ToolCallingProperties.class)
68-
public class ToolCallingAutoConfiguration implements BeanDefinitionRegistryPostProcessor {
61+
public class ToolCallingAutoConfiguration {
6962

7063
private static final Logger logger = LoggerFactory.getLogger(ToolCallingAutoConfiguration.class);
7164

72-
// Marker qualifier to exclude MCP-related ToolCallbackProviders
73-
private static final String EXCLUDE_MCP_TOOL_CALLBACK_PROVIDER = "org.springframework.ai.model.tool.autoconfigure.ToolCallingAutoConfiguration.toolcallbackprovider.mcp-excluded";
74-
7565
/**
7666
* The default {@link ToolCallbackResolver} resolves tools by name for methods,
7767
* functions, and {@link ToolCallbackProvider} beans.
@@ -83,8 +73,8 @@ public class ToolCallingAutoConfiguration implements BeanDefinitionRegistryPostP
8373
@Bean
8474
@ConditionalOnMissingBean
8575
ToolCallbackResolver toolCallbackResolver(GenericApplicationContext applicationContext,
86-
List<ToolCallback> toolCallbacks,
87-
@Qualifier(EXCLUDE_MCP_TOOL_CALLBACK_PROVIDER) List<ToolCallbackProvider> tcbProviders) {
76+
List<ToolCallback> toolCallbacks, List<ToolCallbackProvider> tcbProviders) {
77+
8878
List<ToolCallback> allFunctionAndToolCallbacks = new ArrayList<>(toolCallbacks);
8979
tcbProviders.stream()
9080
.filter(pr -> !isMcpToolCallbackProvider(ResolvableType.forInstance(pr)))
@@ -100,41 +90,6 @@ ToolCallbackResolver toolCallbackResolver(GenericApplicationContext applicationC
10090
return new DelegatingToolCallbackResolver(List.of(staticToolCallbackResolver, springBeanToolCallbackResolver));
10191
}
10292

103-
/**
104-
* Wrap {@link ToolCallbackProvider} beans that are not MCP-related into a named bean,
105-
* which will be picked up by the
106-
* {@link ToolCallingAutoConfiguration#toolCallbackResolver}.
107-
* <p>
108-
* MCP providers must be excluded, because they may depend on a {@code ChatClient} to
109-
* do sampling. The chat client, in turn, depends on a {@link ToolCallbackResolver}.
110-
* To do the detection, we depend on the exposed bean type. If a bean uses a factory
111-
* method which returns a {@link ToolCallbackProvider}, which is an MCP provider under
112-
* the hood, it will be included in the list.
113-
*/
114-
@Override
115-
public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) throws BeansException {
116-
if (!(registry instanceof DefaultListableBeanFactory beanFactory)) {
117-
return;
118-
}
119-
120-
var excludeMcpToolCallbackProviderBeanDefinition = BeanDefinitionBuilder
121-
.genericBeanDefinition(List.class, () -> {
122-
var providerNames = beanFactory.getBeanNamesForType(ToolCallbackProvider.class);
123-
return Arrays.stream(providerNames)
124-
.filter(name -> !isMcpToolCallbackProvider(beanFactory.getBeanDefinition(name).getResolvableType()))
125-
.map(beanFactory::getBean)
126-
.filter(ToolCallbackProvider.class::isInstance)
127-
.map(ToolCallbackProvider.class::cast)
128-
.toList();
129-
})
130-
.setScope(BeanDefinition.SCOPE_SINGLETON)
131-
.setLazyInit(true)
132-
.getBeanDefinition();
133-
134-
registry.registerBeanDefinition(EXCLUDE_MCP_TOOL_CALLBACK_PROVIDER,
135-
excludeMcpToolCallbackProviderBeanDefinition);
136-
}
137-
13893
private static boolean isMcpToolCallbackProvider(ResolvableType type) {
13994
if (type.getType().getTypeName().equals("org.springframework.ai.mcp.SyncMcpToolCallbackProvider")
14095
|| type.getType().getTypeName().equals("org.springframework.ai.mcp.AsyncMcpToolCallbackProvider")) {

0 commit comments

Comments
 (0)