AI Agent 基础概念

img

  • 感知(Perception)模块:获取外部环境信息,接收多模态输入(文本、图像、传感器数据),例如GPT-4o可端到端处理视觉与语音。
  • 决策(Planning)模块:基于LLM/AIGC大模型进行思考、规划、决策,通过思维链(CoT)、思维树(ToT)等技术拆解复杂目标,并基于反馈持续优化策略。
  • 行动(Action)模块:调用工具(API、数据库、搜索引擎、机器人肢体等)去完成任务。
  • 记忆(Memory)模块:记忆模块又可以分为短期记忆和长期记忆。短期记忆依赖大模型的上下文窗口(如128K tokens)存储数据和知识,用于优化任务效果。长期记忆则通过向量数据库+RAG实现历史经验存储与检索。

核心消息类型

在AI Agent主流框架设计中,定义了三种核心消息类型:System Prompt(System Message)Assistant Prompt(Assistant Message)User Prompt(User Message),三者功能明确区分:

  • User Prompt:代表用户的直接输入的问题。
  • Assistant Prompt:代表大模型生成的回复内容。
  • System Prompt:用于设定大模型的角色、基础指令(如身份界定、安全约束)等核心配置。

System Prompt与User Prompt的关键区别在于其位置与优先级:System Prompt固定设置在输入文本序列的开端。由于注意力机制的特性(序列首尾信息通常更受关注),该位置的内容更容易被模型识别和遵循。因此,一个完整的多轮对话提示词(Prompts)通常按以下模式拼接:

System Prompt -> User Prompt -> Assistant Prompt -> User Prompt ... -> Assistant Prompt

核心消息类型对照表

类型 作用 优先级 可见性 典型内容
System 设定角色/规则/安全边界 对用户不可见 角色、行为准则、安全策略、输出格式
User 任务/问题输入 可见 用户问题、需求、数据
Assistant 模型回复/中间思考 可见 解答、工具调用意图、思考(隐藏/可选)
Tool 工具执行结果(供模型消费) 可见(通常结构化) 外部系统结果、检索/计算输出

在此结构中,Assistant Prompt的主要作用是向大模型展示历史对话记录,并明确标注其中哪些内容源于用户的输入。经过这种结构模式数据预训练和微调的大模型能够理解:这些并非即时用户输入,而是对话历史。这有助于大模型更好地把握上下文信息,从而更准确地回应后续问题。

那么,为何不将System Prompt与User Prompt合并呢?一个重要考量在于安全性和可控性。通过在微调阶段区分消息类型,有助于防御提示词注入(Prompt Injection)等攻击手段。具体来说:

  1. 将核心角色定义和规则置于System Prompt中。
  2. 用户交互内容则放在User Prompt里。

上述这种分离机制能有效防范某些简单的提示词攻击或信息泄露风险。特别是在实际应用中,System Prompt对用户通常是不可见的。其定义的规则和角色经过充分训练,因而在模型中享有最高优先级。这显著提高了大模型遵循开发者意图的可能性,降低了因用户输入变化导致输出偏离预期的风险。

当然,仅依赖System Prompt并不能完全抵御攻击(例如,GPT-4 曾出现过System Prompt被诱导泄露的案例)。因此,对用户输入或模型输出进行二次校验,是更为稳妥的安全增强方案。

Spring AI 中的实现

Spring AI 消息 API

public enum MessageType {
USER("user"),
ASSISTANT("assistant"),
SYSTEM("system"),
TOOL("tool");
...
}

PromptTemplate

public class PromptTemplate implements PromptTemplateActions, PromptTemplateMessageActions {
// Other methods to be discussed later
}
String userText = """
Tell me about three famous pirates from the Golden Age of Piracy and why they did.
Write at least a sentence for each pirate.
""";

Message userMessage = new UserMessage(userText);

String systemText = """
You are a helpful AI assistant that helps people find information.
Your name is {name}
You should reply to the user's request with your name and also in the style of a {voice}.
""";

SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemText);
Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", name, "voice", voice));

Prompt prompt = new Prompt(List.of(userMessage, systemMessage));

List<Generation> response = chatModel.call(prompt).getResults();

AI Agent 设计模式

动图

Reflection pattern

动图封面

ReflectAgent reflectAgent = ReflectAgent.builder()
.graph(assistantGraphNode) // 生成节点
.reflection(judgeGraphNode) // 评判节点
.maxIterations(3)
.build();
public StateGraph createReflectionGraph(NodeAction graph, NodeAction reflection, int maxIterations) {
StateGraph stateGraph = new StateGraph(() -> {
HashMap<String, KeyStrategy> keyStrategyHashMap = new HashMap<>();
keyStrategyHashMap.put(MESSAGES, new ReplaceStrategy());
keyStrategyHashMap.put(ITERATION_NUM, new ReplaceStrategy());
return keyStrategyHashMap;
})
.addNode(GRAPH_NODE_ID, node_async(graph))
.addNode(REFLECTION_NODE_ID, node_async(reflection))
.addEdge(START, GRAPH_NODE_ID)
.addConditionalEdges(GRAPH_NODE_ID, edge_async(this::graphCount),
Map.of(REFLECTION_NODE_ID, REFLECTION_NODE_ID, END, END))
.addConditionalEdges(REFLECTION_NODE_ID, edge_async(this::apply),
Map.of(GRAPH_NODE_ID, GRAPH_NODE_ID, END, END));

return stateGraph;
}

