2026-01-26 11:18:51 +08:00
|
|
|
|
"""agentChat 接口:根据 AI 大模型、提示词、关联知识库输出结果。"""
|
|
|
|
|
|
|
|
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, status
|
2026-01-28 16:08:20 +08:00
|
|
|
|
from fastapi.responses import StreamingResponse
|
2026-01-26 11:18:51 +08:00
|
|
|
|
from sqlalchemy import select
|
|
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
|
|
from loguru import logger
|
|
|
|
|
|
|
|
|
|
|
|
from ...db.database import get_session
|
|
|
|
|
|
from ...models.user import User
|
|
|
|
|
|
from ...models.llm_config import LLMConfig
|
|
|
|
|
|
from ...models.knowledge_base import KnowledgeBase
|
|
|
|
|
|
from ...services.llm_service import LLMService
|
|
|
|
|
|
from ...services.document_processor import get_document_processor
|
|
|
|
|
|
from ...core.simple_permissions import require_authenticated_user
|
|
|
|
|
|
from utils.util_exceptions import HxfResponse
|
|
|
|
|
|
from utils.util_schemas import AgentChatRequest, AgentChatResponse
|
|
|
|
|
|
|
|
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/", response_model=AgentChatResponse, summary="agentChat:按大模型、提示词、关联知识库输出结果")
|
|
|
|
|
|
@router.post("", response_model=AgentChatResponse, include_in_schema=False)
|
|
|
|
|
|
async def agent_chat(
|
|
|
|
|
|
body: AgentChatRequest,
|
|
|
|
|
|
current_user: User = Depends(require_authenticated_user),
|
|
|
|
|
|
session: Session = Depends(get_session),
|
|
|
|
|
|
):
|
|
|
|
|
|
"""
|
|
|
|
|
|
根据选择的大模型、关联的知识库和提示词,返回模型输出。
|
|
|
|
|
|
- model_id: AI 大模型配置 ID(需为对话型,非嵌入型)
|
|
|
|
|
|
- prompt 或 message: 提示词(二选一)
|
|
|
|
|
|
- knowledge_base_id 或 knowledge_base_ids: 关联知识库(单个或列表,如 [1, 2] 或 ['3'])
|
|
|
|
|
|
"""
|
|
|
|
|
|
prompt_text = (body.prompt or body.message or "").strip()
|
|
|
|
|
|
# 解析知识库 ID 列表:优先 knowledge_base_ids,否则 [knowledge_base_id]
|
|
|
|
|
|
kb_ids: list[int] = []
|
|
|
|
|
|
if body.knowledge_base_ids:
|
|
|
|
|
|
try:
|
|
|
|
|
|
kb_ids = [int(x) for x in body.knowledge_base_ids if x is not None and str(x).strip() != ""]
|
|
|
|
|
|
kb_ids = [i for i in kb_ids if i >= 1]
|
|
|
|
|
|
except (ValueError, TypeError):
|
|
|
|
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="knowledge_base_ids 须为数字或数字字符串")
|
|
|
|
|
|
elif body.knowledge_base_id is not None and body.knowledge_base_id >= 1:
|
|
|
|
|
|
kb_ids = [body.knowledge_base_id]
|
|
|
|
|
|
|
|
|
|
|
|
session.title = "agentChat"
|
|
|
|
|
|
session.desc = f"START: agentChat model_id={body.model_id}, prompt_len={len(prompt_text)}, knowledge_base_ids={kb_ids}"
|
|
|
|
|
|
|
|
|
|
|
|
# 1. 校验并获取大模型配置
|
|
|
|
|
|
stmt = select(LLMConfig).where(LLMConfig.id == body.model_id)
|
|
|
|
|
|
llm_config = (await session.execute(stmt)).scalar_one_or_none()
|
|
|
|
|
|
if not llm_config:
|
|
|
|
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="大模型配置不存在")
|
|
|
|
|
|
if not llm_config.is_active:
|
|
|
|
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="该大模型配置未启用")
|
|
|
|
|
|
if llm_config.is_embedding:
|
|
|
|
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="请选择对话型大模型,不能使用嵌入模型")
|
|
|
|
|
|
|
|
|
|
|
|
# 2. 若指定知识库,校验并检索(支持多知识库,结果合并后按相似度取 top_k)
|
|
|
|
|
|
knowledge_base_used = False
|
|
|
|
|
|
references = None
|
|
|
|
|
|
final_prompt = prompt_text
|
|
|
|
|
|
first_kb_id_used: int | None = None
|
|
|
|
|
|
|
|
|
|
|
|
if kb_ids:
|
|
|
|
|
|
doc_processor = await get_document_processor(session)
|
|
|
|
|
|
all_results: list = []
|
|
|
|
|
|
|
|
|
|
|
|
for kb_id in kb_ids:
|
|
|
|
|
|
kb_stmt = select(KnowledgeBase).where(KnowledgeBase.id == kb_id)
|
|
|
|
|
|
kb = (await session.execute(kb_stmt)).scalar_one_or_none()
|
|
|
|
|
|
if not kb:
|
|
|
|
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"知识库不存在: id={kb_id}")
|
|
|
|
|
|
if not kb.is_active:
|
|
|
|
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"该知识库未启用: id={kb_id}")
|
|
|
|
|
|
|
|
|
|
|
|
part = doc_processor.search_similar_documents(
|
|
|
|
|
|
knowledge_base_id=kb_id,
|
|
|
|
|
|
query=prompt_text,
|
|
|
|
|
|
k=body.top_k,
|
|
|
|
|
|
)
|
|
|
|
|
|
all_results.extend(part)
|
|
|
|
|
|
|
|
|
|
|
|
def _score(r):
|
|
|
|
|
|
return float(r.get("normalized_score") or r.get("similarity_score") or 0)
|
|
|
|
|
|
|
|
|
|
|
|
all_results.sort(key=_score, reverse=True)
|
|
|
|
|
|
results = all_results[: body.top_k]
|
|
|
|
|
|
max_score = _score(results[0]) if results else 0.0
|
|
|
|
|
|
|
2026-01-30 16:56:10 +08:00
|
|
|
|
if results and max_score >= 0.45:
|
2026-01-26 11:18:51 +08:00
|
|
|
|
knowledge_base_used = True
|
|
|
|
|
|
first_kb_id_used = kb_ids[0]
|
|
|
|
|
|
refs = []
|
|
|
|
|
|
for i, r in enumerate(results[:5], 1):
|
|
|
|
|
|
content = (r.get("content") or "").strip()
|
|
|
|
|
|
if content:
|
|
|
|
|
|
if len(content) > 1000:
|
|
|
|
|
|
content = content[:1000] + "..."
|
|
|
|
|
|
refs.append({"index": i, "content": content, "score": r.get("normalized_score")})
|
|
|
|
|
|
references = refs
|
|
|
|
|
|
|
|
|
|
|
|
context = "\n\n".join([f"【参考文档{ref['index']}】\n{ref['content']}" for ref in refs])
|
|
|
|
|
|
final_prompt = f"""你是一个专业的助手。请仔细阅读以下参考文档,然后回答用户的问题。
|
|
|
|
|
|
|
|
|
|
|
|
{context}
|
|
|
|
|
|
|
|
|
|
|
|
【用户问题】
|
|
|
|
|
|
{prompt_text}
|
|
|
|
|
|
|
|
|
|
|
|
【重要提示】
|
|
|
|
|
|
- 参考文档中包含了与用户问题相关的信息
|
|
|
|
|
|
- 请仔细阅读参考文档,提取相关信息来回答用户的问题
|
|
|
|
|
|
- 即使文档没有直接定义,也要基于文档中的相关内容进行解释和说明
|
|
|
|
|
|
- 如果文档中提到了相关概念、政策、法规等,请基于这些内容进行回答
|
|
|
|
|
|
- 回答要准确、详细、有条理,尽量引用文档中的具体内容"""
|
|
|
|
|
|
logger.info(f"agentChat 使用 RAG,知识库 {kb_ids},检索 {len(results)} 条,最高相似度 {max_score:.3f}")
|
|
|
|
|
|
else:
|
|
|
|
|
|
logger.info(f"agentChat:知识库 {kb_ids} 检索结果相似度较低(最高 {max_score:.3f}),仅用提示词")
|
|
|
|
|
|
|
|
|
|
|
|
# 3. 调用大模型
|
|
|
|
|
|
llm_service = LLMService()
|
|
|
|
|
|
try:
|
|
|
|
|
|
response_text = await llm_service.chat_completion(
|
|
|
|
|
|
model_config=llm_config,
|
|
|
|
|
|
messages=[{"role": "user", "content": final_prompt}],
|
|
|
|
|
|
temperature=body.temperature if body.temperature is not None else llm_config.temperature,
|
|
|
|
|
|
max_tokens=body.max_tokens if body.max_tokens is not None else llm_config.max_tokens,
|
|
|
|
|
|
)
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"agentChat LLM 调用失败: {e}")
|
|
|
|
|
|
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=f"大模型调用失败: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
session.desc = f"SUCCESS: agentChat 完毕 model_id={body.model_id}, model_name={llm_config.model_name}"
|
|
|
|
|
|
|
|
|
|
|
|
return HxfResponse(
|
|
|
|
|
|
AgentChatResponse(
|
|
|
|
|
|
response=response_text,
|
|
|
|
|
|
model_id=llm_config.id,
|
|
|
|
|
|
model_name=llm_config.model_name,
|
|
|
|
|
|
knowledge_base_id=first_kb_id_used if knowledge_base_used else None,
|
|
|
|
|
|
knowledge_base_used=knowledge_base_used,
|
|
|
|
|
|
references=references,
|
|
|
|
|
|
)
|
|
|
|
|
|
)
|
2026-01-28 16:08:20 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post(
|
|
|
|
|
|
"/stream",
|
|
|
|
|
|
summary="agentChat:按大模型、提示词、关联知识库流式输出结果",
|
|
|
|
|
|
)
|
|
|
|
|
|
@router.post(
|
|
|
|
|
|
"stream",
|
|
|
|
|
|
include_in_schema=False,
|
|
|
|
|
|
)
|
|
|
|
|
|
async def agent_chat_stream(
|
|
|
|
|
|
body: AgentChatRequest,
|
|
|
|
|
|
current_user: User = Depends(require_authenticated_user),
|
|
|
|
|
|
session: Session = Depends(get_session),
|
|
|
|
|
|
):
|
|
|
|
|
|
"""
|
|
|
|
|
|
agentChat 流式接口。
|
|
|
|
|
|
|
|
|
|
|
|
根据选择的大模型、关联的知识库和提示词,实时流式返回模型输出文本。
|
|
|
|
|
|
"""
|
|
|
|
|
|
prompt_text = (body.prompt or body.message or "").strip()
|
|
|
|
|
|
# 解析知识库 ID 列表:优先 knowledge_base_ids,否则 [knowledge_base_id]
|
|
|
|
|
|
kb_ids: list[int] = []
|
|
|
|
|
|
if body.knowledge_base_ids:
|
|
|
|
|
|
try:
|
|
|
|
|
|
kb_ids = [
|
|
|
|
|
|
int(x)
|
|
|
|
|
|
for x in body.knowledge_base_ids
|
|
|
|
|
|
if x is not None and str(x).strip() != ""
|
|
|
|
|
|
]
|
|
|
|
|
|
kb_ids = [i for i in kb_ids if i >= 1]
|
|
|
|
|
|
except (ValueError, TypeError):
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
|
|
|
detail="knowledge_base_ids 须为数字或数字字符串",
|
|
|
|
|
|
)
|
|
|
|
|
|
elif body.knowledge_base_id is not None and body.knowledge_base_id >= 1:
|
|
|
|
|
|
kb_ids = [body.knowledge_base_id]
|
|
|
|
|
|
|
|
|
|
|
|
session.title = "agentChatStream"
|
|
|
|
|
|
session.desc = (
|
|
|
|
|
|
f"START: agentChat/stream model_id={body.model_id}, "
|
|
|
|
|
|
f"prompt_len={len(prompt_text)}, knowledge_base_ids={kb_ids}"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 1. 校验并获取大模型配置
|
|
|
|
|
|
stmt = select(LLMConfig).where(LLMConfig.id == body.model_id)
|
|
|
|
|
|
llm_config = (await session.execute(stmt)).scalar_one_or_none()
|
|
|
|
|
|
if not llm_config:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=status.HTTP_404_NOT_FOUND, detail="大模型配置不存在"
|
|
|
|
|
|
)
|
|
|
|
|
|
if not llm_config.is_active:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=status.HTTP_400_BAD_REQUEST, detail="该大模型配置未启用"
|
|
|
|
|
|
)
|
|
|
|
|
|
if llm_config.is_embedding:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
|
|
|
detail="请选择对话型大模型,不能使用嵌入模型",
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 2. 若指定知识库,校验并检索(支持多知识库,结果合并后按相似度取 top_k)
|
|
|
|
|
|
knowledge_base_used = False
|
|
|
|
|
|
references = None
|
|
|
|
|
|
final_prompt = prompt_text
|
|
|
|
|
|
|
|
|
|
|
|
if kb_ids:
|
|
|
|
|
|
doc_processor = await get_document_processor(session)
|
|
|
|
|
|
all_results: list = []
|
|
|
|
|
|
|
|
|
|
|
|
for kb_id in kb_ids:
|
|
|
|
|
|
kb_stmt = select(KnowledgeBase).where(KnowledgeBase.id == kb_id)
|
|
|
|
|
|
kb = (await session.execute(kb_stmt)).scalar_one_or_none()
|
|
|
|
|
|
if not kb:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
|
|
|
|
detail=f"知识库不存在: id={kb_id}",
|
|
|
|
|
|
)
|
|
|
|
|
|
if not kb.is_active:
|
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
|
|
|
detail=f"该知识库未启用: id={kb_id}",
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
part = doc_processor.search_similar_documents(
|
|
|
|
|
|
knowledge_base_id=kb_id,
|
|
|
|
|
|
query=prompt_text,
|
|
|
|
|
|
k=body.top_k,
|
|
|
|
|
|
)
|
|
|
|
|
|
all_results.extend(part)
|
|
|
|
|
|
|
|
|
|
|
|
def _score(r):
|
|
|
|
|
|
return float(r.get("normalized_score") or r.get("similarity_score") or 0)
|
|
|
|
|
|
|
|
|
|
|
|
all_results.sort(key=_score, reverse=True)
|
|
|
|
|
|
results = all_results[: body.top_k]
|
|
|
|
|
|
max_score = _score(results[0]) if results else 0.0
|
|
|
|
|
|
|
2026-01-30 16:56:10 +08:00
|
|
|
|
if results and max_score >= 0.45:
|
2026-01-28 16:08:20 +08:00
|
|
|
|
knowledge_base_used = True
|
|
|
|
|
|
refs = []
|
|
|
|
|
|
for i, r in enumerate(results[:5], 1):
|
|
|
|
|
|
content = (r.get("content") or "").strip()
|
|
|
|
|
|
if content:
|
|
|
|
|
|
if len(content) > 1000:
|
|
|
|
|
|
content = content[:1000] + "..."
|
|
|
|
|
|
refs.append(
|
|
|
|
|
|
{
|
|
|
|
|
|
"index": i,
|
|
|
|
|
|
"content": content,
|
|
|
|
|
|
"score": r.get("normalized_score"),
|
|
|
|
|
|
}
|
|
|
|
|
|
)
|
|
|
|
|
|
references = refs
|
|
|
|
|
|
|
|
|
|
|
|
context = "\n\n".join(
|
|
|
|
|
|
[f"【参考文档{ref['index']}】\n{ref['content']}" for ref in refs]
|
|
|
|
|
|
)
|
|
|
|
|
|
final_prompt = f"""你是一个专业的助手。请仔细阅读以下参考文档,然后回答用户的问题。
|
|
|
|
|
|
|
|
|
|
|
|
{context}
|
|
|
|
|
|
|
|
|
|
|
|
【用户问题】
|
|
|
|
|
|
{prompt_text}
|
|
|
|
|
|
|
|
|
|
|
|
【重要提示】
|
|
|
|
|
|
- 参考文档中包含了与用户问题相关的信息
|
|
|
|
|
|
- 请仔细阅读参考文档,提取相关信息来回答用户的问题
|
|
|
|
|
|
- 即使文档没有直接定义,也要基于文档中的相关内容进行解释和说明
|
|
|
|
|
|
- 如果文档中提到了相关概念、政策、法规等,请基于这些内容进行回答
|
|
|
|
|
|
- 回答要准确、详细、有条理,尽量引用文档中的具体内容"""
|
|
|
|
|
|
logger.info(
|
|
|
|
|
|
f"agentChat/stream 使用 RAG,知识库 {kb_ids},检索 {len(results)} 条,最高相似度 {max_score:.3f}"
|
|
|
|
|
|
)
|
|
|
|
|
|
else:
|
|
|
|
|
|
logger.info(
|
|
|
|
|
|
f"agentChat/stream:知识库 {kb_ids} 检索结果相似度较低(最高 {max_score:.3f}),仅用提示词"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 3. 调用大模型(流式)
|
|
|
|
|
|
llm_service = LLMService()
|
|
|
|
|
|
|
|
|
|
|
|
async def generate():
|
|
|
|
|
|
try:
|
|
|
|
|
|
async for chunk in llm_service.chat_completion_stream(
|
|
|
|
|
|
model_config=llm_config,
|
|
|
|
|
|
messages=[{"role": "user", "content": final_prompt}],
|
|
|
|
|
|
temperature=body.temperature
|
|
|
|
|
|
if body.temperature is not None
|
|
|
|
|
|
else llm_config.temperature,
|
|
|
|
|
|
max_tokens=body.max_tokens
|
|
|
|
|
|
if body.max_tokens is not None
|
|
|
|
|
|
else llm_config.max_tokens,
|
|
|
|
|
|
):
|
|
|
|
|
|
if not chunk:
|
|
|
|
|
|
continue
|
|
|
|
|
|
# 和 /chat/stream 一样,直接输出文本内容
|
|
|
|
|
|
yield chunk
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error(f"agentChat/stream LLM 流式调用失败: {e}")
|
|
|
|
|
|
# 将错误信息也推到流里,方便前端展示
|
|
|
|
|
|
yield f"[ERROR] 大模型调用失败: {str(e)}"
|
|
|
|
|
|
|
|
|
|
|
|
return StreamingResponse(
|
|
|
|
|
|
generate(),
|
|
|
|
|
|
media_type="text/stream",
|
|
|
|
|
|
headers={
|
|
|
|
|
|
"Cache-Control": "no-cache",
|
|
|
|
|
|
"Connection": "keep-alive",
|
|
|
|
|
|
},
|
|
|
|
|
|
)
|
|
|
|
|
|
|