diff --git a/data/chroma/kb_3/8aa4b792-80c2-4a03-a4d0-c024eeb961c9/data_level0.bin b/data/chroma/kb_3/8aa4b792-80c2-4a03-a4d0-c024eeb961c9/data_level0.bin index 331d9db..c62f655 100644 Binary files a/data/chroma/kb_3/8aa4b792-80c2-4a03-a4d0-c024eeb961c9/data_level0.bin and b/data/chroma/kb_3/8aa4b792-80c2-4a03-a4d0-c024eeb961c9/data_level0.bin differ diff --git a/data/chroma/kb_3/chroma.sqlite3 b/data/chroma/kb_3/chroma.sqlite3 index bde7d15..b47b03c 100644 Binary files a/data/chroma/kb_3/chroma.sqlite3 and b/data/chroma/kb_3/chroma.sqlite3 differ diff --git a/th_agenter/api/endpoints/agent_chat.py b/th_agenter/api/endpoints/agent_chat.py new file mode 100644 index 0000000..4172d0f --- /dev/null +++ b/th_agenter/api/endpoints/agent_chat.py @@ -0,0 +1,145 @@ +"""agentChat 接口:根据 AI 大模型、提示词、关联知识库输出结果。""" + +from fastapi import APIRouter, Depends, HTTPException, status +from sqlalchemy import select +from sqlalchemy.orm import Session +from loguru import logger + +from ...db.database import get_session +from ...models.user import User +from ...models.llm_config import LLMConfig +from ...models.knowledge_base import KnowledgeBase +from ...services.llm_service import LLMService +from ...services.document_processor import get_document_processor +from ...core.simple_permissions import require_authenticated_user +from utils.util_exceptions import HxfResponse +from utils.util_schemas import AgentChatRequest, AgentChatResponse + +router = APIRouter() + + +@router.post("/", response_model=AgentChatResponse, summary="agentChat:按大模型、提示词、关联知识库输出结果") +@router.post("", response_model=AgentChatResponse, include_in_schema=False) +async def agent_chat( + body: AgentChatRequest, + current_user: User = Depends(require_authenticated_user), + session: Session = Depends(get_session), +): + """ + 根据选择的大模型、关联的知识库和提示词,返回模型输出。 + - model_id: AI 大模型配置 ID(需为对话型,非嵌入型) + - prompt 或 message: 提示词(二选一) + - knowledge_base_id 或 knowledge_base_ids: 关联知识库(单个或列表,如 [1, 2] 或 ['3']) + """ + prompt_text = (body.prompt or body.message or "").strip() + # 解析知识库 ID 列表:优先 knowledge_base_ids,否则 [knowledge_base_id] + kb_ids: list[int] = [] + if body.knowledge_base_ids: + try: + kb_ids = [int(x) for x in body.knowledge_base_ids if x is not None and str(x).strip() != ""] + kb_ids = [i for i in kb_ids if i >= 1] + except (ValueError, TypeError): + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="knowledge_base_ids 须为数字或数字字符串") + elif body.knowledge_base_id is not None and body.knowledge_base_id >= 1: + kb_ids = [body.knowledge_base_id] + + session.title = "agentChat" + session.desc = f"START: agentChat model_id={body.model_id}, prompt_len={len(prompt_text)}, knowledge_base_ids={kb_ids}" + + # 1. 校验并获取大模型配置 + stmt = select(LLMConfig).where(LLMConfig.id == body.model_id) + llm_config = (await session.execute(stmt)).scalar_one_or_none() + if not llm_config: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="大模型配置不存在") + if not llm_config.is_active: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="该大模型配置未启用") + if llm_config.is_embedding: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="请选择对话型大模型,不能使用嵌入模型") + + # 2. 若指定知识库,校验并检索(支持多知识库,结果合并后按相似度取 top_k) + knowledge_base_used = False + references = None + final_prompt = prompt_text + first_kb_id_used: int | None = None + + if kb_ids: + doc_processor = await get_document_processor(session) + all_results: list = [] + + for kb_id in kb_ids: + kb_stmt = select(KnowledgeBase).where(KnowledgeBase.id == kb_id) + kb = (await session.execute(kb_stmt)).scalar_one_or_none() + if not kb: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"知识库不存在: id={kb_id}") + if not kb.is_active: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"该知识库未启用: id={kb_id}") + + part = doc_processor.search_similar_documents( + knowledge_base_id=kb_id, + query=prompt_text, + k=body.top_k, + ) + all_results.extend(part) + + def _score(r): + return float(r.get("normalized_score") or r.get("similarity_score") or 0) + + all_results.sort(key=_score, reverse=True) + results = all_results[: body.top_k] + max_score = _score(results[0]) if results else 0.0 + + if results and max_score >= 0.5: + knowledge_base_used = True + first_kb_id_used = kb_ids[0] + refs = [] + for i, r in enumerate(results[:5], 1): + content = (r.get("content") or "").strip() + if content: + if len(content) > 1000: + content = content[:1000] + "..." + refs.append({"index": i, "content": content, "score": r.get("normalized_score")}) + references = refs + + context = "\n\n".join([f"【参考文档{ref['index']}】\n{ref['content']}" for ref in refs]) + final_prompt = f"""你是一个专业的助手。请仔细阅读以下参考文档,然后回答用户的问题。 + +{context} + +【用户问题】 +{prompt_text} + +【重要提示】 +- 参考文档中包含了与用户问题相关的信息 +- 请仔细阅读参考文档,提取相关信息来回答用户的问题 +- 即使文档没有直接定义,也要基于文档中的相关内容进行解释和说明 +- 如果文档中提到了相关概念、政策、法规等,请基于这些内容进行回答 +- 回答要准确、详细、有条理,尽量引用文档中的具体内容""" + logger.info(f"agentChat 使用 RAG,知识库 {kb_ids},检索 {len(results)} 条,最高相似度 {max_score:.3f}") + else: + logger.info(f"agentChat:知识库 {kb_ids} 检索结果相似度较低(最高 {max_score:.3f}),仅用提示词") + + # 3. 调用大模型 + llm_service = LLMService() + try: + response_text = await llm_service.chat_completion( + model_config=llm_config, + messages=[{"role": "user", "content": final_prompt}], + temperature=body.temperature if body.temperature is not None else llm_config.temperature, + max_tokens=body.max_tokens if body.max_tokens is not None else llm_config.max_tokens, + ) + except Exception as e: + logger.error(f"agentChat LLM 调用失败: {e}") + raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=f"大模型调用失败: {str(e)}") + + session.desc = f"SUCCESS: agentChat 完毕 model_id={body.model_id}, model_name={llm_config.model_name}" + + return HxfResponse( + AgentChatResponse( + response=response_text, + model_id=llm_config.id, + model_name=llm_config.model_name, + knowledge_base_id=first_kb_id_used if knowledge_base_used else None, + knowledge_base_used=knowledge_base_used, + references=references, + ) + ) diff --git a/th_agenter/api/routes.py b/th_agenter/api/routes.py index ccc01e3..5839839 100644 --- a/th_agenter/api/routes.py +++ b/th_agenter/api/routes.py @@ -3,6 +3,7 @@ from fastapi import APIRouter from .endpoints import chat +from .endpoints import agent_chat from .endpoints import auth from .endpoints import knowledge_base from .endpoints import smart_query @@ -63,6 +64,11 @@ router.include_router( prefix="/chat", tags=["chat"] ) +router.include_router( + agent_chat.router, + prefix="/agentChat", + tags=["agentChat"] +) router.include_router( smart_chat.router, diff --git a/utils/util_schemas.py b/utils/util_schemas.py index 52d1bf4..580c402 100644 --- a/utils/util_schemas.py +++ b/utils/util_schemas.py @@ -1,8 +1,8 @@ """Pydantic schemas for API requests and responses.""" -from typing import Optional, List, Any, Dict, TYPE_CHECKING +from typing import Optional, List, Any, Dict, TYPE_CHECKING, Union from datetime import datetime -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator from enum import Enum if TYPE_CHECKING: @@ -363,6 +363,34 @@ class ChatResponse(BaseModel): model_used: str +class AgentChatRequest(BaseModel): + """agentChat 请求:AI大模型、提示词、关联知识库""" + model_id: int = Field(..., ge=1, description="AI大模型配置ID") + prompt: Optional[str] = Field(default=None, max_length=20000, description="提示词,与 message 二选一") + message: Optional[str] = Field(default=None, max_length=20000, description="提示词(与 prompt 等价,二选一)") + knowledge_base_id: Optional[int] = Field(default=None, ge=1, description="关联知识库ID(单个),与 knowledge_base_ids 二选一") + knowledge_base_ids: Optional[List[Union[int, str]]] = Field(default=None, description="关联知识库ID列表,如 [1, 2] 或 ['3']") + top_k: int = Field(default=5, ge=1, le=20, description="知识库检索返回条数") + temperature: Optional[float] = Field(default=None, ge=0.0, le=2.0) + max_tokens: Optional[int] = Field(default=None, ge=1, le=32768) + + @model_validator(mode="after") + def require_prompt_or_message(self): + if not ((self.prompt or "").strip() or (self.message or "").strip()): + raise ValueError("prompt 或 message 至少提供一个") + return self + + +class AgentChatResponse(BaseModel): + """agentChat 响应""" + response: str = Field(..., description="模型输出结果") + model_id: int = Field(..., description="使用的大模型配置ID") + model_name: str = Field(..., description="使用的大模型名称") + knowledge_base_id: Optional[int] = Field(default=None, description="关联的知识库ID(若使用)") + knowledge_base_used: bool = Field(default=False, description="是否使用了知识库RAG") + references: Optional[List[Dict[str, Any]]] = Field(default=None, description="引用的知识库片段(若使用RAG)") + + class StreamChunk(BaseModel): """流式响应块模型""" content: str