// 迭代次数检查
private String graphCount(OverAllState state) {
int iterationNum = state.value(ITERATION_NUM, Integer.class).orElse(0);
state.updateState(Map.of(ITERATION_NUM, iterationNum + 1));

return iterationNum >= maxIterations ? END : REFLECTION_NODE_ID;
}

// 消息类型检查
private String apply(OverAllState state) {
List<Message> messages = state.value(MESSAGES, List.class).orElse(new ArrayList<>());
if (messages.isEmpty()) return END;

Message lastMessage = messages.get(messages.size() - 1);
return lastMessage instanceof UserMessage ? GRAPH_NODE_ID : END;
}

Tool use pattern

动图封面

ReAct 模式

动图封面

ReactAgent reactAgent = new ReactAgent(
"weatherAgent",
chatClient,
toolCallbacks,
10 // 最大迭代次数
);

// 编译并使用
CompiledGraph compiledGraph = reactAgent.getAndCompileGraph();

什么时候用/不该用(模式快速指引)

模式 何时使用 不建议使用
Reflection 需要自我评审/改写答案、对质量敏感 延迟极强场景或对成本极敏感
ReAct 需要“思考 + 工具调用”交替、信息检索或函数编排 工具很少或单次调用即可完成
Planning 需要中长期规划、分解子任务 任务简单、一步完成
Multi-agent 需要角色分工与协作 单人即可完成、沟通开销不划算
private StateGraph initGraph() throws GraphStateException {
StateGraph graph = new StateGraph(name, this.keyStrategyFactory);

// 添加核心节点
graph.addNode("llm", node_async(this.llmNode));
graph.addNode("tool", node_async(this.toolNode));

// 构建执行流程
graph.addEdge(START, "llm")
.addConditionalEdges("llm", edge_async(this::think),
Map.of("continue", "tool", "end", END))
.addEdge("tool", "llm");

return graph;
}

// 决策逻辑
private String think(OverAllState state) {
if (iterations > max_iterations) {
return "end";
}

List<Message> messages = (List<Message>) state.value("messages").orElseThrow();
AssistantMessage message = (AssistantMessage) messages.get(messages.size() - 1);

// 检查是否有工具调用
return message.hasToolCalls() ? "continue" : "end";
}

Planning pattern

动图封面

Multi-agent pattern

动图封面

Agent Memory

Memory(记忆)是让AI Agent能够存储、检索和利用过去交互中获取的信息的机制。它超越了单次对话/任务的限制,是AI Agent实现“连续性”和“个性化”的核心组件。

  • 记住用户偏好:例如,用户在历史对话中说喜欢科幻小说,下次推荐书籍时就可以优先考虑这个类型。
  • 维持对话上下文:进行多轮对话,能够理解一些名词的指代(如“他”、“她”、“它”、“那个”指的是什么)。
  • 从历史中学习:总结过去的成功或失败经验,优化未来的决策和行动。
  • 积累知识:将新获取的信息存入知识库,供长期使用。

当前在AI Agent中增加记忆机制有短期记忆、长期记忆以及工具记忆三种模式。

  1. **短期记忆 / 对话记忆(Short-term / Conversation Memory)**主要用于维护当前对话/任务的上下文。通常有窗口限制(只记住最近N轮对话,有底层核心AIGC/LLM大模型上下文窗口决定)。
  2. **长期记忆(Long-term Memory)主要在较长时间跨度内存储和回忆重要信息。容量远大于短期记忆,需要通过检索来获取相关信息。长期记忆机制通常与向量数据库(Vector Database)**结合,将信息转换为向量嵌入(Embeddings)后进行存储和相似性检索。
  3. 工具记忆 (Tool Memory):记录AI Agent调用工具(如API、函数)的历史和结果。这对于需要多步执行复杂任务的Agent至关重要,它可以回顾之前的工具调用结果来决定下一步动作。

Memory机制的核心可以简化为两个基本操作:读取(Read)写入(Write)

  1. 写入(Write)操作在AI Agent完成一次动作(如回复用户、调用工具)后。将有价值的交互信息(如用户输入、AI输出、工具执行结果、自主总结的要点)持久化到存储中(内存、数据库、文件等)。
  2. 读取(Retrieve / Read)操作在AI Agent处理新的输入或开始新任务之前。根据当前的查询(Query)(如用户的新问题),从记忆存储中查找最相关的信息片段。对于长期记忆,通常使用向量相似性检索。将查询也转换为向量,然后从向量数据库中找出与之最相似的记忆向量

    记忆合并到提示词的方式(实践对照)

  • 短期记忆:以消息集合注入(Message 形式),适合多轮上下文保持。
  • Prompt 拼接:将历史摘要拼到 System Prompt,适合“稳定的长规则+短上下文”。
  • 长期记忆(向量库):召回相关片段拼接为长文本,适合知识库/文档问答。

示例

MessageWindowChatMemory

Spring AI 提供多个内置 Advisor 来配置 ChatClient 的记忆行为:

  • MessageChatMemoryAdvisor:使用提供的 ChatMemory 管理会话记忆。每次交互时,从记忆中检索会话历史,并将其作为消息集合包含在提示中。
  • PromptChatMemoryAdvisor:使用提供的 ChatMemory 管理会话记忆。每次交互时,从记忆中检索会话历史,并将其作为纯文本附加到系统提示。
  • VectorStoreChatMemoryAdvisor:使用提供的 VectorStore 管理会话记忆。每次交互时,从向量库检索会话历史,并将其作为纯文本附加到系统消息。
ChatMemory chatMemory = MessageWindowChatMemory.builder().build();  

