Skip to content

Commit f73385f

Browse files
committed
Simplify filtering MCP ToolCallbackProviders for ToolCallingAutoConfiguration
See GH-4751 Signed-off-by: Yanming Zhou <zhouyanming@gmail.com>
1 parent 30af4e8 commit f73385f

File tree

2 files changed

+15
-65
lines changed

2 files changed

+15
-65
lines changed

auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/McpToolsConfigurationTests.java

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,14 @@
4747

4848
import static org.assertj.core.api.Assertions.assertThat;
4949
import static org.mockito.Mockito.mock;
50+
import static org.mockito.Mockito.never;
51+
import static org.mockito.Mockito.spy;
52+
import static org.mockito.Mockito.verify;
5053
import static org.mockito.Mockito.when;
5154

5255
/**
5356
* @author Daniel Garnier-Moiroux
57+
* @author Yanming Zhou
5458
*/
5559
class McpToolsConfigurationTests {
5660

@@ -123,24 +127,16 @@ void toolCallbacksRegistered() {
123127

124128
// MCP toolcallback providers are never added to the resolver
125129

126-
// Bean graph setup
127-
var injectedProviders = (List<ToolCallbackProvider>) ctx.getBean(
128-
"org.springframework.ai.model.tool.autoconfigure.ToolCallingAutoConfiguration.toolcallbackprovider.mcp-excluded");
129-
// Beans exposed as non-MCP
130-
var toolCallbackProvider = (ToolCallbackProvider) ctx.getBean("toolCallbackProvider");
131-
var customToolCallbackProvider = (ToolCallbackProvider) ctx.getBean("customToolCallbackProvider");
132130
// This is injected in the resolver bean, because it's exposed as a
133131
// ToolCallbackProvider, but it's not added to the resolver
134132
var genericMcpToolCallbackProvider = (ToolCallbackProvider) ctx.getBean("genericMcpToolCallbackProvider");
135-
136133
// beans exposed as MCP
137134
var mcpToolCallbackProvider = (ToolCallbackProvider) ctx.getBean("mcpToolCallbackProvider");
138135
var customMcpToolCallbackProvider = (ToolCallbackProvider) ctx.getBean("customMcpToolCallbackProvider");
139136

140-
assertThat(injectedProviders)
141-
.containsExactlyInAnyOrder(toolCallbackProvider, customToolCallbackProvider,
142-
genericMcpToolCallbackProvider)
143-
.doesNotContain(mcpToolCallbackProvider, customMcpToolCallbackProvider);
137+
verify(genericMcpToolCallbackProvider, never()).getToolCallbacks();
138+
verify(mcpToolCallbackProvider, never()).getToolCallbacks();
139+
verify(customMcpToolCallbackProvider, never()).getToolCallbacks();
144140

145141
});
146142
}
@@ -211,15 +207,15 @@ CustomToolCallbackProvider customToolCallbackProvider() {
211207
// This bean depends on the resolver, to ensure there are no cyclic dependencies
212208
@Bean
213209
CustomMcpToolCallbackProvider customMcpToolCallbackProvider(ToolCallbackResolver resolver) {
214-
return new CustomMcpToolCallbackProvider();
210+
return spy(new CustomMcpToolCallbackProvider());
215211
}
216212

217213
// This will be added to the resolver, because the visible type of the bean
218214
// is ToolCallbackProvider ; we would need to actually instantiate the bean
219215
// to find out that it is MCP-related
220216
@Bean
221-
ToolCallbackProvider genericMcpToolCallbackProvider() {
222-
return new CustomMcpToolCallbackProvider();
217+
CustomMcpToolCallbackProvider genericMcpToolCallbackProvider() {
218+
return spy(new CustomMcpToolCallbackProvider());
223219
}
224220

225221
static ToolCallback[] toolCallback(String name) {

auto-configurations/models/tool/spring-ai-autoconfigure-model-tool/src/main/java/org/springframework/ai/model/tool/autoconfigure/ToolCallingAutoConfiguration.java

Lines changed: 5 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
package org.springframework.ai.model.tool.autoconfigure;
1818

1919
import java.util.ArrayList;
20-
import java.util.Arrays;
2120
import java.util.List;
2221

2322
import io.micrometer.observation.ObservationRegistry;
@@ -36,14 +35,7 @@
3635
import org.springframework.ai.tool.resolution.SpringBeanToolCallbackResolver;
3736
import org.springframework.ai.tool.resolution.StaticToolCallbackResolver;
3837
import org.springframework.ai.tool.resolution.ToolCallbackResolver;
39-
import org.springframework.beans.BeansException;
4038
import org.springframework.beans.factory.ObjectProvider;
41-
import org.springframework.beans.factory.annotation.Qualifier;
42-
import org.springframework.beans.factory.config.BeanDefinition;
43-
import org.springframework.beans.factory.support.BeanDefinitionBuilder;
44-
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
45-
import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor;
46-
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
4739
import org.springframework.boot.autoconfigure.AutoConfiguration;
4840
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
4941
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
@@ -60,18 +52,16 @@
6052
* @author Thomas Vitale
6153
* @author Christian Tzolov
6254
* @author Daniel Garnier-Moiroux
55+
* @author Yanming Zhou
6356
* @since 1.0.0
6457
*/
6558
@AutoConfiguration
6659
@ConditionalOnClass(ChatModel.class)
6760
@EnableConfigurationProperties(ToolCallingProperties.class)
68-
public class ToolCallingAutoConfiguration implements BeanDefinitionRegistryPostProcessor {
61+
public class ToolCallingAutoConfiguration {
6962

7063
private static final Logger logger = LoggerFactory.getLogger(ToolCallingAutoConfiguration.class);
7164

72-
// Marker qualifier to exclude MCP-related ToolCallbackProviders
73-
private static final String EXCLUDE_MCP_TOOL_CALLBACK_PROVIDER = "org.springframework.ai.model.tool.autoconfigure.ToolCallingAutoConfiguration.toolcallbackprovider.mcp-excluded";
74-
7565
/**
7666
* The default {@link ToolCallbackResolver} resolves tools by name for methods,
7767
* functions, and {@link ToolCallbackProvider} beans.
@@ -83,11 +73,10 @@ public class ToolCallingAutoConfiguration implements BeanDefinitionRegistryPostP
8373
@Bean
8474
@ConditionalOnMissingBean
8575
ToolCallbackResolver toolCallbackResolver(GenericApplicationContext applicationContext,
86-
List<ToolCallback> toolCallbacks,
87-
@Qualifier(EXCLUDE_MCP_TOOL_CALLBACK_PROVIDER) List<ToolCallbackProvider> tcbProviders) {
76+
List<ToolCallback> toolCallbacks, ObjectProvider<ToolCallbackProvider> tcbProviders) {
77+
8878
List<ToolCallback> allFunctionAndToolCallbacks = new ArrayList<>(toolCallbacks);
89-
tcbProviders.stream()
90-
.filter(pr -> !isMcpToolCallbackProvider(ResolvableType.forInstance(pr)))
79+
tcbProviders.stream(clazz -> !isMcpToolCallbackProvider(ResolvableType.forClass(clazz)))
9180
.map(pr -> List.of(pr.getToolCallbacks()))
9281
.forEach(allFunctionAndToolCallbacks::addAll);
9382

@@ -100,41 +89,6 @@ ToolCallbackResolver toolCallbackResolver(GenericApplicationContext applicationC
10089
return new DelegatingToolCallbackResolver(List.of(staticToolCallbackResolver, springBeanToolCallbackResolver));
10190
}
10291

103-
/**
104-
* Wrap {@link ToolCallbackProvider} beans that are not MCP-related into a named bean,
105-
* which will be picked up by the
106-
* {@link ToolCallingAutoConfiguration#toolCallbackResolver}.
107-
* <p>
108-
* MCP providers must be excluded, because they may depend on a {@code ChatClient} to
109-
* do sampling. The chat client, in turn, depends on a {@link ToolCallbackResolver}.
110-
* To do the detection, we depend on the exposed bean type. If a bean uses a factory
111-
* method which returns a {@link ToolCallbackProvider}, which is an MCP provider under
112-
* the hood, it will be included in the list.
113-
*/
114-
@Override
115-
public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) throws BeansException {
116-
if (!(registry instanceof DefaultListableBeanFactory beanFactory)) {
117-
return;
118-
}
119-
120-
var excludeMcpToolCallbackProviderBeanDefinition = BeanDefinitionBuilder
121-
.genericBeanDefinition(List.class, () -> {
122-
var providerNames = beanFactory.getBeanNamesForType(ToolCallbackProvider.class);
123-
return Arrays.stream(providerNames)
124-
.filter(name -> !isMcpToolCallbackProvider(beanFactory.getBeanDefinition(name).getResolvableType()))
125-
.map(beanFactory::getBean)
126-
.filter(ToolCallbackProvider.class::isInstance)
127-
.map(ToolCallbackProvider.class::cast)
128-
.toList();
129-
})
130-
.setScope(BeanDefinition.SCOPE_SINGLETON)
131-
.setLazyInit(true)
132-
.getBeanDefinition();
133-
134-
registry.registerBeanDefinition(EXCLUDE_MCP_TOOL_CALLBACK_PROVIDER,
135-
excludeMcpToolCallbackProviderBeanDefinition);
136-
}
137-
13892
private static boolean isMcpToolCallbackProvider(ResolvableType type) {
13993
if (type.getType().getTypeName().equals("org.springframework.ai.mcp.SyncMcpToolCallbackProvider")
14094
|| type.getType().getTypeName().equals("org.springframework.ai.mcp.AsyncMcpToolCallbackProvider")) {

0 commit comments

Comments
 (0)