155 lines
6.2 KiB
Python
155 lines
6.2 KiB
Python
import os
|
||
import asyncio
|
||
from datetime import datetime
|
||
from deepagents import create_deep_agent
|
||
from openai import OpenAI
|
||
from langchain.chat_models import init_chat_model
|
||
from langchain.agents import create_agent
|
||
from langgraph.checkpoint.memory import InMemorySaver, MemorySaver # 导入检查点工具
|
||
from deepagents.backends import StoreBackend
|
||
from loguru import logger
|
||
def internet_search_tool(query: str):
|
||
"""Run a web search"""
|
||
logger.info(f"Running internet search for query: {query}")
|
||
client = OpenAI(
|
||
api_key=os.getenv('DASHSCOPE_API_KEY'),
|
||
base_url=os.getenv('DASHSCOPE_BASE_URL'),
|
||
)
|
||
logger.info(f"create OpenAI")
|
||
completion = client.chat.completions.create(
|
||
model="qwen-plus",
|
||
messages=[
|
||
{'role': 'system', 'content': 'You are a helpful assistant.'},
|
||
{'role': 'user', 'content': query}
|
||
],
|
||
extra_body={
|
||
"enable_search": True
|
||
}
|
||
)
|
||
logger.info(f"create completions")
|
||
logger.info(f"OpenAI response: {completion.choices[0].message.content}")
|
||
return completion.choices[0].message.content
|
||
|
||
|
||
|
||
# System prompt to steer the agent to be an expert researcher
|
||
today = datetime.now().strftime("%Y年%m月%d日")
|
||
research_instructions = f"""你是一个智能助手。你的任务是帮助用户完成各种任务。
|
||
|
||
你可以使用互联网搜索工具来获取信息。
|
||
## `internet_search`
|
||
使用此工具对给定查询进行互联网搜索。你可以指定返回结果的最大数量、主题以及是否包含原始内容。
|
||
|
||
今天的日期是:{today}
|
||
"""
|
||
|
||
# Create the deep agent with memory
|
||
model = init_chat_model(
|
||
model="gpt-4.1-mini",
|
||
model_provider='openai',
|
||
api_key=os.getenv('OPENAI_API_KEY'),
|
||
base_url=os.getenv('OPENAI_BASE_URL'),
|
||
)
|
||
checkpointer = InMemorySaver() # 创建内存检查点,自动保存历史
|
||
|
||
agent = create_deep_agent( # state:thread会话级的状态
|
||
tools=[internet_search_tool],
|
||
system_prompt=research_instructions,
|
||
model=model,
|
||
checkpointer=checkpointer, # 添加检查点,启用自动记忆
|
||
interrupt_on={'internet_search_tool':True}
|
||
)
|
||
|
||
# 多轮对话循环(使用 Checkpointer 自动记忆)
|
||
printed_msg_ids = set() # 跟踪已打印的消息ID
|
||
thread_id = "user_session_001" # 会话 ID,区分不同用户/会话
|
||
config = {"configurable": {"thread_id": thread_id}, "metastore": {'assistant_id': 'owenliang'}} # 配置会话
|
||
|
||
print("开始对话(输入 'exit' 退出):")
|
||
while True:
|
||
user_input = input("\nHUMAN: ").strip()
|
||
if user_input.lower() == 'exit':
|
||
break
|
||
|
||
# 使用 values 模式多次返回完整状态,这里按 message.id 去重,并按类型分类打印
|
||
pending_resume = None
|
||
while True:
|
||
if pending_resume is None:
|
||
request = {"messages": [{"role": "user", "content": user_input}]}
|
||
else:
|
||
from langgraph.types import Command as _Command
|
||
|
||
request = _Command(resume=pending_resume)
|
||
pending_resume = None
|
||
|
||
for item in agent.stream(
|
||
request,
|
||
config=config,
|
||
stream_mode="values",
|
||
):
|
||
state = item[0] if isinstance(item, tuple) and len(item) == 2 else item
|
||
|
||
# 先检查是否触发了 Human-In-The-Loop 中断
|
||
if isinstance(state, dict) and "__interrupt__" in state:
|
||
interrupts = state["__interrupt__"] or []
|
||
if interrupts:
|
||
hitl_payload = interrupts[0].value
|
||
action_requests = hitl_payload.get("action_requests", [])
|
||
|
||
print("\n=== 需要人工审批的工具调用 ===")
|
||
decisions: list[dict[str, str]] = []
|
||
for idx, ar in enumerate(action_requests):
|
||
name = ar.get("name")
|
||
args = ar.get("args")
|
||
print(f"[{idx}] 工具 {name} 参数: {args}")
|
||
while True:
|
||
choice = input(" 决策 (a=approve, r=reject): ").strip().lower()
|
||
if choice in ("a", "r"):
|
||
break
|
||
decisions.append({"type": "approve" if choice == "a" else "reject"})
|
||
|
||
# 下一轮调用改为 resume,同一轮用户回合继续往下跑
|
||
pending_resume = {"decisions": decisions}
|
||
break
|
||
|
||
# 兼容 dict state 和 AgentState dataclass
|
||
messages = state.get("messages", []) if isinstance(state, dict) else getattr(state, "messages", [])
|
||
for msg in messages:
|
||
msg_id = getattr(msg, "id", None)
|
||
if msg_id is not None and msg_id in printed_msg_ids:
|
||
continue
|
||
if msg_id is not None:
|
||
printed_msg_ids.add(msg_id)
|
||
|
||
msg_type = getattr(msg, "type", None)
|
||
|
||
if msg_type == "human":
|
||
# 用户输入已经在命令行里,不再重复打印
|
||
continue
|
||
|
||
if msg_type == "ai":
|
||
tool_calls = getattr(msg, "tool_calls", None) or []
|
||
if tool_calls:
|
||
# 这是发起工具调用的 AI 消息(TOOL CALL)
|
||
for tc in tool_calls:
|
||
tool_name = tc.get("name")
|
||
args = tc.get("args")
|
||
print(f"TOOL CALL [{tool_name}]: {args}")
|
||
# 如果 AI 同时带有自然语言内容,也一起打印
|
||
if getattr(msg, "content", None):
|
||
print(f"AI: {msg.content}")
|
||
continue
|
||
|
||
if msg_type == "tool":
|
||
# 工具执行结果(TOOL RESPONSE)
|
||
tool_name = getattr(msg, "name", None) or "tool"
|
||
print(f"TOOL RESPONSE [{tool_name}]: {msg.content}")
|
||
continue
|
||
|
||
# 兜底:其它类型直接打印出来便于调试
|
||
print(f"[{msg_type}]: {getattr(msg, 'content', None)}")
|
||
|
||
# 如果没有新的中断需要 resume,则整轮结束,等待下一轮用户输入
|
||
if pending_resume is None:
|
||
break
|