ChatClient chatClient = ChatClient.builder(chatModel)
.defaultAdvisors(MessageChatMemoryAdvisor.builder(chatMemory).build())
.build();

当执行对ChatClient的调用时,内存将由MessageChatMemoryAdvisor自动管理。会话历史将根据指定的会话ID从内存中检索:

String conversationId = "007";  
String response = chatClient.prompt()
.user("Do I have license to code?")
.advisors(a -> a.param(ChatMemory.CONVERSATION_ID, conversationId))
.call()
.content();

PromptChatMemoryAdvisor

通过 .promptTemplate() 方法提供自定义 PromptTemplate 来覆盖默认行为:

PromptTemplate customTemplate = new PromptTemplate("""  
系统指令:{instructions}
历史记忆:{memory}
""");

PromptChatMemoryAdvisor advisor = PromptChatMemoryAdvisor.builder(chatMemory)
.promptTemplate(customTemplate)
.build();

注意:模板必须包含 {instructions}{memory} 占位符。

VectorStoreChatMemoryAdvisor

自定义模板

类似地,可通过 .promptTemplate() 自定义向量存储记忆的合并方式:

PromptTemplate customTemplate = new PromptTemplate("""  
系统指令:{instructions}
长期记忆:{long_term_memory}
""");

VectorStoreChatMemoryAdvisor advisor = VectorStoreChatMemoryAdvisor.builder(vectorStore)
.promptTemplate(customTemplate)
.build();

注意:模板必须包含 {instructions}{long_term_memory} 占位符。

ChatModel 显式管理记忆

// 创建记忆实例  
ChatMemory chatMemory = MessageWindowChatMemory.builder().build();
String conversationId = "007";

// 首次交互
UserMessage userMessage1 = new UserMessage("My name is James Bond");
chatMemory.add(conversationId, userMessage1);
ChatResponse response1 = chatModel.call(new Prompt(chatMemory.get(conversationId)));
chatMemory.add(conversationId, response1.getResult().getOutput());

// 二次交互
UserMessage userMessage2 = new UserMessage("What is my name?");
chatMemory.add(conversationId, userMessage2);
ChatResponse response2 = chatModel.call(new Prompt(chatMemory.get(conversationId)));
chatMemory.add(conversationId, response2.getResult().getOutput());

// 响应将包含 "James Bond"

MCP

动图封面

Spring AI 实现

img

调用链路

初始化连接

img

调用工具

img

示例

Tool Calling

img

  1. 在聊天请求中包含工具的定义,包括工具名称、描述、输入模式
  2. 当AI模型决定调用一个工具时,会发送一个响应,包含工具名称和输入参数(大模型提取文本根据输入模式转化而得)
  3. 应用程序将使用工具名称并使用提供的输入参数
  4. 工具计算结果后将结果返回给应用程序
  5. 应用程序再将结果发送给模型
  6. 模型添加工具结果作为附加的上下文信息生成最终响应

工具调用链路

img

public class DefaultChatClient implements ChatClient {
@Override
public ChatClientRequestSpec tools(String... toolNames) {
Assert.notNull(toolNames, "toolNames cannot be null");
Assert.noNullElements(toolNames, "toolNames cannot contain null elements");
this.functionNames.addAll(List.of(toolNames));
return this;
}
@Override
public ChatClientRequestSpec tools(FunctionCallback... toolCallbacks) {
Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements");
this.functionCallbacks.addAll(List.of(toolCallbacks));
return this;
}
@Override
public ChatClientRequestSpec tools(List<ToolCallback> toolCallbacks) {
Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements");
this.functionCallbacks.addAll(toolCallbacks);
return this;
}
@Override
public ChatClientRequestSpec tools(Object... toolObjects) {
Assert.notNull(toolObjects, "toolObjects cannot be null");
Assert.noNullElements(toolObjects, "toolObjects cannot contain null elements");
this.functionCallbacks.addAll(Arrays.asList(ToolCallbacks.from(toolObjects)));
return this;
}
@Override
public ChatClientRequestSpec tools(ToolCallbackProvider... toolCallbackProviders) {
Assert.notNull(toolCallbackProviders, "toolCallbackProviders cannot be null");
Assert.noNullElements(toolCallbackProviders, "toolCallbackProviders cannot contain null elements");
for (ToolCallbackProvider toolCallbackProvider : toolCallbackProviders) {
this.functionCallbacks.addAll(List.of(toolCallbackProvider.getToolCallbacks()));
}
return this;
}
}

Spring AI 最小可运行示例

以下示例展示:依赖、初始化、一个 @Tool、一次调用。

  1. 依赖(以 Maven 为例,选择任意提供商实现)
<dependencyManagement>
<dependencies>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-bom</artifactId>
<version>1.0.0</version>
<type>pom</type>
<scope>import</scope>
</dependency>
</dependencies>
</dependencyManagement>

<dependencies>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-openai-spring-boot-starter</artifactId>
</dependency>
</dependencies>
  1. 配置(application.yml)
spring:
ai:
openai:
api-key: ${OPENAI_API_KEY}
  1. 定义 Tool 并调用
@Component
public class MathTools {
@Tool(name = "addTwoNumbers", description = "Add two integers")
public int addTwoNumbers(int a, int b) { return a + b; }
}

