hyf-backend/th_agenter/api/endpoints/agent_chat.py

320 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""agentChat 接口:根据 AI 大模型、提示词、关联知识库输出结果。"""
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import StreamingResponse
from sqlalchemy import select
from sqlalchemy.orm import Session
from loguru import logger
from ...db.database import get_session
from ...models.user import User
from ...models.llm_config import LLMConfig
from ...models.knowledge_base import KnowledgeBase
from ...services.llm_service import LLMService
from ...services.document_processor import get_document_processor
from ...core.simple_permissions import require_authenticated_user
from utils.util_exceptions import HxfResponse
from utils.util_schemas import AgentChatRequest, AgentChatResponse
router = APIRouter()
@router.post("/", response_model=AgentChatResponse, summary="agentChat按大模型、提示词、关联知识库输出结果")
@router.post("", response_model=AgentChatResponse, include_in_schema=False)
async def agent_chat(
body: AgentChatRequest,
current_user: User = Depends(require_authenticated_user),
session: Session = Depends(get_session),
):
"""
根据选择的大模型、关联的知识库和提示词,返回模型输出。
- model_id: AI 大模型配置 ID需为对话型非嵌入型
- prompt 或 message: 提示词(二选一)
- knowledge_base_id 或 knowledge_base_ids: 关联知识库(单个或列表,如 [1, 2] 或 ['3']
"""
prompt_text = (body.prompt or body.message or "").strip()
# 解析知识库 ID 列表:优先 knowledge_base_ids否则 [knowledge_base_id]
kb_ids: list[int] = []
if body.knowledge_base_ids:
try:
kb_ids = [int(x) for x in body.knowledge_base_ids if x is not None and str(x).strip() != ""]
kb_ids = [i for i in kb_ids if i >= 1]
except (ValueError, TypeError):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="knowledge_base_ids 须为数字或数字字符串")
elif body.knowledge_base_id is not None and body.knowledge_base_id >= 1:
kb_ids = [body.knowledge_base_id]
session.title = "agentChat"
session.desc = f"START: agentChat model_id={body.model_id}, prompt_len={len(prompt_text)}, knowledge_base_ids={kb_ids}"
# 1. 校验并获取大模型配置
stmt = select(LLMConfig).where(LLMConfig.id == body.model_id)
llm_config = (await session.execute(stmt)).scalar_one_or_none()
if not llm_config:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="大模型配置不存在")
if not llm_config.is_active:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="该大模型配置未启用")
if llm_config.is_embedding:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="请选择对话型大模型,不能使用嵌入模型")
# 2. 若指定知识库,校验并检索(支持多知识库,结果合并后按相似度取 top_k
knowledge_base_used = False
references = None
final_prompt = prompt_text
first_kb_id_used: int | None = None
if kb_ids:
doc_processor = await get_document_processor(session)
all_results: list = []
for kb_id in kb_ids:
kb_stmt = select(KnowledgeBase).where(KnowledgeBase.id == kb_id)
kb = (await session.execute(kb_stmt)).scalar_one_or_none()
if not kb:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"知识库不存在: id={kb_id}")
if not kb.is_active:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"该知识库未启用: id={kb_id}")
part = doc_processor.search_similar_documents(
knowledge_base_id=kb_id,
query=prompt_text,
k=body.top_k,
)
all_results.extend(part)
def _score(r):
return float(r.get("normalized_score") or r.get("similarity_score") or 0)
all_results.sort(key=_score, reverse=True)
results = all_results[: body.top_k]
max_score = _score(results[0]) if results else 0.0
if results and max_score >= 0.45:
knowledge_base_used = True
first_kb_id_used = kb_ids[0]
refs = []
for i, r in enumerate(results[:5], 1):
content = (r.get("content") or "").strip()
if content:
if len(content) > 1000:
content = content[:1000] + "..."
refs.append({"index": i, "content": content, "score": r.get("normalized_score")})
references = refs
context = "\n\n".join([f"【参考文档{ref['index']}\n{ref['content']}" for ref in refs])
final_prompt = f"""你是一个专业的助手。请仔细阅读以下参考文档,然后回答用户的问题。
{context}
【用户问题】
{prompt_text}
【重要提示】
- 参考文档中包含了与用户问题相关的信息
- 请仔细阅读参考文档,提取相关信息来回答用户的问题
- 即使文档没有直接定义,也要基于文档中的相关内容进行解释和说明
- 如果文档中提到了相关概念、政策、法规等,请基于这些内容进行回答
- 回答要准确、详细、有条理,尽量引用文档中的具体内容"""
logger.info(f"agentChat 使用 RAG知识库 {kb_ids},检索 {len(results)} 条,最高相似度 {max_score:.3f}")
else:
logger.info(f"agentChat知识库 {kb_ids} 检索结果相似度较低(最高 {max_score:.3f}),仅用提示词")
# 3. 调用大模型
llm_service = LLMService()
try:
response_text = await llm_service.chat_completion(
model_config=llm_config,
messages=[{"role": "user", "content": final_prompt}],
temperature=body.temperature if body.temperature is not None else llm_config.temperature,
max_tokens=body.max_tokens if body.max_tokens is not None else llm_config.max_tokens,
)
except Exception as e:
logger.error(f"agentChat LLM 调用失败: {e}")
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=f"大模型调用失败: {str(e)}")
session.desc = f"SUCCESS: agentChat 完毕 model_id={body.model_id}, model_name={llm_config.model_name}"
return HxfResponse(
AgentChatResponse(
response=response_text,
model_id=llm_config.id,
model_name=llm_config.model_name,
knowledge_base_id=first_kb_id_used if knowledge_base_used else None,
knowledge_base_used=knowledge_base_used,
references=references,
)
)
@router.post(
"/stream",
summary="agentChat按大模型、提示词、关联知识库流式输出结果",
)
@router.post(
"stream",
include_in_schema=False,
)
async def agent_chat_stream(
body: AgentChatRequest,
current_user: User = Depends(require_authenticated_user),
session: Session = Depends(get_session),
):
"""
agentChat 流式接口。
根据选择的大模型、关联的知识库和提示词,实时流式返回模型输出文本。
"""
prompt_text = (body.prompt or body.message or "").strip()
# 解析知识库 ID 列表:优先 knowledge_base_ids否则 [knowledge_base_id]
kb_ids: list[int] = []
if body.knowledge_base_ids:
try:
kb_ids = [
int(x)
for x in body.knowledge_base_ids
if x is not None and str(x).strip() != ""
]
kb_ids = [i for i in kb_ids if i >= 1]
except (ValueError, TypeError):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="knowledge_base_ids 须为数字或数字字符串",
)
elif body.knowledge_base_id is not None and body.knowledge_base_id >= 1:
kb_ids = [body.knowledge_base_id]
session.title = "agentChatStream"
session.desc = (
f"START: agentChat/stream model_id={body.model_id}, "
f"prompt_len={len(prompt_text)}, knowledge_base_ids={kb_ids}"
)
# 1. 校验并获取大模型配置
stmt = select(LLMConfig).where(LLMConfig.id == body.model_id)
llm_config = (await session.execute(stmt)).scalar_one_or_none()
if not llm_config:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="大模型配置不存在"
)
if not llm_config.is_active:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="该大模型配置未启用"
)
if llm_config.is_embedding:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="请选择对话型大模型,不能使用嵌入模型",
)
# 2. 若指定知识库,校验并检索(支持多知识库,结果合并后按相似度取 top_k
knowledge_base_used = False
references = None
final_prompt = prompt_text
if kb_ids:
doc_processor = await get_document_processor(session)
all_results: list = []
for kb_id in kb_ids:
kb_stmt = select(KnowledgeBase).where(KnowledgeBase.id == kb_id)
kb = (await session.execute(kb_stmt)).scalar_one_or_none()
if not kb:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"知识库不存在: id={kb_id}",
)
if not kb.is_active:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"该知识库未启用: id={kb_id}",
)
part = doc_processor.search_similar_documents(
knowledge_base_id=kb_id,
query=prompt_text,
k=body.top_k,
)
all_results.extend(part)
def _score(r):
return float(r.get("normalized_score") or r.get("similarity_score") or 0)
all_results.sort(key=_score, reverse=True)
results = all_results[: body.top_k]
max_score = _score(results[0]) if results else 0.0
if results and max_score >= 0.45:
knowledge_base_used = True
refs = []
for i, r in enumerate(results[:5], 1):
content = (r.get("content") or "").strip()
if content:
if len(content) > 1000:
content = content[:1000] + "..."
refs.append(
{
"index": i,
"content": content,
"score": r.get("normalized_score"),
}
)
references = refs
context = "\n\n".join(
[f"【参考文档{ref['index']}\n{ref['content']}" for ref in refs]
)
final_prompt = f"""你是一个专业的助手。请仔细阅读以下参考文档,然后回答用户的问题。
{context}
【用户问题】
{prompt_text}
【重要提示】
- 参考文档中包含了与用户问题相关的信息
- 请仔细阅读参考文档,提取相关信息来回答用户的问题
- 即使文档没有直接定义,也要基于文档中的相关内容进行解释和说明
- 如果文档中提到了相关概念、政策、法规等,请基于这些内容进行回答
- 回答要准确、详细、有条理,尽量引用文档中的具体内容"""
logger.info(
f"agentChat/stream 使用 RAG知识库 {kb_ids},检索 {len(results)} 条,最高相似度 {max_score:.3f}"
)
else:
logger.info(
f"agentChat/stream知识库 {kb_ids} 检索结果相似度较低(最高 {max_score:.3f}),仅用提示词"
)
# 3. 调用大模型(流式)
llm_service = LLMService()
async def generate():
try:
async for chunk in llm_service.chat_completion_stream(
model_config=llm_config,
messages=[{"role": "user", "content": final_prompt}],
temperature=body.temperature
if body.temperature is not None
else llm_config.temperature,
max_tokens=body.max_tokens
if body.max_tokens is not None
else llm_config.max_tokens,
):
if not chunk:
continue
# 和 /chat/stream 一样,直接输出文本内容
yield chunk
except Exception as e:
logger.error(f"agentChat/stream LLM 流式调用失败: {e}")
# 将错误信息也推到流里,方便前端展示
yield f"[ERROR] 大模型调用失败: {str(e)}"
return StreamingResponse(
generate(),
media_type="text/stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
},
)