hyf-backend/th_agenter/services/llm_service.py

165 lines
7.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""LLM service for workflow execution."""
import asyncio
from typing import List, Dict, Any, Optional, AsyncGenerator
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
from ..models.llm_config import LLMConfig
from loguru import logger
class LLMService:
"""LLM服务用于工作流中的大模型调用"""
def __init__(self):
pass
async def chat_completion(
self,
model_config: LLMConfig,
messages: List[Dict[str, str]],
temperature: Optional[float] = None,
max_tokens: Optional[int] = None
) -> str:
"""调用大模型进行对话完成"""
# 处理 base_url如果包含 /chat/completions需要移除
base_url = model_config.base_url
if base_url and '/chat/completions' in base_url:
# 移除 /chat/completions 后缀ChatOpenAI 会自动添加
base_url = base_url.replace('/chat/completions', '').rstrip('/')
logger.debug(f"调整 base_url: {model_config.base_url} -> {base_url}")
# 处理 SiliconFlow 的模型名称格式
model_name = model_config.model_name
if 'siliconflow' in (base_url or '').lower() and '/' not in model_name:
# SiliconFlow 需要 org/model 格式,尝试自动转换
model_name_lower = model_name.lower()
if 'deepseek' in model_name_lower or 'r1' in model_name_lower:
# 尝试常见的 DeepSeek 模型格式
if 'r1' in model_name_lower:
model_name = 'deepseek-ai/DeepSeek-R1'
elif 'v3' in model_name_lower:
model_name = 'deepseek-ai/DeepSeek-V3'
else:
model_name = f'deepseek-ai/{model_name}'
logger.debug(f"调整 SiliconFlow 模型名称: {model_config.model_name} -> {model_name}")
try:
# 创建LangChain ChatOpenAI实例
llm = ChatOpenAI(
model=model_name,
api_key=model_config.api_key,
base_url=base_url,
temperature=temperature or model_config.temperature,
max_tokens=max_tokens or model_config.max_tokens,
streaming=False
)
# 转换消息格式
langchain_messages = []
for msg in messages:
role = msg.get("role", "user")
content = msg.get("content", "")
if role == "system":
langchain_messages.append(SystemMessage(content=content))
elif role == "user":
langchain_messages.append(HumanMessage(content=content))
elif role == "assistant":
langchain_messages.append(AIMessage(content=content))
# 调用LLM
response = await llm.ainvoke(langchain_messages)
# 返回响应内容
return response.content
except Exception as e:
# 提取详细的错误信息
error_detail = str(e)
# 尝试从异常对象中提取更多信息
if hasattr(e, 'response'):
# HTTP 响应错误
if hasattr(e.response, 'status_code'):
error_detail = f"HTTP {e.response.status_code}: {error_detail}"
if hasattr(e.response, 'text'):
try:
import json
error_body = json.loads(e.response.text)
if isinstance(error_body, dict):
if 'message' in error_body:
error_detail = f"{error_detail} - {error_body['message']}"
if 'error' in error_body:
error_info = error_body['error']
if isinstance(error_info, dict) and 'message' in error_info:
error_detail = f"{error_detail} - {error_info['message']}"
except:
pass
# 添加模型配置信息到错误消息中(使用处理后的配置)
model_info = f"模型: {model_name}, base_url: {base_url}"
if 'Not Found' in error_detail or '404' in error_detail:
error_detail = f"{error_detail} ({model_info})。可能的原因1) 模型名称格式不正确SiliconFlow需要org/model格式如deepseek-ai/DeepSeek-R12) base_url配置错误3) API端点不存在"
elif '403' in error_detail or 'account balance' in error_detail.lower() or 'insufficient' in error_detail.lower():
error_detail = f"{error_detail} ({model_info})。可能的原因账户余额不足或API密钥权限不足"
elif '401' in error_detail or 'authentication' in error_detail.lower():
error_detail = f"{error_detail} ({model_info})。可能的原因API密钥无效或已过期"
else:
error_detail = f"{error_detail} ({model_info})"
logger.error(f"LLM调用失败: {error_detail}")
raise Exception(f"LLM调用失败: {error_detail}")
async def chat_completion_stream(
self,
model_config: LLMConfig,
messages: List[Dict[str, str]],
temperature: Optional[float] = None,
max_tokens: Optional[int] = None
) -> AsyncGenerator[str, None]:
"""调用大模型进行流式对话完成"""
try:
# 创建LangChain ChatOpenAI实例流式
llm = ChatOpenAI(
model=model_config.model_name,
api_key=model_config.api_key,
base_url=model_config.base_url,
temperature=temperature or model_config.temperature,
max_tokens=max_tokens or model_config.max_tokens,
streaming=True
)
# 转换消息格式
langchain_messages = []
for msg in messages:
role = msg.get("role", "user")
content = msg.get("content", "")
if role == "system":
langchain_messages.append(SystemMessage(content=content))
elif role == "user":
langchain_messages.append(HumanMessage(content=content))
elif role == "assistant":
langchain_messages.append(AIMessage(content=content))
# 流式调用LLM
async for chunk in llm.astream(langchain_messages):
if hasattr(chunk, 'content') and chunk.content:
yield chunk.content
except Exception as e:
logger.error(f"LLM流式调用失败: {str(e)}")
raise Exception(f"LLM流式调用失败: {str(e)}")
def get_model_info(self, model_config: LLMConfig) -> Dict[str, Any]:
"""获取模型信息"""
return {
"id": model_config.id,
"name": model_config.model_name,
"provider": model_config.provider,
"base_url": model_config.base_url,
"temperature": model_config.temperature,
"max_tokens": model_config.max_tokens,
"is_active": model_config.is_active
}