@RestController
public class DemoController {
private final ChatClient chatClient;
private final MethodToolCallbackProvider toolProvider;

public DemoController(ChatClient chatClient, MathTools tools) {
this.chatClient = chatClient;
this.toolProvider = MethodToolCallbackProvider.builder().toolObjects(tools).build();
}

@GetMapping("/demo")
public String demo() {
return chatClient
.prompt()
.system("You are a helpful assistant.")
.user("请计算 12 和 30 的和")
.tools(ToolCallbackProvider.from(toolProvider.getToolCallbacks()))
.call()
.content();
}
}

Tool(工具注解)

@Target({ ElementType.METHOD, ElementType.ANNOTATION_TYPE })
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface Tool {
/**
* The name of the tool. If not provided, the method name will be used.
*/
String name() default "";
/**
* The description of the tool. If not provided, the method name will be used.
*/
String description() default "";
/**
* Whether the tool result should be returned directly or passed back to the model.
*/
boolean returnDirect() default false;
/**
* The class to use to convert the tool call result to a String.
*/
Class<? extends ToolCallResultConverter> resultConverter() default DefaultToolCallResultConverter.class;
}

ToolDefinition(工具定义)

public interface ToolDefinition {
// 工具的名称,提供给一个模型时要求唯一标识
String name();
// 工具描述
String description();
// 工具的输入模式
String inputSchema();
static DefaultToolDefinition.Builder builder() {
return DefaultToolDefinition.builder();
}
// 从方法中提取出工具的名称、描述、输入模式
static DefaultToolDefinition.Builder builder(Method method) {
Assert.notNull(method, "method cannot be null");
return DefaultToolDefinition.builder()
.name(ToolUtils.getToolName(method))
.description(ToolUtils.getToolDescription(method))
.inputSchema(JsonSchemaGenerator.generateForMethodInput(method));
}
static ToolDefinition from(Method method) {
return ToolDefinition.builder(method).build();
}
}

DefaultToolDefinition

public record DefaultToolDefinition(String name, String description, String inputSchema) implements ToolDefinition {
public DefaultToolDefinition {
Assert.hasText(name, "name cannot be null or empty");
Assert.hasText(description, "description cannot be null or empty");
Assert.hasText(inputSchema, "inputSchema cannot be null or empty");
}
public static Builder builder() {
return new Builder();
}
public static class Builder {
private String name;
private String description;
private String inputSchema;
private Builder() {}
public Builder name(String name) {
this.name = name;
return this;
}
public Builder description(String description) {
this.description = description;
return this;
}
public Builder inputSchema(String inputSchema) {
this.inputSchema = inputSchema;
return this;
}
public ToolDefinition build() {
if (!StringUtils.hasText(description)) {
description = ToolUtils.getToolDescriptionFromName(name);
}
return new DefaultToolDefinition(name, description, inputSchema);
}
}
}

ToolMetadata(工具元数据)

现阶段只用于控制直接将工具结果返回,不再走模型响应

public interface ToolMetadata {
default boolean returnDirect() {
return false;
}
static DefaultToolMetadata.Builder builder() {
return DefaultToolMetadata.builder();
}
static ToolMetadata from(Method method) {
Assert.notNull(method, "method cannot be null");
return DefaultToolMetadata.builder().returnDirect(ToolUtils.getToolReturnDirect(method)).build();
}
}

DefaultToolMetadata

public record DefaultToolMetadata(boolean returnDirect) implements ToolMetadata {
public static Builder builder() {
return new Builder();
}
public static class Builder {
private boolean returnDirect = false;
private Builder() {}
public Builder returnDirect(boolean returnDirect) {
this.returnDirect = returnDirect;
return this;
}
public ToolMetadata build() {
return new DefaultToolMetadata(returnDirect);
}
}
}

ToolCallback(工具回调)

public interface ToolCallback{
// AI模型用来确定何时以及如何调用工具的定义
ToolDefinition getToolDefinition();
// 元数据提供了额外的信息怎么操作工具
default ToolMetadata getToolMetadata() {
return ToolMetadata.builder().build();
}
// toolInput为工具的输入,最终返回结果工具的结果
String call(String toolInput);
// toolInput为工具的输入,tooContext为工具的上下文信息
default String call(String toolInput, @Nullable ToolContext tooContext) {
if (tooContext != null && !tooContext.getContext().isEmpty()) {
throw new UnsupportedOperationException("Tool context is not supported!");
}
return call(toolInput);
}
}

MethodToolCallback

核心方法主要关注call

  1. 将模型处理后的字符串文本,转化为对应的输入模式
Map<String, Object> toolArguments = extractToolArguments(toolInput);

Object[] methodArguments = buildMethodArguments(toolArguments, toolContext);
  1. 调用工具的方法+输入参数,得到工具的输出结果

    Object result = callMethod(methodArguments);
  2. 将工具的输出结果的类型进行转化

Type returnType = toolMethod.getGenericReturnType();

