diff --git a/data/chroma/kb_1/add53ead-7f8c-45e1-9851-b11e93ad0dfb/data_level0.bin b/data/chroma/kb_1/add53ead-7f8c-45e1-9851-b11e93ad0dfb/data_level0.bin index c7ab88d..819795a 100644 Binary files a/data/chroma/kb_1/add53ead-7f8c-45e1-9851-b11e93ad0dfb/data_level0.bin and b/data/chroma/kb_1/add53ead-7f8c-45e1-9851-b11e93ad0dfb/data_level0.bin differ diff --git a/data/chroma/kb_1/chroma.sqlite3 b/data/chroma/kb_1/chroma.sqlite3 index 97dfd8f..1fa2a7d 100644 Binary files a/data/chroma/kb_1/chroma.sqlite3 and b/data/chroma/kb_1/chroma.sqlite3 differ diff --git a/th_agenter/api/endpoints/agent_chat.py b/th_agenter/api/endpoints/agent_chat.py index 4172d0f..3d14beb 100644 --- a/th_agenter/api/endpoints/agent_chat.py +++ b/th_agenter/api/endpoints/agent_chat.py @@ -1,6 +1,7 @@ """agentChat 接口:根据 AI 大模型、提示词、关联知识库输出结果。""" from fastapi import APIRouter, Depends, HTTPException, status +from fastapi.responses import StreamingResponse from sqlalchemy import select from sqlalchemy.orm import Session from loguru import logger @@ -143,3 +144,176 @@ async def agent_chat( references=references, ) ) + + +@router.post( + "/stream", + summary="agentChat:按大模型、提示词、关联知识库流式输出结果", +) +@router.post( + "stream", + include_in_schema=False, +) +async def agent_chat_stream( + body: AgentChatRequest, + current_user: User = Depends(require_authenticated_user), + session: Session = Depends(get_session), +): + """ + agentChat 流式接口。 + + 根据选择的大模型、关联的知识库和提示词,实时流式返回模型输出文本。 + """ + prompt_text = (body.prompt or body.message or "").strip() + # 解析知识库 ID 列表:优先 knowledge_base_ids,否则 [knowledge_base_id] + kb_ids: list[int] = [] + if body.knowledge_base_ids: + try: + kb_ids = [ + int(x) + for x in body.knowledge_base_ids + if x is not None and str(x).strip() != "" + ] + kb_ids = [i for i in kb_ids if i >= 1] + except (ValueError, TypeError): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="knowledge_base_ids 须为数字或数字字符串", + ) + elif body.knowledge_base_id is not None and body.knowledge_base_id >= 1: + kb_ids = [body.knowledge_base_id] + + session.title = "agentChatStream" + session.desc = ( + f"START: agentChat/stream model_id={body.model_id}, " + f"prompt_len={len(prompt_text)}, knowledge_base_ids={kb_ids}" + ) + + # 1. 校验并获取大模型配置 + stmt = select(LLMConfig).where(LLMConfig.id == body.model_id) + llm_config = (await session.execute(stmt)).scalar_one_or_none() + if not llm_config: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="大模型配置不存在" + ) + if not llm_config.is_active: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail="该大模型配置未启用" + ) + if llm_config.is_embedding: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="请选择对话型大模型,不能使用嵌入模型", + ) + + # 2. 若指定知识库,校验并检索(支持多知识库,结果合并后按相似度取 top_k) + knowledge_base_used = False + references = None + final_prompt = prompt_text + + if kb_ids: + doc_processor = await get_document_processor(session) + all_results: list = [] + + for kb_id in kb_ids: + kb_stmt = select(KnowledgeBase).where(KnowledgeBase.id == kb_id) + kb = (await session.execute(kb_stmt)).scalar_one_or_none() + if not kb: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"知识库不存在: id={kb_id}", + ) + if not kb.is_active: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"该知识库未启用: id={kb_id}", + ) + + part = doc_processor.search_similar_documents( + knowledge_base_id=kb_id, + query=prompt_text, + k=body.top_k, + ) + all_results.extend(part) + + def _score(r): + return float(r.get("normalized_score") or r.get("similarity_score") or 0) + + all_results.sort(key=_score, reverse=True) + results = all_results[: body.top_k] + max_score = _score(results[0]) if results else 0.0 + + if results and max_score >= 0.5: + knowledge_base_used = True + refs = [] + for i, r in enumerate(results[:5], 1): + content = (r.get("content") or "").strip() + if content: + if len(content) > 1000: + content = content[:1000] + "..." + refs.append( + { + "index": i, + "content": content, + "score": r.get("normalized_score"), + } + ) + references = refs + + context = "\n\n".join( + [f"【参考文档{ref['index']}】\n{ref['content']}" for ref in refs] + ) + final_prompt = f"""你是一个专业的助手。请仔细阅读以下参考文档,然后回答用户的问题。 + +{context} + +【用户问题】 +{prompt_text} + +【重要提示】 +- 参考文档中包含了与用户问题相关的信息 +- 请仔细阅读参考文档,提取相关信息来回答用户的问题 +- 即使文档没有直接定义,也要基于文档中的相关内容进行解释和说明 +- 如果文档中提到了相关概念、政策、法规等,请基于这些内容进行回答 +- 回答要准确、详细、有条理,尽量引用文档中的具体内容""" + logger.info( + f"agentChat/stream 使用 RAG,知识库 {kb_ids},检索 {len(results)} 条,最高相似度 {max_score:.3f}" + ) + else: + logger.info( + f"agentChat/stream:知识库 {kb_ids} 检索结果相似度较低(最高 {max_score:.3f}),仅用提示词" + ) + + # 3. 调用大模型(流式) + llm_service = LLMService() + + async def generate(): + try: + async for chunk in llm_service.chat_completion_stream( + model_config=llm_config, + messages=[{"role": "user", "content": final_prompt}], + temperature=body.temperature + if body.temperature is not None + else llm_config.temperature, + max_tokens=body.max_tokens + if body.max_tokens is not None + else llm_config.max_tokens, + ): + if not chunk: + continue + # 和 /chat/stream 一样,直接输出文本内容 + yield chunk + except Exception as e: + logger.error(f"agentChat/stream LLM 流式调用失败: {e}") + # 将错误信息也推到流里,方便前端展示 + yield f"[ERROR] 大模型调用失败: {str(e)}" + + return StreamingResponse( + generate(), + media_type="text/stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + }, + ) +