110 lines
4.1 KiB
Python
110 lines
4.1 KiB
Python
"""LLM service for workflow execution."""
|
||
|
||
import asyncio
|
||
from typing import List, Dict, Any, Optional, AsyncGenerator
|
||
from langchain_openai import ChatOpenAI
|
||
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
|
||
|
||
from ..models.llm_config import LLMConfig
|
||
from loguru import logger
|
||
|
||
class LLMService:
|
||
"""LLM服务,用于工作流中的大模型调用"""
|
||
|
||
def __init__(self):
|
||
pass
|
||
|
||
async def chat_completion(
|
||
self,
|
||
model_config: LLMConfig,
|
||
messages: List[Dict[str, str]],
|
||
temperature: Optional[float] = None,
|
||
max_tokens: Optional[int] = None
|
||
) -> str:
|
||
"""调用大模型进行对话完成"""
|
||
try:
|
||
# 创建LangChain ChatOpenAI实例
|
||
llm = ChatOpenAI(
|
||
model=model_config.model_name,
|
||
api_key=model_config.api_key,
|
||
base_url=model_config.base_url,
|
||
temperature=temperature or model_config.temperature,
|
||
max_tokens=max_tokens or model_config.max_tokens,
|
||
streaming=False
|
||
)
|
||
|
||
# 转换消息格式
|
||
langchain_messages = []
|
||
for msg in messages:
|
||
role = msg.get("role", "user")
|
||
content = msg.get("content", "")
|
||
|
||
if role == "system":
|
||
langchain_messages.append(SystemMessage(content=content))
|
||
elif role == "user":
|
||
langchain_messages.append(HumanMessage(content=content))
|
||
elif role == "assistant":
|
||
langchain_messages.append(AIMessage(content=content))
|
||
|
||
# 调用LLM
|
||
response = await llm.ainvoke(langchain_messages)
|
||
|
||
# 返回响应内容
|
||
return response.content
|
||
|
||
except Exception as e:
|
||
logger.error(f"LLM调用失败: {str(e)}")
|
||
raise Exception(f"LLM调用失败: {str(e)}")
|
||
|
||
async def chat_completion_stream(
|
||
self,
|
||
model_config: LLMConfig,
|
||
messages: List[Dict[str, str]],
|
||
temperature: Optional[float] = None,
|
||
max_tokens: Optional[int] = None
|
||
) -> AsyncGenerator[str, None]:
|
||
"""调用大模型进行流式对话完成"""
|
||
try:
|
||
# 创建LangChain ChatOpenAI实例(流式)
|
||
llm = ChatOpenAI(
|
||
model=model_config.model_name,
|
||
api_key=model_config.api_key,
|
||
base_url=model_config.base_url,
|
||
temperature=temperature or model_config.temperature,
|
||
max_tokens=max_tokens or model_config.max_tokens,
|
||
streaming=True
|
||
)
|
||
|
||
# 转换消息格式
|
||
langchain_messages = []
|
||
for msg in messages:
|
||
role = msg.get("role", "user")
|
||
content = msg.get("content", "")
|
||
|
||
if role == "system":
|
||
langchain_messages.append(SystemMessage(content=content))
|
||
elif role == "user":
|
||
langchain_messages.append(HumanMessage(content=content))
|
||
elif role == "assistant":
|
||
langchain_messages.append(AIMessage(content=content))
|
||
|
||
# 流式调用LLM
|
||
async for chunk in llm.astream(langchain_messages):
|
||
if hasattr(chunk, 'content') and chunk.content:
|
||
yield chunk.content
|
||
|
||
except Exception as e:
|
||
logger.error(f"LLM流式调用失败: {str(e)}")
|
||
raise Exception(f"LLM流式调用失败: {str(e)}")
|
||
|
||
def get_model_info(self, model_config: LLMConfig) -> Dict[str, Any]:
|
||
"""获取模型信息"""
|
||
return {
|
||
"id": model_config.id,
|
||
"name": model_config.model_name,
|
||
"provider": model_config.provider,
|
||
"base_url": model_config.base_url,
|
||
"temperature": model_config.temperature,
|
||
"max_tokens": model_config.max_tokens,
|
||
"is_active": model_config.is_active
|
||
} |