return toolCallResultConverter.convert(result, returnType);
public class MethodToolCallback implements ToolCallback {
private static final Logger logger = LoggerFactory.getLogger(MethodToolCallback.class);
private static final ToolCallResultConverter DEFAULT_RESULT_CONVERTER = new DefaultToolCallResultConverter();
private static final ToolMetadata DEFAULT_TOOL_METADATA = ToolMetadata.builder().build();
private final ToolDefinition toolDefinition;
private final ToolMetadata toolMetadata;
private final Method toolMethod;
@Nullable
private final Object toolObject;
private final ToolCallResultConverter toolCallResultConverter;
public MethodToolCallback(ToolDefinition toolDefinition, @Nullable ToolMetadata toolMetadata, Method toolMethod,
@Nullable Object toolObject, @Nullable ToolCallResultConverter toolCallResultConverter) {
Assert.notNull(toolDefinition, "toolDefinition cannot be null");
Assert.notNull(toolMethod, "toolMethod cannot be null");
Assert.isTrue(Modifier.isStatic(toolMethod.getModifiers()) || toolObject != null,
"toolObject cannot be null for non-static methods");
this.toolDefinition = toolDefinition;
this.toolMetadata = toolMetadata != null ? toolMetadata : DEFAULT_TOOL_METADATA;
this.toolMethod = toolMethod;
this.toolObject = toolObject;
this.toolCallResultConverter = toolCallResultConverter != null ? toolCallResultConverter : DEFAULT_RESULT_CONVERTER;
}
@Override
public ToolDefinition getToolDefinition() {
return toolDefinition;
}
@Override
public ToolMetadata getToolMetadata() {
return toolMetadata;
}
@Override
public String call(String toolInput) {
return call(toolInput, null);
}
@Override
public String call(String toolInput, @Nullable ToolContext toolContext) {
Assert.hasText(toolInput, "toolInput cannot be null or empty");
logger.debug("Starting execution of tool: {}", toolDefinition.name());
validateToolContextSupport(toolContext);
Map<String, Object> toolArguments = extractToolArguments(toolInput);
Object[] methodArguments = buildMethodArguments(toolArguments, toolContext);
Object result = callMethod(methodArguments);
logger.debug("Successful execution of tool: {}", toolDefinition.name());
Type returnType = toolMethod.getGenericReturnType();
return toolCallResultConverter.convert(result, returnType);
}
private void validateToolContextSupport(@Nullable ToolContext toolContext) {
var isNonEmptyToolContextProvided = toolContext != null && !CollectionUtils.isEmpty(toolContext.getContext());
var isToolContextAcceptedByMethod = Stream.of(toolMethod.getParameterTypes())
.anyMatch(type -> ClassUtils.isAssignable(type, ToolContext.class));
if (isToolContextAcceptedByMethod && !isNonEmptyToolContextProvided) {
throw new IllegalArgumentException("ToolContext is required by the method as an argument");
}
}
private Map<String, Object> extractToolArguments(String toolInput) {
return JsonParser.fromJson(toolInput, new TypeReference<>() {});
}
// Based on the implementation in MethodInvokingFunctionCallback.
private Object[] buildMethodArguments(Map<String, Object> toolInputArguments, @Nullable ToolContext toolContext) {
return Stream.of(toolMethod.getParameters()).map(parameter -> {
if (parameter.getType().isAssignableFrom(ToolContext.class)) {
return toolContext;
}
Object rawArgument = toolInputArguments.get(parameter.getName());
return buildTypedArgument(rawArgument, parameter.getType());
}).toArray();
}
@Nullable
private Object buildTypedArgument(@Nullable Object value, Class<?> type) {
if (value == null) {
return null;
}
return JsonParser.toTypedObject(value, type);
}
@Nullable
private Object callMethod(Object[] methodArguments) {
if (isObjectNotPublic() || isMethodNotPublic()) {
toolMethod.setAccessible(true);
}
Object result;
try {
result = toolMethod.invoke(toolObject, methodArguments);
}
catch (IllegalAccessException ex) {
throw new IllegalStateException("Could not access method: " + ex.getMessage(), ex);
}
catch (InvocationTargetException ex) {
throw new ToolExecutionException(toolDefinition, ex.getCause());
}
return result;
}
private boolean isObjectNotPublic() {
return toolObject != null && !Modifier.isPublic(toolObject.getClass().getModifiers());
}
private boolean isMethodNotPublic() {
return !Modifier.isPublic(toolMethod.getModifiers());
}
@Override
public String toString() {
return "MethodToolCallback{" + "toolDefinition=" + toolDefinition + ", toolMetadata=" + toolMetadata + '}';
}
public static Builder builder() {
return new Builder();
}
public static class Builder {
private ToolDefinition toolDefinition;
private ToolMetadata toolMetadata;
private Method toolMethod;
private Object toolObject;
private ToolCallResultConverter toolCallResultConverter;
private Builder() {}
public Builder toolDefinition(ToolDefinition toolDefinition) {
this.toolDefinition = toolDefinition;
return this;
}
public Builder toolMetadata(ToolMetadata toolMetadata) {
this.toolMetadata = toolMetadata;
return this;
}
public Builder toolMethod(Method toolMethod) {
this.toolMethod = toolMethod;
return this;
}
public Builder toolObject(Object toolObject) {
this.toolObject = toolObject;
return this;
}
public Builder toolCallResultConverter(ToolCallResultConverter toolCallResultConverter) {
this.toolCallResultConverter = toolCallResultConverter;
return this;
}
public MethodToolCallback build() {
return new MethodToolCallback(toolDefinition, toolMetadata, toolMethod, toolObject, toolCallResultConverter);
}
}
}

FunctionToolCallback

核心方法主要关注call

  1. 模型提取的toolInput为json字符串,先转为定义的Request类型
I request = JsonParser.fromJson(toolInput, toolInputType);
  1. 工具调用,返回对应的工具结果
O response = toolFunction.apply(request, toolContext);

