AI Agent 基础概念
感知(Perception)模块 :获取外部环境信息,接收多模态输入(文本、图像、传感器数据),例如GPT-4o可端到端处理视觉与语音。
决策(Planning)模块 :基于LLM/AIGC大模型进行思考、规划、决策,通过思维链(CoT)、思维树(ToT)等技术拆解复杂目标,并基于反馈持续优化策略。
行动(Action)模块 :调用工具(API、数据库、搜索引擎、机器人肢体等)去完成任务。
记忆(Memory)模块 :记忆模块又可以分为短期记忆和长期记忆。短期记忆依赖大模型的上下文窗口(如128K tokens)存储数据和知识,用于优化任务效果。长期记忆则通过向量数据库+RAG实现历史经验存储与检索。
核心消息类型 在AI Agent主流框架设计中,定义了三种核心消息类型:System Prompt(System Message) 、Assistant Prompt(Assistant Message) 和User Prompt(User Message) ,三者功能明确区分:
User Prompt: 代表用户的直接输入的问题。
Assistant Prompt: 代表大模型生成的回复内容。
System Prompt: 用于设定大模型的角色、基础指令(如身份界定、安全约束)等核心配置。
System Prompt与User Prompt的关键区别在于其位置与优先级 :System Prompt固定设置在输入文本序列的开端。由于注意力机制的特性(序列首尾信息通常更受关注 ),该位置的内容更容易被模型识别和遵循。因此,一个完整的多轮对话提示词(Prompts)通常按以下模式拼接:
System Prompt -> User Prompt -> Assistant Prompt -> User Prompt ... -> Assistant Prompt
核心消息类型对照表
类型
作用
优先级
可见性
典型内容
System
设定角色/规则/安全边界
高
对用户不可见
角色、行为准则、安全策略、输出格式
User
任务/问题输入
中
可见
用户问题、需求、数据
Assistant
模型回复/中间思考
中
可见
解答、工具调用意图、思考(隐藏/可选)
Tool
工具执行结果(供模型消费)
中
可见(通常结构化)
外部系统结果、检索/计算输出
在此结构中,Assistant Prompt的主要作用是向大模型展示历史对话记录 ,并明确标注其中哪些内容源于用户的输入。经过这种结构模式数据预训练和微调的大模型能够理解:这些并非即时用户输入,而是对话历史。这有助于大模型更好地把握上下文信息,从而更准确地回应后续问题。
那么,为何不将System Prompt与User Prompt合并呢? 一个重要考量在于安全性和可控性 。通过在微调阶段区分消息类型,有助于防御提示词注入(Prompt Injection)等攻击手段。具体来说:
将核心角色定义和规则置于System Prompt中。
用户交互内容则放在User Prompt里。
上述这种分离机制能有效防范某些简单的提示词攻击或信息泄露风险。特别是在实际应用中,System Prompt对用户通常是不可见的。其定义的规则和角色经过充分训练,因而在模型中享有最高优先级。这显著提高了大模型遵循开发者意图的可能性,降低了因用户输入变化导致输出偏离预期的风险。
当然,仅依赖System Prompt并不能完全抵御攻击 (例如,GPT-4 曾出现过System Prompt被诱导泄露的案例)。因此,对用户输入或模型输出进行二次校验 ,是更为稳妥的安全增强方案。
Spring AI 中的实现
public enum MessageType { USER("user" ), ASSISTANT("assistant" ), SYSTEM("system" ), TOOL("tool" ); ... }
PromptTemplate public class PromptTemplate implements PromptTemplateActions , PromptTemplateMessageActions { }
String userText = "" " Tell me about three famous pirates from the Golden Age of Piracy and why they did. Write at least a sentence for each pirate. " "" ;Message userMessage = new UserMessage(userText); String systemText = "" " You are a helpful AI assistant that helps people find information. Your name is {name} You should reply to the user's request with your name and also in the style of a {voice}. " "" ;SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemText); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name" , name, "voice" , voice)); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); List<Generation> response = chatModel.call(prompt).getResults();
AI Agent 设计模式
Reflection pattern
ReflectAgent reflectAgent = ReflectAgent.builder() .graph(assistantGraphNode) .reflection(judgeGraphNode) .maxIterations(3 ) .build();
public StateGraph createReflectionGraph (NodeAction graph, NodeAction reflection, int maxIterations) { StateGraph stateGraph = new StateGraph(() -> { HashMap<String, KeyStrategy> keyStrategyHashMap = new HashMap<>(); keyStrategyHashMap.put(MESSAGES, new ReplaceStrategy()); keyStrategyHashMap.put(ITERATION_NUM, new ReplaceStrategy()); return keyStrategyHashMap; }) .addNode(GRAPH_NODE_ID, node_async(graph)) .addNode(REFLECTION_NODE_ID, node_async(reflection)) .addEdge(START, GRAPH_NODE_ID) .addConditionalEdges(GRAPH_NODE_ID, edge_async(this ::graphCount), Map.of(REFLECTION_NODE_ID, REFLECTION_NODE_ID, END, END)) .addConditionalEdges(REFLECTION_NODE_ID, edge_async(this ::apply), Map.of(GRAPH_NODE_ID, GRAPH_NODE_ID, END, END)); return stateGraph; } private String graphCount (OverAllState state) { int iterationNum = state.value(ITERATION_NUM, Integer.class).orElse(0 ); state.updateState(Map.of(ITERATION_NUM, iterationNum + 1 )); return iterationNum >= maxIterations ? END : REFLECTION_NODE_ID; } private String apply (OverAllState state) { List<Message> messages = state.value(MESSAGES, List.class).orElse(new ArrayList<>()); if (messages.isEmpty()) return END; Message lastMessage = messages.get(messages.size() - 1 ); return lastMessage instanceof UserMessage ? GRAPH_NODE_ID : END; }
ReAct 模式
ReactAgent reactAgent = new ReactAgent( "weatherAgent" , chatClient, toolCallbacks, 10 ); CompiledGraph compiledGraph = reactAgent.getAndCompileGraph();
什么时候用/不该用(模式快速指引)
模式
何时使用
不建议使用
Reflection
需要自我评审/改写答案、对质量敏感
延迟极强场景或对成本极敏感
ReAct
需要“思考 + 工具调用”交替、信息检索或函数编排
工具很少或单次调用即可完成
Planning
需要中长期规划、分解子任务
任务简单、一步完成
Multi-agent
需要角色分工与协作
单人即可完成、沟通开销不划算
private StateGraph initGraph () throws GraphStateException { StateGraph graph = new StateGraph(name, this .keyStrategyFactory); graph.addNode("llm" , node_async(this .llmNode)); graph.addNode("tool" , node_async(this .toolNode)); graph.addEdge(START, "llm" ) .addConditionalEdges("llm" , edge_async(this ::think), Map.of("continue" , "tool" , "end" , END)) .addEdge("tool" , "llm" ); return graph; } private String think (OverAllState state) { if (iterations > max_iterations) { return "end" ; } List<Message> messages = (List<Message>) state.value("messages" ).orElseThrow(); AssistantMessage message = (AssistantMessage) messages.get(messages.size() - 1 ); return message.hasToolCalls() ? "continue" : "end" ; }
Planning pattern
Multi-agent pattern
Agent Memory Memory(记忆) 是让AI Agent能够存储、检索和利用过去交互中获取的信息的机制。它超越了单次对话/任务的限制,是AI Agent实现“连续性”和“个性化”的核心组件。
记住用户偏好 :例如,用户在历史对话中说喜欢科幻小说,下次推荐书籍时就可以优先考虑这个类型。
维持对话上下文 :进行多轮对话,能够理解一些名词的指代(如“他”、“她”、“它”、“那个”指的是什么)。
从历史中学习 :总结过去的成功或失败经验,优化未来的决策和行动。
积累知识 :将新获取的信息存入知识库,供长期使用。
当前在AI Agent中增加记忆机制有短期记忆、长期记忆以及工具记忆 三种模式。
**短期记忆 / 对话记忆(Short-term / Conversation Memory)**主要用于维护当前对话/任务的上下文。通常有窗口限制(只记住最近N轮对话,有底层核心AIGC/LLM大模型上下文窗口决定)。
**长期记忆(Long-term Memory)主要在较长时间跨度内存储和回忆重要信息。容量远大于短期记忆,需要通过检索来获取相关信息。长期记忆机制通常与 向量数据库(Vector Database)**结合,将信息转换为向量嵌入(Embeddings)后进行存储和相似性检索。
工具记忆 (Tool Memory): 记录AI Agent调用工具(如API、函数)的历史和结果。这对于需要多步执行复杂任务的Agent至关重要,它可以回顾之前的工具调用结果来决定下一步动作。
Memory机制的核心可以简化为两个基本操作:读取(Read) 和 写入(Write) 。
写入(Write)操作 在AI Agent完成一次动作(如回复用户、调用工具)后。将有价值的交互信息(如用户输入、AI输出、工具执行结果、自主总结的要点)持久化到存储中(内存、数据库、文件等)。
读取(Retrieve / Read)操作 在AI Agent处理新的输入或开始新任务之前。根据当前的查询(Query)(如用户的新问题),从记忆存储中查找最相关的信息片段。对于长期记忆,通常使用向量相似性检索 。将查询也转换为向量,然后从向量数据库中找出与之最相似的记忆向量记忆合并到提示词的方式(实践对照)
短期记忆:以消息集合注入(Message 形式),适合多轮上下文保持。
Prompt 拼接:将历史摘要拼到 System Prompt,适合“稳定的长规则+短上下文”。
长期记忆(向量库):召回相关片段拼接为长文本,适合知识库/文档问答。
示例 MessageWindowChatMemory Spring AI 提供多个内置 Advisor 来配置 ChatClient 的记忆行为:
MessageChatMemoryAdvisor :使用提供的 ChatMemory 管理会话记忆。每次交互时,从记忆中检索会话历史,并将其作为消息集合包含在提示中。
PromptChatMemoryAdvisor :使用提供的 ChatMemory 管理会话记忆。每次交互时,从记忆中检索会话历史,并将其作为纯文本附加到系统提示。
VectorStoreChatMemoryAdvisor :使用提供的 VectorStore 管理会话记忆。每次交互时,从向量库检索会话历史,并将其作为纯文本附加到系统消息。
ChatMemory chatMemory = MessageWindowChatMemory.builder().build(); ChatClient chatClient = ChatClient.builder(chatModel) .defaultAdvisors(MessageChatMemoryAdvisor.builder(chatMemory).build()) .build();
当执行对ChatClient的调用时,内存将由MessageChatMemoryAdvisor自动管理。会话历史将根据指定的会话ID从内存中检索:
String conversationId = "007" ; String response = chatClient.prompt() .user("Do I have license to code?" ) .advisors(a -> a.param(ChatMemory.CONVERSATION_ID, conversationId)) .call() .content();
PromptChatMemoryAdvisor 通过 .promptTemplate() 方法提供自定义 PromptTemplate 来覆盖默认行为:
PromptTemplate customTemplate = new PromptTemplate("" " 系统指令:{instructions} 历史记忆:{memory} " "" ); PromptChatMemoryAdvisor advisor = PromptChatMemoryAdvisor.builder(chatMemory) .promptTemplate(customTemplate) .build();
注意 :模板必须包含 {instructions} 和 {memory} 占位符。
VectorStoreChatMemoryAdvisor 自定义模板
类似地,可通过 .promptTemplate() 自定义向量存储记忆的合并方式:
PromptTemplate customTemplate = new PromptTemplate("" " 系统指令:{instructions} 长期记忆:{long_term_memory} " "" ); VectorStoreChatMemoryAdvisor advisor = VectorStoreChatMemoryAdvisor.builder(vectorStore) .promptTemplate(customTemplate) .build();
注意 :模板必须包含 {instructions} 和 {long_term_memory} 占位符。
ChatModel 显式管理记忆 ChatMemory chatMemory = MessageWindowChatMemory.builder().build(); String conversationId = "007" ; UserMessage userMessage1 = new UserMessage("My name is James Bond" ); chatMemory.add(conversationId, userMessage1); ChatResponse response1 = chatModel.call(new Prompt(chatMemory.get(conversationId))); chatMemory.add(conversationId, response1.getResult().getOutput()); UserMessage userMessage2 = new UserMessage("What is my name?" ); chatMemory.add(conversationId, userMessage2); ChatResponse response2 = chatModel.call(new Prompt(chatMemory.get(conversationId))); chatMemory.add(conversationId, response2.getResult().getOutput());
MCP
Spring AI 实现
调用链路 初始化连接
调用工具
示例
在聊天请求中包含工具的定义,包括工具名称、描述、输入模式
当AI模型决定调用一个工具时,会发送一个响应,包含工具名称和输入参数(大模型提取文本根据输入模式转化而得)
应用程序将使用工具名称并使用提供的输入参数
工具计算结果后将结果返回给应用程序
应用程序再将结果发送给模型
模型添加工具结果作为附加的上下文信息生成最终响应
工具调用链路
public class DefaultChatClient implements ChatClient { @Override public ChatClientRequestSpec tools (String... toolNames) { Assert.notNull(toolNames, "toolNames cannot be null" ); Assert.noNullElements(toolNames, "toolNames cannot contain null elements" ); this .functionNames.addAll(List.of(toolNames)); return this ; } @Override public ChatClientRequestSpec tools (FunctionCallback... toolCallbacks) { Assert.notNull(toolCallbacks, "toolCallbacks cannot be null" ); Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements" ); this .functionCallbacks.addAll(List.of(toolCallbacks)); return this ; } @Override public ChatClientRequestSpec tools (List<ToolCallback> toolCallbacks) { Assert.notNull(toolCallbacks, "toolCallbacks cannot be null" ); Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements" ); this .functionCallbacks.addAll(toolCallbacks); return this ; } @Override public ChatClientRequestSpec tools (Object... toolObjects) { Assert.notNull(toolObjects, "toolObjects cannot be null" ); Assert.noNullElements(toolObjects, "toolObjects cannot contain null elements" ); this .functionCallbacks.addAll(Arrays.asList(ToolCallbacks.from(toolObjects))); return this ; } @Override public ChatClientRequestSpec tools (ToolCallbackProvider... toolCallbackProviders) { Assert.notNull(toolCallbackProviders, "toolCallbackProviders cannot be null" ); Assert.noNullElements(toolCallbackProviders, "toolCallbackProviders cannot contain null elements" ); for (ToolCallbackProvider toolCallbackProvider : toolCallbackProviders) { this .functionCallbacks.addAll(List.of(toolCallbackProvider.getToolCallbacks())); } return this ; } }
Spring AI 最小可运行示例 以下示例展示:依赖、初始化、一个 @Tool、一次调用。
依赖(以 Maven 为例,选择任意提供商实现)
<dependencyManagement > <dependencies > <dependency > <groupId > org.springframework.ai</groupId > <artifactId > spring-ai-bom</artifactId > <version > 1.0.0</version > <type > pom</type > <scope > import</scope > </dependency > </dependencies > </dependencyManagement > <dependencies > <dependency > <groupId > org.springframework.ai</groupId > <artifactId > spring-ai-openai-spring-boot-starter</artifactId > </dependency > </dependencies >
配置(application.yml)
spring: ai: openai: api-key: ${OPENAI_API_KEY}
定义 Tool 并调用
@Component public class MathTools { @Tool(name = "addTwoNumbers", description = "Add two integers") public int addTwoNumbers (int a, int b) { return a + b; } } @RestController public class DemoController { private final ChatClient chatClient; private final MethodToolCallbackProvider toolProvider; public DemoController (ChatClient chatClient, MathTools tools) { this .chatClient = chatClient; this .toolProvider = MethodToolCallbackProvider.builder().toolObjects(tools).build(); } @GetMapping("/demo") public String demo () { return chatClient .prompt() .system("You are a helpful assistant." ) .user("请计算 12 和 30 的和" ) .tools(ToolCallbackProvider.from(toolProvider.getToolCallbacks())) .call() .content(); } }
@Target({ ElementType.METHOD, ElementType.ANNOTATION_TYPE }) @Retention(RetentionPolicy.RUNTIME) @Documented public @interface Tool { String name () default "" ; String description () default "" ; boolean returnDirect () default false ; Class<? extends ToolCallResultConverter> resultConverter() default DefaultToolCallResultConverter.class; }
public interface ToolDefinition { String name () ; String description () ; String inputSchema () ; static DefaultToolDefinition.Builder builder () { return DefaultToolDefinition.builder(); } static DefaultToolDefinition.Builder builder (Method method) { Assert.notNull(method, "method cannot be null" ); return DefaultToolDefinition.builder() .name(ToolUtils.getToolName(method)) .description(ToolUtils.getToolDescription(method)) .inputSchema(JsonSchemaGenerator.generateForMethodInput(method)); } static ToolDefinition from (Method method) { return ToolDefinition.builder(method).build(); } }
public record DefaultToolDefinition (String name, String description, String inputSchema) implements ToolDefinition { public DefaultToolDefinition { Assert.hasText(name, "name cannot be null or empty" ); Assert.hasText(description, "description cannot be null or empty" ); Assert.hasText(inputSchema, "inputSchema cannot be null or empty" ); } public static Builder builder () { return new Builder(); } public static class Builder { private String name; private String description; private String inputSchema; private Builder () {} public Builder name (String name) { this .name = name; return this ; } public Builder description (String description) { this .description = description; return this ; } public Builder inputSchema (String inputSchema) { this .inputSchema = inputSchema; return this ; } public ToolDefinition build () { if (!StringUtils.hasText(description)) { description = ToolUtils.getToolDescriptionFromName(name); } return new DefaultToolDefinition(name, description, inputSchema); } } }
现阶段只用于控制直接将工具结果返回,不再走模型响应
public interface ToolMetadata { default boolean returnDirect () { return false ; } static DefaultToolMetadata.Builder builder () { return DefaultToolMetadata.builder(); } static ToolMetadata from (Method method) { Assert.notNull(method, "method cannot be null" ); return DefaultToolMetadata.builder().returnDirect(ToolUtils.getToolReturnDirect(method)).build(); } }
public record DefaultToolMetadata (boolean returnDirect) implements ToolMetadata { public static Builder builder () { return new Builder(); } public static class Builder { private boolean returnDirect = false ; private Builder () {} public Builder returnDirect (boolean returnDirect) { this .returnDirect = returnDirect; return this ; } public ToolMetadata build () { return new DefaultToolMetadata(returnDirect); } } }
public interface ToolCallback { ToolDefinition getToolDefinition () ; default ToolMetadata getToolMetadata () { return ToolMetadata.builder().build(); } String call (String toolInput) ; default String call (String toolInput, @Nullable ToolContext tooContext) { if (tooContext != null && !tooContext.getContext().isEmpty()) { throw new UnsupportedOperationException("Tool context is not supported!" ); } return call(toolInput); } }
核心方法主要关注call
将模型处理后的字符串文本,转化为对应的输入模式
Map<String, Object> toolArguments = extractToolArguments(toolInput); Object[] methodArguments = buildMethodArguments(toolArguments, toolContext);
调用工具的方法+输入参数,得到工具的输出结果
Object result = callMethod(methodArguments);
将工具的输出结果的类型进行转化
Type returnType = toolMethod.getGenericReturnType(); return toolCallResultConverter.convert(result, returnType);
public class MethodToolCallback implements ToolCallback { private static final Logger logger = LoggerFactory.getLogger(MethodToolCallback.class); private static final ToolCallResultConverter DEFAULT_RESULT_CONVERTER = new DefaultToolCallResultConverter(); private static final ToolMetadata DEFAULT_TOOL_METADATA = ToolMetadata.builder().build(); private final ToolDefinition toolDefinition; private final ToolMetadata toolMetadata; private final Method toolMethod; @Nullable private final Object toolObject; private final ToolCallResultConverter toolCallResultConverter; public MethodToolCallback (ToolDefinition toolDefinition, @Nullable ToolMetadata toolMetadata, Method toolMethod, @Nullable Object toolObject, @Nullable ToolCallResultConverter toolCallResultConverter) { Assert.notNull(toolDefinition, "toolDefinition cannot be null" ); Assert.notNull(toolMethod, "toolMethod cannot be null" ); Assert.isTrue(Modifier.isStatic(toolMethod.getModifiers()) || toolObject != null , "toolObject cannot be null for non-static methods" ); this .toolDefinition = toolDefinition; this .toolMetadata = toolMetadata != null ? toolMetadata : DEFAULT_TOOL_METADATA; this .toolMethod = toolMethod; this .toolObject = toolObject; this .toolCallResultConverter = toolCallResultConverter != null ? toolCallResultConverter : DEFAULT_RESULT_CONVERTER; } @Override public ToolDefinition getToolDefinition () { return toolDefinition; } @Override public ToolMetadata getToolMetadata () { return toolMetadata; } @Override public String call (String toolInput) { return call(toolInput, null ); } @Override public String call (String toolInput, @Nullable ToolContext toolContext) { Assert.hasText(toolInput, "toolInput cannot be null or empty" ); logger.debug("Starting execution of tool: {}" , toolDefinition.name()); validateToolContextSupport(toolContext); Map<String, Object> toolArguments = extractToolArguments(toolInput); Object[] methodArguments = buildMethodArguments(toolArguments, toolContext); Object result = callMethod(methodArguments); logger.debug("Successful execution of tool: {}" , toolDefinition.name()); Type returnType = toolMethod.getGenericReturnType(); return toolCallResultConverter.convert(result, returnType); } private void validateToolContextSupport (@Nullable ToolContext toolContext) { var isNonEmptyToolContextProvided = toolContext != null && !CollectionUtils.isEmpty(toolContext.getContext()); var isToolContextAcceptedByMethod = Stream.of(toolMethod.getParameterTypes()) .anyMatch(type -> ClassUtils.isAssignable(type, ToolContext.class)); if (isToolContextAcceptedByMethod && !isNonEmptyToolContextProvided) { throw new IllegalArgumentException("ToolContext is required by the method as an argument" ); } } private Map<String, Object> extractToolArguments (String toolInput) { return JsonParser.fromJson(toolInput, new TypeReference<>() {}); } private Object[] buildMethodArguments(Map<String, Object> toolInputArguments, @Nullable ToolContext toolContext) { return Stream.of(toolMethod.getParameters()).map(parameter -> { if (parameter.getType().isAssignableFrom(ToolContext.class)) { return toolContext; } Object rawArgument = toolInputArguments.get(parameter.getName()); return buildTypedArgument(rawArgument, parameter.getType()); }).toArray(); } @Nullable private Object buildTypedArgument (@Nullable Object value, Class<?> type) { if (value == null ) { return null ; } return JsonParser.toTypedObject(value, type); } @Nullable private Object callMethod (Object[] methodArguments) { if (isObjectNotPublic() || isMethodNotPublic()) { toolMethod.setAccessible(true ); } Object result; try { result = toolMethod.invoke(toolObject, methodArguments); } catch (IllegalAccessException ex) { throw new IllegalStateException("Could not access method: " + ex.getMessage(), ex); } catch (InvocationTargetException ex) { throw new ToolExecutionException(toolDefinition, ex.getCause()); } return result; } private boolean isObjectNotPublic () { return toolObject != null && !Modifier.isPublic(toolObject.getClass().getModifiers()); } private boolean isMethodNotPublic () { return !Modifier.isPublic(toolMethod.getModifiers()); } @Override public String toString () { return "MethodToolCallback{" + "toolDefinition=" + toolDefinition + ", toolMetadata=" + toolMetadata + '}' ; } public static Builder builder () { return new Builder(); } public static class Builder { private ToolDefinition toolDefinition; private ToolMetadata toolMetadata; private Method toolMethod; private Object toolObject; private ToolCallResultConverter toolCallResultConverter; private Builder () {} public Builder toolDefinition (ToolDefinition toolDefinition) { this .toolDefinition = toolDefinition; return this ; } public Builder toolMetadata (ToolMetadata toolMetadata) { this .toolMetadata = toolMetadata; return this ; } public Builder toolMethod (Method toolMethod) { this .toolMethod = toolMethod; return this ; } public Builder toolObject (Object toolObject) { this .toolObject = toolObject; return this ; } public Builder toolCallResultConverter (ToolCallResultConverter toolCallResultConverter) { this .toolCallResultConverter = toolCallResultConverter; return this ; } public MethodToolCallback build () { return new MethodToolCallback(toolDefinition, toolMetadata, toolMethod, toolObject, toolCallResultConverter); } } }
核心方法主要关注call
模型提取的toolInput为json字符串,先转为定义的Request类型
I request = JsonParser.fromJson(toolInput, toolInputType);
工具调用,返回对应的工具结果
O response = toolFunction.apply(request, toolContext); public class FunctionToolCallback <I , O > implements ToolCallback { private static final Logger logger = LoggerFactory.getLogger(FunctionToolCallback.class); private static final ToolCallResultConverter DEFAULT_RESULT_CONVERTER = new DefaultToolCallResultConverter(); private static final ToolMetadata DEFAULT_TOOL_METADATA = ToolMetadata.builder().build(); private final ToolDefinition toolDefinition; private final ToolMetadata toolMetadata; private final Type toolInputType; private final BiFunction<I, ToolContext, O> toolFunction; private final ToolCallResultConverter toolCallResultConverter; public FunctionToolCallback (ToolDefinition toolDefinition, @Nullable ToolMetadata toolMetadata, Type toolInputType, BiFunction<I, ToolContext, O> toolFunction, @Nullable ToolCallResultConverter toolCallResultConverter) { Assert.notNull(toolDefinition, "toolDefinition cannot be null" ); Assert.notNull(toolInputType, "toolInputType cannot be null" ); Assert.notNull(toolFunction, "toolFunction cannot be null" ); this .toolDefinition = toolDefinition; this .toolMetadata = toolMetadata != null ? toolMetadata : DEFAULT_TOOL_METADATA; this .toolFunction = toolFunction; this .toolInputType = toolInputType; this .toolCallResultConverter = toolCallResultConverter != null ? toolCallResultConverter : DEFAULT_RESULT_CONVERTER; } @Override public ToolDefinition getToolDefinition () { return toolDefinition; } @Override public ToolMetadata getToolMetadata () { return toolMetadata; } @Override public String call (String toolInput) { return call(toolInput, null ); } @Override public String call (String toolInput, @Nullable ToolContext toolContext) { Assert.hasText(toolInput, "toolInput cannot be null or empty" ); logger.debug("Starting execution of tool: {}" , toolDefinition.name()); I request = JsonParser.fromJson(toolInput, toolInputType); O response = toolFunction.apply(request, toolContext); logger.debug("Successful execution of tool: {}" , toolDefinition.name()); return toolCallResultConverter.convert(response, null ); } @Override public String toString () { return "FunctionToolCallback{" + "toolDefinition=" + toolDefinition + ", toolMetadata=" + toolMetadata + '}' ; } public static <I, O> Builder<I, O> builder (String name, BiFunction<I, ToolContext, O> function) { return new Builder<>(name, function); } public static <I, O> Builder<I, O> builder (String name, Function<I, O> function) { Assert.notNull(function, "function cannot be null" ); return new Builder<>(name, (request, context) -> function.apply(request)); } public static <O> Builder<Void, O> builder (String name, Supplier<O> supplier) { Assert.notNull(supplier, "supplier cannot be null" ); Function<Void, O> function = input -> supplier.get(); return builder(name, function).inputType(Void.class); } public static <I> Builder<I, Void> builder (String name, Consumer<I> consumer) { Assert.notNull(consumer, "consumer cannot be null" ); Function<I, Void> function = (I input) -> { consumer.accept(input); return null ; }; return builder(name, function); } public static class Builder <I , O > { private String name; private String description; private String inputSchema; private Type inputType; private ToolMetadata toolMetadata; private BiFunction<I, ToolContext, O> toolFunction; private ToolCallResultConverter toolCallResultConverter; private Builder (String name, BiFunction<I, ToolContext, O> toolFunction) { Assert.hasText(name, "name cannot be null or empty" ); Assert.notNull(toolFunction, "toolFunction cannot be null" ); this .name = name; this .toolFunction = toolFunction; } public Builder<I, O> description (String description) { this .description = description; return this ; } public Builder<I, O> inputSchema (String inputSchema) { this .inputSchema = inputSchema; return this ; } public Builder<I, O> inputType (Type inputType) { this .inputType = inputType; return this ; } public Builder<I, O> inputType (ParameterizedTypeReference<?> inputType) { Assert.notNull(inputType, "inputType cannot be null" ); this .inputType = inputType.getType(); return this ; } public Builder<I, O> toolMetadata (ToolMetadata toolMetadata) { this .toolMetadata = toolMetadata; return this ; } public Builder<I, O> toolCallResultConverter (ToolCallResultConverter toolCallResultConverter) { this .toolCallResultConverter = toolCallResultConverter; return this ; } public FunctionToolCallback<I, O> build () { Assert.notNull(inputType, "inputType cannot be null" ); var toolDefinition = ToolDefinition.builder() .name(name) .description(StringUtils.hasText(description) ? description : ToolUtils.getToolDescriptionFromName(name)) .inputSchema(StringUtils.hasText(inputSchema) ? inputSchema : JsonSchemaGenerator.generateForType(inputType)) .build(); return new FunctionToolCallback<>(toolDefinition, toolMetadata, inputType, toolFunction, toolCallResultConverter); } } }
主要用于集中管理和提供工具回调
getToolCallbacks:获得工具回调数组
public interface ToolCallbackProvider { ToolCallback[] getToolCallbacks(); static ToolCallbackProvider from (List<? extends FunctionCallback> toolCallbacks) { return new StaticToolCallbackProvider(toolCallbacks); } static ToolCallbackProvider from (FunctionCallback... toolCallbacks) { return new StaticToolCallbackProvider(toolCallbacks); } }
获取MethodToolCallback实例
public class MethodToolCallbackProvider implements ToolCallbackProvider { private static final Logger logger = LoggerFactory.getLogger(MethodToolCallbackProvider.class); private final List<Object> toolObjects; private MethodToolCallbackProvider (List<Object> toolObjects) { Assert.notNull(toolObjects, "toolObjects cannot be null" ); Assert.noNullElements(toolObjects, "toolObjects cannot contain null elements" ); this .toolObjects = toolObjects; } @Override public ToolCallback[] getToolCallbacks() { var toolCallbacks = toolObjects.stream() .map(toolObject -> Stream .of(ReflectionUtils.getDeclaredMethods( AopUtils.isAopProxy(toolObject) ? AopUtils.getTargetClass(toolObject) : toolObject.getClass())) .filter(toolMethod -> toolMethod.isAnnotationPresent(Tool.class)) .filter(toolMethod -> !isFunctionalType(toolMethod)) .map(toolMethod -> MethodToolCallback.builder() .toolDefinition(ToolDefinition.from(toolMethod)) .toolMetadata(ToolMetadata.from(toolMethod)) .toolMethod(toolMethod) .toolObject(toolObject) .toolCallResultConverter(ToolUtils.getToolCallResultConverter(toolMethod)) .build()) .toArray(ToolCallback[]::new )) .flatMap(Stream::of) .toArray(ToolCallback[]::new ); validateToolCallbacks(toolCallbacks); return toolCallbacks; } private boolean isFunctionalType (Method toolMethod) { var isFunction = ClassUtils.isAssignable(toolMethod.getReturnType(), Function.class) || ClassUtils.isAssignable(toolMethod.getReturnType(), Supplier.class) || ClassUtils.isAssignable(toolMethod.getReturnType(), Consumer.class); if (isFunction) { logger.warn("Method {} is annotated with @Tool but returns a functional type. " + "This is not supported and the method will be ignored." , toolMethod.getName()); } return isFunction; } private void validateToolCallbacks (ToolCallback[] toolCallbacks) { List<String> duplicateToolNames = ToolUtils.getDuplicateToolNames(toolCallbacks); if (!duplicateToolNames.isEmpty()) { throw new IllegalStateException("Multiple tools with the same name (%s) found in sources: %s" .formatted( String.join(", " , duplicateToolNames), toolObjects.stream().map(o -> o.getClass().getName()).collect(Collectors.joining(", " )))); } } public static Builder builder () { return new Builder(); } public static class Builder { private List<Object> toolObjects; private Builder () {} public Builder toolObjects (Object... toolObjects) { Assert.notNull(toolObjects, "toolObjects cannot be null" ); this .toolObjects = Arrays.asList(toolObjects); return this ; } public MethodToolCallbackProvider build () { return new MethodToolCallbackProvider(toolObjects); } } }
提供FunctionToolCallback,但目测还没有实现该功能
public class StaticToolCallbackProvider implements ToolCallbackProvider { private final FunctionCallback[] toolCallbacks; public StaticToolCallbackProvider (FunctionCallback... toolCallbacks) { Assert.notNull(toolCallbacks, "ToolCallbacks must not be null" ); this .toolCallbacks = toolCallbacks; } public StaticToolCallbackProvider (List<? extends FunctionCallback> toolCallbacks) { Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements" ); this .toolCallbacks = toolCallbacks.toArray(new FunctionCallback[0 ]); } @Override public FunctionCallback[] getToolCallbacks() { return this .toolCallbacks; } }
public interface ToolCallingManager { List<ToolDefinition> resolveToolDefinitions (ToolCallingChatOptions chatOptions) ; ToolExecutionResult executeToolCalls (Prompt prompt, ChatResponse chatResponse) ; static DefaultToolCallingManager.Builder builder () { return DefaultToolCallingManager.builder(); } }
核心功能如下
解析工具定义(resolveToolDefinitions):从ToolCallingChatOptions中解析出工具定义,确保模型能正确识别和使用工具
执行工具调用(executeToolCalls):根据模型响应,执行相应的工具调用,并返回工具的执行结果
构建工具上下文(buildToolContext):为工具调用提供上下文信息,历史的Message记录
管理工具回调:通过 ToolCallbackResolver 解析工具回调,支持动态工具调用
public class DefaultToolCallingManager implements ToolCallingManager { @Override public List<ToolDefinition> resolveToolDefinitions (ToolCallingChatOptions chatOptions) { Assert.notNull(chatOptions, "chatOptions cannot be null" ); List<FunctionCallback> toolCallbacks = new ArrayList<>(chatOptions.getToolCallbacks()); for (String toolName : chatOptions.getToolNames()) { if (chatOptions.getToolCallbacks().stream().anyMatch(tool -> tool.getName().equals(toolName))) { continue ; } FunctionCallback toolCallback = toolCallbackResolver.resolve(toolName); if (toolCallback == null ) { throw new IllegalStateException("No ToolCallback found for tool name: " + toolName); } toolCallbacks.add(toolCallback); } return toolCallbacks.stream().map(functionCallback -> { if (functionCallback instanceof ToolCallback toolCallback) { return toolCallback.getToolDefinition(); } else { return ToolDefinition.builder() .name(functionCallback.getName()) .description(functionCallback.getDescription()) .inputSchema(functionCallback.getInputTypeSchema()) .build(); } }).toList(); } @Override public ToolExecutionResult executeToolCalls (Prompt prompt, ChatResponse chatResponse) { Assert.notNull(prompt, "prompt cannot be null" ); Assert.notNull(chatResponse, "chatResponse cannot be null" ); Optional<Generation> toolCallGeneration = chatResponse.getResults() .stream() .filter(g -> !CollectionUtils.isEmpty(g.getOutput().getToolCalls())) .findFirst(); if (toolCallGeneration.isEmpty()) { throw new IllegalStateException("No tool call requested by the chat model" ); } AssistantMessage assistantMessage = toolCallGeneration.get().getOutput(); ToolContext toolContext = buildToolContext(prompt, assistantMessage); InternalToolExecutionResult internalToolExecutionResult = executeToolCall(prompt, assistantMessage, toolContext); List<Message> conversationHistory = buildConversationHistoryAfterToolExecution(prompt.getInstructions(), assistantMessage, internalToolExecutionResult.toolResponseMessage()); return ToolExecutionResult.builder() .conversationHistory(conversationHistory) .returnDirect(internalToolExecutionResult.returnDirect()) .build(); } }
@FunctionalInterface public interface ToolCallResultConverter { String convert (@Nullable Object result, @Nullable Type returnType) ; }
ToolCallResultConverter接口类暂时的唯一实现,转为Json化的字符串
public final class DefaultToolCallResultConverter implements ToolCallResultConverter { private static final Logger logger = LoggerFactory.getLogger(DefaultToolCallResultConverter.class); @Override public String convert (@Nullable Object result, @Nullable Type returnType) { if (returnType == Void.TYPE) { logger.debug("The tool has no return type. Converting to conventional response." ); return "Done" ; } else { logger.debug("Converting tool result to JSON." ); return JsonParser.toJson(result); } } }
ToolContext(工具上下文) 被构建于工具回调管理器
作用:
用于封装工具执行的上下文信息,确保上下文不可变,从而保证线程安全
通过getContext方法获取整个上下文,通过getToolCallHistory方法获取Message的历史记录
public class ToolContext { public static final String TOOL_CALL_HISTORY = "TOOL_CALL_HISTORY" ; private final Map<String, Object> context; public ToolContext (Map<String, Object> context) { this .context = Collections.unmodifiableMap(context); } public Map<String, Object> getContext () { return this .context; } @SuppressWarnings("unchecked") public List<Message> getToolCallHistory () { return (List<Message>) this .context.get(TOOL_CALL_HISTORY); } }
从方法上提取名称,主要根据方法上是否有Tool注解,若无则统一设置为方法名
getToolName:获取工具名称
getToolDescriptionFromName:根据工具名称生成工具的描述
getToolDescription:获取工具描述
getToolReturnDirect:判断工具是否直接返回结果
getToolCallResultConverter:获取工具的结果转换器
getDuplicateToolNames:检查工具回调列表中是否有重复的工具名称
public final class ToolUtils { private ToolUtils () {} public static String getToolName (Method method) { Assert.notNull(method, "method cannot be null" ); var tool = method.getAnnotation(Tool.class); if (tool == null ) { return method.getName(); } return StringUtils.hasText(tool.name()) ? tool.name() : method.getName(); } public static String getToolDescriptionFromName (String toolName) { Assert.hasText(toolName, "toolName cannot be null or empty" ); return ParsingUtils.reConcatenateCamelCase(toolName, " " ); } public static String getToolDescription (Method method) { Assert.notNull(method, "method cannot be null" ); var tool = method.getAnnotation(Tool.class); if (tool == null ) { return ParsingUtils.reConcatenateCamelCase(method.getName(), " " ); } return StringUtils.hasText(tool.description()) ? tool.description() : method.getName(); } public static boolean getToolReturnDirect (Method method) { Assert.notNull(method, "method cannot be null" ); var tool = method.getAnnotation(Tool.class); return tool != null && tool.returnDirect(); } public static ToolCallResultConverter getToolCallResultConverter (Method method) { Assert.notNull(method, "method cannot be null" ); var tool = method.getAnnotation(Tool.class); if (tool == null ) { return new DefaultToolCallResultConverter(); } var type = tool.resultConverter(); try { return type.getDeclaredConstructor().newInstance(); } catch (Exception e) { throw new IllegalArgumentException("Failed to instantiate ToolCallResultConverter: " + type, e); } } public static List<String> getDuplicateToolNames (List<FunctionCallback> toolCallbacks) { Assert.notNull(toolCallbacks, "toolCallbacks cannot be null" ); return toolCallbacks.stream() .collect(Collectors.groupingBy(FunctionCallback::getName, Collectors.counting())) .entrySet() .stream() .filter(entry -> entry.getValue() > 1 ) .map(Map.Entry::getKey) .collect(Collectors.toList()); } public static List<String> getDuplicateToolNames (FunctionCallback... toolCallbacks) { Assert.notNull(toolCallbacks, "toolCallbacks cannot be null" ); return getDuplicateToolNames(Arrays.asList(toolCallbacks)); } }