Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,14 @@

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

/**
* @author Daniel Garnier-Moiroux
* @author Yanming Zhou
*/
class McpToolsConfigurationTests {

Expand Down Expand Up @@ -123,24 +127,16 @@ void toolCallbacksRegistered() {

// MCP toolcallback providers are never added to the resolver

// Bean graph setup
var injectedProviders = (List<ToolCallbackProvider>) ctx.getBean(
"org.springframework.ai.model.tool.autoconfigure.ToolCallingAutoConfiguration.toolcallbackprovider.mcp-excluded");
// Beans exposed as non-MCP
var toolCallbackProvider = (ToolCallbackProvider) ctx.getBean("toolCallbackProvider");
var customToolCallbackProvider = (ToolCallbackProvider) ctx.getBean("customToolCallbackProvider");
// This is injected in the resolver bean, because it's exposed as a
// ToolCallbackProvider, but it's not added to the resolver
var genericMcpToolCallbackProvider = (ToolCallbackProvider) ctx.getBean("genericMcpToolCallbackProvider");

// beans exposed as MCP
var mcpToolCallbackProvider = (ToolCallbackProvider) ctx.getBean("mcpToolCallbackProvider");
var customMcpToolCallbackProvider = (ToolCallbackProvider) ctx.getBean("customMcpToolCallbackProvider");

assertThat(injectedProviders)
.containsExactlyInAnyOrder(toolCallbackProvider, customToolCallbackProvider,
genericMcpToolCallbackProvider)
.doesNotContain(mcpToolCallbackProvider, customMcpToolCallbackProvider);
verify(genericMcpToolCallbackProvider, never()).getToolCallbacks();
verify(mcpToolCallbackProvider, never()).getToolCallbacks();
verify(customMcpToolCallbackProvider, never()).getToolCallbacks();

});
}
Expand Down Expand Up @@ -211,15 +207,15 @@ CustomToolCallbackProvider customToolCallbackProvider() {
// This bean depends on the resolver, to ensure there are no cyclic dependencies
@Bean
CustomMcpToolCallbackProvider customMcpToolCallbackProvider(ToolCallbackResolver resolver) {
return new CustomMcpToolCallbackProvider();
return spy(new CustomMcpToolCallbackProvider());
}

// This will be added to the resolver, because the visible type of the bean
// is ToolCallbackProvider ; we would need to actually instantiate the bean
// to find out that it is MCP-related
@Bean
ToolCallbackProvider genericMcpToolCallbackProvider() {
return new CustomMcpToolCallbackProvider();
CustomMcpToolCallbackProvider genericMcpToolCallbackProvider() {
return spy(new CustomMcpToolCallbackProvider());
}

static ToolCallback[] toolCallback(String name) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package org.springframework.ai.model.tool.autoconfigure;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import io.micrometer.observation.ObservationRegistry;
Expand All @@ -36,14 +35,7 @@
import org.springframework.ai.tool.resolution.SpringBeanToolCallbackResolver;
import org.springframework.ai.tool.resolution.StaticToolCallbackResolver;
import org.springframework.ai.tool.resolution.ToolCallbackResolver;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.support.BeanDefinitionBuilder;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor;
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
Expand All @@ -60,18 +52,16 @@
* @author Thomas Vitale
* @author Christian Tzolov
* @author Daniel Garnier-Moiroux
* @author Yanming Zhou
* @since 1.0.0
*/
@AutoConfiguration
@ConditionalOnClass(ChatModel.class)
@EnableConfigurationProperties(ToolCallingProperties.class)
public class ToolCallingAutoConfiguration implements BeanDefinitionRegistryPostProcessor {
public class ToolCallingAutoConfiguration {

private static final Logger logger = LoggerFactory.getLogger(ToolCallingAutoConfiguration.class);

// Marker qualifier to exclude MCP-related ToolCallbackProviders
private static final String EXCLUDE_MCP_TOOL_CALLBACK_PROVIDER = "org.springframework.ai.model.tool.autoconfigure.ToolCallingAutoConfiguration.toolcallbackprovider.mcp-excluded";

/**
* The default {@link ToolCallbackResolver} resolves tools by name for methods,
* functions, and {@link ToolCallbackProvider} beans.
Expand All @@ -83,11 +73,10 @@ public class ToolCallingAutoConfiguration implements BeanDefinitionRegistryPostP
@Bean
@ConditionalOnMissingBean
ToolCallbackResolver toolCallbackResolver(GenericApplicationContext applicationContext,
List<ToolCallback> toolCallbacks,
@Qualifier(EXCLUDE_MCP_TOOL_CALLBACK_PROVIDER) List<ToolCallbackProvider> tcbProviders) {
List<ToolCallback> toolCallbacks, ObjectProvider<ToolCallbackProvider> tcbProviders) {

List<ToolCallback> allFunctionAndToolCallbacks = new ArrayList<>(toolCallbacks);
tcbProviders.stream()
.filter(pr -> !isMcpToolCallbackProvider(ResolvableType.forInstance(pr)))
tcbProviders.stream(clazz -> !isMcpToolCallbackProvider(ResolvableType.forClass(clazz)))
.map(pr -> List.of(pr.getToolCallbacks()))
.forEach(allFunctionAndToolCallbacks::addAll);

Expand All @@ -100,41 +89,6 @@ ToolCallbackResolver toolCallbackResolver(GenericApplicationContext applicationC
return new DelegatingToolCallbackResolver(List.of(staticToolCallbackResolver, springBeanToolCallbackResolver));
}

/**
* Wrap {@link ToolCallbackProvider} beans that are not MCP-related into a named bean,
* which will be picked up by the
* {@link ToolCallingAutoConfiguration#toolCallbackResolver}.
* <p>
* MCP providers must be excluded, because they may depend on a {@code ChatClient} to
* do sampling. The chat client, in turn, depends on a {@link ToolCallbackResolver}.
* To do the detection, we depend on the exposed bean type. If a bean uses a factory
* method which returns a {@link ToolCallbackProvider}, which is an MCP provider under
* the hood, it will be included in the list.
*/
@Override
public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) throws BeansException {
if (!(registry instanceof DefaultListableBeanFactory beanFactory)) {
return;
}

var excludeMcpToolCallbackProviderBeanDefinition = BeanDefinitionBuilder
.genericBeanDefinition(List.class, () -> {
var providerNames = beanFactory.getBeanNamesForType(ToolCallbackProvider.class);
return Arrays.stream(providerNames)
.filter(name -> !isMcpToolCallbackProvider(beanFactory.getBeanDefinition(name).getResolvableType()))
.map(beanFactory::getBean)
.filter(ToolCallbackProvider.class::isInstance)
.map(ToolCallbackProvider.class::cast)
.toList();
})
.setScope(BeanDefinition.SCOPE_SINGLETON)
.setLazyInit(true)
.getBeanDefinition();

registry.registerBeanDefinition(EXCLUDE_MCP_TOOL_CALLBACK_PROVIDER,
excludeMcpToolCallbackProviderBeanDefinition);
}

private static boolean isMcpToolCallbackProvider(ResolvableType type) {
if (type.getType().getTypeName().equals("org.springframework.ai.mcp.SyncMcpToolCallbackProvider")
|| type.getType().getTypeName().equals("org.springframework.ai.mcp.AsyncMcpToolCallbackProvider")) {
Expand Down