public class FunctionToolCallback<I, O> implements ToolCallback {
private static final Logger logger = LoggerFactory.getLogger(FunctionToolCallback.class);
private static final ToolCallResultConverter DEFAULT_RESULT_CONVERTER = new DefaultToolCallResultConverter();
private static final ToolMetadata DEFAULT_TOOL_METADATA = ToolMetadata.builder().build();
private final ToolDefinition toolDefinition;
private final ToolMetadata toolMetadata;
private final Type toolInputType;
private final BiFunction<I, ToolContext, O> toolFunction;
private final ToolCallResultConverter toolCallResultConverter;
public FunctionToolCallback(ToolDefinition toolDefinition, @Nullable ToolMetadata toolMetadata, Type toolInputType,
BiFunction<I, ToolContext, O> toolFunction, @Nullable ToolCallResultConverter toolCallResultConverter) {
Assert.notNull(toolDefinition, "toolDefinition cannot be null");
Assert.notNull(toolInputType, "toolInputType cannot be null");
Assert.notNull(toolFunction, "toolFunction cannot be null");
this.toolDefinition = toolDefinition;
this.toolMetadata = toolMetadata != null ? toolMetadata : DEFAULT_TOOL_METADATA;
this.toolFunction = toolFunction;
this.toolInputType = toolInputType;
this.toolCallResultConverter = toolCallResultConverter != null ? toolCallResultConverter : DEFAULT_RESULT_CONVERTER;
}
@Override
public ToolDefinition getToolDefinition() {
return toolDefinition;
}
@Override
public ToolMetadata getToolMetadata() {
return toolMetadata;
}
@Override
public String call(String toolInput) {
return call(toolInput, null);
}
@Override
public String call(String toolInput, @Nullable ToolContext toolContext) {
Assert.hasText(toolInput, "toolInput cannot be null or empty");
logger.debug("Starting execution of tool: {}", toolDefinition.name());
I request = JsonParser.fromJson(toolInput, toolInputType);
O response = toolFunction.apply(request, toolContext);
logger.debug("Successful execution of tool: {}", toolDefinition.name());
return toolCallResultConverter.convert(response, null);
}
@Override
public String toString() {
return "FunctionToolCallback{" + "toolDefinition=" + toolDefinition + ", toolMetadata=" + toolMetadata + '}';
}
public static <I, O> Builder<I, O> builder(String name, BiFunction<I, ToolContext, O> function) {
return new Builder<>(name, function);
}
public static <I, O> Builder<I, O> builder(String name, Function<I, O> function) {
Assert.notNull(function, "function cannot be null");
return new Builder<>(name, (request, context) -> function.apply(request));
}
public static <O> Builder<Void, O> builder(String name, Supplier<O> supplier) {
Assert.notNull(supplier, "supplier cannot be null");
Function<Void, O> function = input -> supplier.get();
return builder(name, function).inputType(Void.class);
}
public static <I> Builder<I, Void> builder(String name, Consumer<I> consumer) {
Assert.notNull(consumer, "consumer cannot be null");
Function<I, Void> function = (I input) -> {
consumer.accept(input);
return null;
};
return builder(name, function);
}
public static class Builder<I, O> {
private String name;
private String description;
private String inputSchema;
private Type inputType;
private ToolMetadata toolMetadata;
private BiFunction<I, ToolContext, O> toolFunction;
private ToolCallResultConverter toolCallResultConverter;
private Builder(String name, BiFunction<I, ToolContext, O> toolFunction) {
Assert.hasText(name, "name cannot be null or empty");
Assert.notNull(toolFunction, "toolFunction cannot be null");
this.name = name;
this.toolFunction = toolFunction;
}
public Builder<I, O> description(String description) {
this.description = description;
return this;
}
public Builder<I, O> inputSchema(String inputSchema) {
this.inputSchema = inputSchema;
return this;
}
public Builder<I, O> inputType(Type inputType) {
this.inputType = inputType;
return this;
}
public Builder<I, O> inputType(ParameterizedTypeReference<?> inputType) {
Assert.notNull(inputType, "inputType cannot be null");
this.inputType = inputType.getType();
return this;
}
public Builder<I, O> toolMetadata(ToolMetadata toolMetadata) {
this.toolMetadata = toolMetadata;
return this;
}
public Builder<I, O> toolCallResultConverter(ToolCallResultConverter toolCallResultConverter) {
this.toolCallResultConverter = toolCallResultConverter;
return this;
}
public FunctionToolCallback<I, O> build() {
Assert.notNull(inputType, "inputType cannot be null");
var toolDefinition = ToolDefinition.builder()
.name(name)
.description(StringUtils.hasText(description) ? description : ToolUtils.getToolDescriptionFromName(name))
.inputSchema(StringUtils.hasText(inputSchema) ? inputSchema : JsonSchemaGenerator.generateForType(inputType))
.build();
return new FunctionToolCallback<>(toolDefinition, toolMetadata, inputType, toolFunction, toolCallResultConverter);
}
}
}

ToolCallbackProvider(工具回调实例提供)

主要用于集中管理和提供工具回调

  • getToolCallbacks:获得工具回调数组
public interface ToolCallbackProvider {
ToolCallback[] getToolCallbacks();
static ToolCallbackProvider from(List<? extends FunctionCallback> toolCallbacks) {
return new StaticToolCallbackProvider(toolCallbacks);
}
static ToolCallbackProvider from(FunctionCallback... toolCallbacks) {
return new StaticToolCallbackProvider(toolCallbacks);
}
}

MethodToolCallbackProvider

获取MethodToolCallback实例

