"""agentChat 接口:根据 AI 大模型、提示词、关联知识库输出结果。""" from fastapi import APIRouter, Depends, HTTPException, status from fastapi.responses import StreamingResponse from sqlalchemy import select from sqlalchemy.orm import Session from loguru import logger from ...db.database import get_session from ...models.user import User from ...models.llm_config import LLMConfig from ...models.knowledge_base import KnowledgeBase from ...services.llm_service import LLMService from ...services.document_processor import get_document_processor from ...core.simple_permissions import require_authenticated_user from utils.util_exceptions import HxfResponse from utils.util_schemas import AgentChatRequest, AgentChatResponse router = APIRouter() @router.post("/", response_model=AgentChatResponse, summary="agentChat:按大模型、提示词、关联知识库输出结果") @router.post("", response_model=AgentChatResponse, include_in_schema=False) async def agent_chat( body: AgentChatRequest, current_user: User = Depends(require_authenticated_user), session: Session = Depends(get_session), ): """ 根据选择的大模型、关联的知识库和提示词,返回模型输出。 - model_id: AI 大模型配置 ID(需为对话型,非嵌入型) - prompt 或 message: 提示词(二选一) - knowledge_base_id 或 knowledge_base_ids: 关联知识库(单个或列表,如 [1, 2] 或 ['3']) """ prompt_text = (body.prompt or body.message or "").strip() # 解析知识库 ID 列表:优先 knowledge_base_ids,否则 [knowledge_base_id] kb_ids: list[int] = [] if body.knowledge_base_ids: try: kb_ids = [int(x) for x in body.knowledge_base_ids if x is not None and str(x).strip() != ""] kb_ids = [i for i in kb_ids if i >= 1] except (ValueError, TypeError): raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="knowledge_base_ids 须为数字或数字字符串") elif body.knowledge_base_id is not None and body.knowledge_base_id >= 1: kb_ids = [body.knowledge_base_id] session.title = "agentChat" session.desc = f"START: agentChat model_id={body.model_id}, prompt_len={len(prompt_text)}, knowledge_base_ids={kb_ids}" # 1. 校验并获取大模型配置 stmt = select(LLMConfig).where(LLMConfig.id == body.model_id) llm_config = (await session.execute(stmt)).scalar_one_or_none() if not llm_config: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="大模型配置不存在") if not llm_config.is_active: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="该大模型配置未启用") if llm_config.is_embedding: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="请选择对话型大模型,不能使用嵌入模型") # 2. 若指定知识库,校验并检索(支持多知识库,结果合并后按相似度取 top_k) knowledge_base_used = False references = None final_prompt = prompt_text first_kb_id_used: int | None = None if kb_ids: doc_processor = await get_document_processor(session) all_results: list = [] for kb_id in kb_ids: kb_stmt = select(KnowledgeBase).where(KnowledgeBase.id == kb_id) kb = (await session.execute(kb_stmt)).scalar_one_or_none() if not kb: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"知识库不存在: id={kb_id}") if not kb.is_active: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"该知识库未启用: id={kb_id}") part = doc_processor.search_similar_documents( knowledge_base_id=kb_id, query=prompt_text, k=body.top_k, ) all_results.extend(part) def _score(r): return float(r.get("normalized_score") or r.get("similarity_score") or 0) all_results.sort(key=_score, reverse=True) results = all_results[: body.top_k] max_score = _score(results[0]) if results else 0.0 if results and max_score >= 0.5: knowledge_base_used = True first_kb_id_used = kb_ids[0] refs = [] for i, r in enumerate(results[:5], 1): content = (r.get("content") or "").strip() if content: if len(content) > 1000: content = content[:1000] + "..." refs.append({"index": i, "content": content, "score": r.get("normalized_score")}) references = refs context = "\n\n".join([f"【参考文档{ref['index']}】\n{ref['content']}" for ref in refs]) final_prompt = f"""你是一个专业的助手。请仔细阅读以下参考文档,然后回答用户的问题。 {context} 【用户问题】 {prompt_text} 【重要提示】 - 参考文档中包含了与用户问题相关的信息 - 请仔细阅读参考文档,提取相关信息来回答用户的问题 - 即使文档没有直接定义,也要基于文档中的相关内容进行解释和说明 - 如果文档中提到了相关概念、政策、法规等,请基于这些内容进行回答 - 回答要准确、详细、有条理,尽量引用文档中的具体内容""" logger.info(f"agentChat 使用 RAG,知识库 {kb_ids},检索 {len(results)} 条,最高相似度 {max_score:.3f}") else: logger.info(f"agentChat:知识库 {kb_ids} 检索结果相似度较低(最高 {max_score:.3f}),仅用提示词") # 3. 调用大模型 llm_service = LLMService() try: response_text = await llm_service.chat_completion( model_config=llm_config, messages=[{"role": "user", "content": final_prompt}], temperature=body.temperature if body.temperature is not None else llm_config.temperature, max_tokens=body.max_tokens if body.max_tokens is not None else llm_config.max_tokens, ) except Exception as e: logger.error(f"agentChat LLM 调用失败: {e}") raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=f"大模型调用失败: {str(e)}") session.desc = f"SUCCESS: agentChat 完毕 model_id={body.model_id}, model_name={llm_config.model_name}" return HxfResponse( AgentChatResponse( response=response_text, model_id=llm_config.id, model_name=llm_config.model_name, knowledge_base_id=first_kb_id_used if knowledge_base_used else None, knowledge_base_used=knowledge_base_used, references=references, ) ) @router.post( "/stream", summary="agentChat:按大模型、提示词、关联知识库流式输出结果", ) @router.post( "stream", include_in_schema=False, ) async def agent_chat_stream( body: AgentChatRequest, current_user: User = Depends(require_authenticated_user), session: Session = Depends(get_session), ): """ agentChat 流式接口。 根据选择的大模型、关联的知识库和提示词,实时流式返回模型输出文本。 """ prompt_text = (body.prompt or body.message or "").strip() # 解析知识库 ID 列表:优先 knowledge_base_ids,否则 [knowledge_base_id] kb_ids: list[int] = [] if body.knowledge_base_ids: try: kb_ids = [ int(x) for x in body.knowledge_base_ids if x is not None and str(x).strip() != "" ] kb_ids = [i for i in kb_ids if i >= 1] except (ValueError, TypeError): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="knowledge_base_ids 须为数字或数字字符串", ) elif body.knowledge_base_id is not None and body.knowledge_base_id >= 1: kb_ids = [body.knowledge_base_id] session.title = "agentChatStream" session.desc = ( f"START: agentChat/stream model_id={body.model_id}, " f"prompt_len={len(prompt_text)}, knowledge_base_ids={kb_ids}" ) # 1. 校验并获取大模型配置 stmt = select(LLMConfig).where(LLMConfig.id == body.model_id) llm_config = (await session.execute(stmt)).scalar_one_or_none() if not llm_config: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="大模型配置不存在" ) if not llm_config.is_active: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="该大模型配置未启用" ) if llm_config.is_embedding: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="请选择对话型大模型,不能使用嵌入模型", ) # 2. 若指定知识库,校验并检索(支持多知识库,结果合并后按相似度取 top_k) knowledge_base_used = False references = None final_prompt = prompt_text if kb_ids: doc_processor = await get_document_processor(session) all_results: list = [] for kb_id in kb_ids: kb_stmt = select(KnowledgeBase).where(KnowledgeBase.id == kb_id) kb = (await session.execute(kb_stmt)).scalar_one_or_none() if not kb: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"知识库不存在: id={kb_id}", ) if not kb.is_active: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"该知识库未启用: id={kb_id}", ) part = doc_processor.search_similar_documents( knowledge_base_id=kb_id, query=prompt_text, k=body.top_k, ) all_results.extend(part) def _score(r): return float(r.get("normalized_score") or r.get("similarity_score") or 0) all_results.sort(key=_score, reverse=True) results = all_results[: body.top_k] max_score = _score(results[0]) if results else 0.0 if results and max_score >= 0.5: knowledge_base_used = True refs = [] for i, r in enumerate(results[:5], 1): content = (r.get("content") or "").strip() if content: if len(content) > 1000: content = content[:1000] + "..." refs.append( { "index": i, "content": content, "score": r.get("normalized_score"), } ) references = refs context = "\n\n".join( [f"【参考文档{ref['index']}】\n{ref['content']}" for ref in refs] ) final_prompt = f"""你是一个专业的助手。请仔细阅读以下参考文档,然后回答用户的问题。 {context} 【用户问题】 {prompt_text} 【重要提示】 - 参考文档中包含了与用户问题相关的信息 - 请仔细阅读参考文档,提取相关信息来回答用户的问题 - 即使文档没有直接定义,也要基于文档中的相关内容进行解释和说明 - 如果文档中提到了相关概念、政策、法规等,请基于这些内容进行回答 - 回答要准确、详细、有条理,尽量引用文档中的具体内容""" logger.info( f"agentChat/stream 使用 RAG,知识库 {kb_ids},检索 {len(results)} 条,最高相似度 {max_score:.3f}" ) else: logger.info( f"agentChat/stream:知识库 {kb_ids} 检索结果相似度较低(最高 {max_score:.3f}),仅用提示词" ) # 3. 调用大模型(流式) llm_service = LLMService() async def generate(): try: async for chunk in llm_service.chat_completion_stream( model_config=llm_config, messages=[{"role": "user", "content": final_prompt}], temperature=body.temperature if body.temperature is not None else llm_config.temperature, max_tokens=body.max_tokens if body.max_tokens is not None else llm_config.max_tokens, ): if not chunk: continue # 和 /chat/stream 一样,直接输出文本内容 yield chunk except Exception as e: logger.error(f"agentChat/stream LLM 流式调用失败: {e}") # 将错误信息也推到流里,方便前端展示 yield f"[ERROR] 大模型调用失败: {str(e)}" return StreamingResponse( generate(), media_type="text/stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", }, )