public class MethodToolCallbackProvider implements ToolCallbackProvider {
private static final Logger logger = LoggerFactory.getLogger(MethodToolCallbackProvider.class);
private final List<Object> toolObjects;
private MethodToolCallbackProvider(List<Object> toolObjects) {
Assert.notNull(toolObjects, "toolObjects cannot be null");
Assert.noNullElements(toolObjects, "toolObjects cannot contain null elements");
this.toolObjects = toolObjects;
}
@Override
public ToolCallback[] getToolCallbacks() {
var toolCallbacks = toolObjects.stream()
.map(toolObject -> Stream
.of(ReflectionUtils.getDeclaredMethods(
AopUtils.isAopProxy(toolObject) ? AopUtils.getTargetClass(toolObject) : toolObject.getClass()))
.filter(toolMethod -> toolMethod.isAnnotationPresent(Tool.class))
.filter(toolMethod -> !isFunctionalType(toolMethod))
.map(toolMethod -> MethodToolCallback.builder()
.toolDefinition(ToolDefinition.from(toolMethod))
.toolMetadata(ToolMetadata.from(toolMethod))
.toolMethod(toolMethod)
.toolObject(toolObject)
.toolCallResultConverter(ToolUtils.getToolCallResultConverter(toolMethod))
.build())
.toArray(ToolCallback[]::new))
.flatMap(Stream::of)
.toArray(ToolCallback[]::new);
validateToolCallbacks(toolCallbacks);
return toolCallbacks;
}
private boolean isFunctionalType(Method toolMethod) {
var isFunction = ClassUtils.isAssignable(toolMethod.getReturnType(), Function.class)
|| ClassUtils.isAssignable(toolMethod.getReturnType(), Supplier.class)
|| ClassUtils.isAssignable(toolMethod.getReturnType(), Consumer.class);
if (isFunction) {
logger.warn("Method {} is annotated with @Tool but returns a functional type. "
+ "This is not supported and the method will be ignored.", toolMethod.getName());
}
return isFunction;
}
private void validateToolCallbacks(ToolCallback[] toolCallbacks) {
List<String> duplicateToolNames = ToolUtils.getDuplicateToolNames(toolCallbacks);
if (!duplicateToolNames.isEmpty()) {
throw new IllegalStateException("Multiple tools with the same name (%s) found in sources: %s".formatted(
String.join(", ", duplicateToolNames),
toolObjects.stream().map(o -> o.getClass().getName()).collect(Collectors.joining(", "))));
}
}
public static Builder builder() {
return new Builder();
}
public static class Builder {
private List<Object> toolObjects;
private Builder() {}
public Builder toolObjects(Object... toolObjects) {
Assert.notNull(toolObjects, "toolObjects cannot be null");
this.toolObjects = Arrays.asList(toolObjects);
return this;
}
public MethodToolCallbackProvider build() {
return new MethodToolCallbackProvider(toolObjects);
}
}
}

StaticToolCallbackProvider

提供FunctionToolCallback,但目测还没有实现该功能

public class StaticToolCallbackProvider implements ToolCallbackProvider {
private final FunctionCallback[] toolCallbacks;
public StaticToolCallbackProvider(FunctionCallback... toolCallbacks) {
Assert.notNull(toolCallbacks, "ToolCallbacks must not be null");
this.toolCallbacks = toolCallbacks;
}
public StaticToolCallbackProvider(List<? extends FunctionCallback> toolCallbacks) {
Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements");
this.toolCallbacks = toolCallbacks.toArray(new FunctionCallback[0]);
}
@Override
public FunctionCallback[] getToolCallbacks() {
return this.toolCallbacks;
}
}

ToolCallingManager(工具回调管理器)

public interface ToolCallingManager {
// 从配置中提取工具的定义,确保模型能正确识别和使用工具
List<ToolDefinition> resolveToolDefinitions(ToolCallingChatOptions chatOptions);
// 根据模型的响应,执行响应的工具调用,并返回执行结果
ToolExecutionResult executeToolCalls(Prompt prompt, ChatResponse chatResponse);
// 构建工具调用管理器
static DefaultToolCallingManager.Builder builder() {
return DefaultToolCallingManager.builder();
}
}

DefaultToolCallingManager

核心功能如下

  1. 解析工具定义(resolveToolDefinitions):从ToolCallingChatOptions中解析出工具定义,确保模型能正确识别和使用工具
  2. 执行工具调用(executeToolCalls):根据模型响应,执行相应的工具调用,并返回工具的执行结果
  3. 构建工具上下文(buildToolContext):为工具调用提供上下文信息,历史的Message记录
  4. 管理工具回调:通过 ToolCallbackResolver 解析工具回调,支持动态工具调用
public class DefaultToolCallingManager implements ToolCallingManager {
@Override
public List<ToolDefinition> resolveToolDefinitions(ToolCallingChatOptions chatOptions) {
Assert.notNull(chatOptions, "chatOptions cannot be null");
List<FunctionCallback> toolCallbacks = new ArrayList<>(chatOptions.getToolCallbacks());
for (String toolName : chatOptions.getToolNames()) {
// Skip the tool if it is already present in the request toolCallbacks.
// That might happen if a tool is defined in the options
// both as a ToolCallback and as a tool name.
if (chatOptions.getToolCallbacks().stream().anyMatch(tool -> tool.getName().equals(toolName))) {
continue;
}
FunctionCallback toolCallback = toolCallbackResolver.resolve(toolName);
if (toolCallback == null) {
throw new IllegalStateException("No ToolCallback found for tool name: " + toolName);
}
toolCallbacks.add(toolCallback);
}
return toolCallbacks.stream().map(functionCallback -> {
if (functionCallback instanceof ToolCallback toolCallback) {
return toolCallback.getToolDefinition();
} else {
return ToolDefinition.builder()
.name(functionCallback.getName())
.description(functionCallback.getDescription())
.inputSchema(functionCallback.getInputTypeSchema())
.build();
}
}).toList();
}
@Override
public ToolExecutionResult executeToolCalls(Prompt prompt, ChatResponse chatResponse) {
Assert.notNull(prompt, "prompt cannot be null");
Assert.notNull(chatResponse, "chatResponse cannot be null");
Optional<Generation> toolCallGeneration = chatResponse.getResults()
.stream()
.filter(g -> !CollectionUtils.isEmpty(g.getOutput().getToolCalls()))
.findFirst();
if (toolCallGeneration.isEmpty()) {
throw new IllegalStateException("No tool call requested by the chat model");
}
AssistantMessage assistantMessage = toolCallGeneration.get().getOutput();
ToolContext toolContext = buildToolContext(prompt, assistantMessage);
InternalToolExecutionResult internalToolExecutionResult = executeToolCall(prompt, assistantMessage, toolContext);
List<Message> conversationHistory = buildConversationHistoryAfterToolExecution(prompt.getInstructions(),
assistantMessage, internalToolExecutionResult.toolResponseMessage());
return ToolExecutionResult.builder()
.conversationHistory(conversationHistory)
.returnDirect(internalToolExecutionResult.returnDirect())
.build();
}
}

ToolCallResultConverter(工具结果转换器)

@FunctionalInterface
public interface ToolCallResultConverter {
// result:工具结果,returnType:返回类型
String convert(@Nullable Object result, @Nullable Type returnType);
}

DefaultToolCallResultConverter

ToolCallResultConverter接口类暂时的唯一实现,转为Json化的字符串

public final class DefaultToolCallResultConverter implements ToolCallResultConverter {
private static final Logger logger = LoggerFactory.getLogger(DefaultToolCallResultConverter.class);
@Override
public String convert(@Nullable Object result, @Nullable Type returnType) {
if (returnType == Void.TYPE) {
logger.debug("The tool has no return type. Converting to conventional response.");
return "Done";
} else {
logger.debug("Converting tool result to JSON.");
return JsonParser.toJson(result);
}
}
}

ToolContext(工具上下文)

被构建于工具回调管理器

作用:

  1. 用于封装工具执行的上下文信息,确保上下文不可变,从而保证线程安全
  2. 通过getContext方法获取整个上下文,通过getToolCallHistory方法获取Message的历史记录
public class ToolContext {
public static final String TOOL_CALL_HISTORY = "TOOL_CALL_HISTORY";
private final Map<String, Object> context;
public ToolContext(Map<String, Object> context) {
this.context = Collections.unmodifiableMap(context);
}
public Map<String, Object> getContext() {
return this.context;
}
@SuppressWarnings("unchecked")
public List<Message> getToolCallHistory() {
return (List<Message>) this.context.get(TOOL_CALL_HISTORY);
}
}

ToolUtils(工具常见方法封装)

从方法上提取名称,主要根据方法上是否有Tool注解,若无则统一设置为方法名

  1. getToolName:获取工具名称
  2. getToolDescriptionFromName:根据工具名称生成工具的描述
  3. getToolDescription:获取工具描述
  4. getToolReturnDirect:判断工具是否直接返回结果
  5. getToolCallResultConverter:获取工具的结果转换器
  6. getDuplicateToolNames:检查工具回调列表中是否有重复的工具名称
public final class ToolUtils {
private ToolUtils() {}
public static String getToolName(Method method) {
Assert.notNull(method, "method cannot be null");
var tool = method.getAnnotation(Tool.class);
if (tool == null) {
return method.getName();
}
return StringUtils.hasText(tool.name()) ? tool.name() : method.getName();
}
public static String getToolDescriptionFromName(String toolName) {
Assert.hasText(toolName, "toolName cannot be null or empty");
return ParsingUtils.reConcatenateCamelCase(toolName, " ");
}
public static String getToolDescription(Method method) {
Assert.notNull(method, "method cannot be null");
var tool = method.getAnnotation(Tool.class);
if (tool == null) {
return ParsingUtils.reConcatenateCamelCase(method.getName(), " ");
}
return StringUtils.hasText(tool.description()) ? tool.description() : method.getName();
}
public static boolean getToolReturnDirect(Method method) {
Assert.notNull(method, "method cannot be null");
var tool = method.getAnnotation(Tool.class);
return tool != null && tool.returnDirect();
}
public static ToolCallResultConverter getToolCallResultConverter(Method method) {
Assert.notNull(method, "method cannot be null");
var tool = method.getAnnotation(Tool.class);
if (tool == null) {
return new DefaultToolCallResultConverter();
}
var type = tool.resultConverter();
try {
return type.getDeclaredConstructor().newInstance();
}
catch (Exception e) {
throw new IllegalArgumentException("Failed to instantiate ToolCallResultConverter: " + type, e);
}
}
public static List<String> getDuplicateToolNames(List<FunctionCallback> toolCallbacks) {
Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
return toolCallbacks.stream()
.collect(Collectors.groupingBy(FunctionCallback::getName, Collectors.counting()))
.entrySet()
.stream()
.filter(entry -> entry.getValue() > 1)
.map(Map.Entry::getKey)
.collect(Collectors.toList());
}
public static List<String> getDuplicateToolNames(FunctionCallback... toolCallbacks) {
Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
return getDuplicateToolNames(Arrays.asList(toolCallbacks));
}
}