hyf-backend/utils/wssServer.py

136 lines
6.3 KiB
Python
Raw Normal View History

2026-01-21 13:45:39 +08:00
import time, websockets, json, inspect, os
import numpy as np
from loguru import logger
from typing import Dict, Set, Callable, Any, Optional
from traceback import format_exc
from DrGraph.utils.Helper import *
import DrGraph.utils.vclEnums as enums
class WebSocketServer:
"""
WebSocket服务器类
提供WebSocket服务器功能包括客户端连接管理消息处理和广播功能
"""
def __init__(self, host: str = "localhost", port: int = 8765):
self.host = host
self.port = port
self.message_handlers: Dict[str, Callable] = {}
self.server = None
self.onClientConnect = None
self.register_handler("file", self.handle_file)
def register_handler(self, message_type: str, handler: Callable):
self.message_handlers[message_type] = handler
async def response_text(self, websocket: websockets.WebSocketServerProtocol, message, t = 'default'):
if isinstance(message, dict):
message = json.dumps(message, ensure_ascii=False)
if Helper.AppFlag_SaveLog:
logger.warning(f"发送消息: {message}")
await websocket.send(f't{message}')
caller = inspect.stack()[1]
logger.info(f"(type={t})to {websocket.remote_address}: {message} - caller={caller} ")
async def response_binary(self, websocket: websockets.WebSocketServerProtocol, data):
# if isinstance(data, np.ndarray):
# data = data.tobytes()
await websocket.send(bytearray(b'b' + data))
async def handle_message(self, websocket: websockets.WebSocketServerProtocol, message: str):
try:
data = json.loads(message)
message_type = data.get("type")
payload = data.get("data")
if Helper.AppFlag_SaveLog:
logger.warning(f"收到消息: {message}")
if message_type in self.message_handlers:
response = await self.message_handlers[message_type](websocket, payload)
if response is not None:
if isinstance(response, (bytes, bytearray, np.ndarray)):
# print("发送图片数据")
await self.response_binary(websocket, response);
else:
await self.response_text(websocket, response, message_type);
elif message_type == 'echo':
print("echo message ", int(time.time() * 1000))
data["type"] = "echo_response"
await self.response_text(websocket, data);
else:
logger.warning(f"未知类型消息: {message} - {websocket.remote_address}")
await self.response_text(websocket, Helper.build_response(message_type, enums.Response.ERROR, f"未知消息类型 - {message_type}"))
except json.JSONDecodeError:
logger.error(f"无效的JSON格式 - {message}")
await self.response_text(websocket, Helper.build_response("JSONDecodeError", enums.Response.EXCEPTION, f"无效的JSON格式 - {message}"))
except BrokenPipeError as e:
logger.error(f"WebSocket BrokenPipeError: {str(e)} - 客户端: {websocket.remote_address}")
# BrokenPipeError表示连接已断开不需要特殊处理让上层处理ConnectionClosed异常
except Exception as e:
logger.error(f"处理消息时出错: {format_exc()}")
await self.response_text(websocket, Helper.build_response("Exception", enums.Response.EXCEPTION, f"服务器内部错误 - 处理消息时出错: {format_exc()}"))
async def handle_client(self, websocket: websockets.WebSocketServerProtocol, path: str = ""):
"""
处理客户端连接
参数:
websocket (websockets.WebSocketServerProtocol): WebSocket连接对象
path (str): 请求路径
"""
logger.info(f"客户端 {websocket.remote_address} [path:{path}] 新建连接")
if self.onClientConnect:
await self.onClientConnect(websocket, True)
try:
async for message in websocket:
await self.handle_message(websocket, message)
except websockets.exceptions.ConnectionClosed:
logger.info(f"客户端 {websocket.remote_address} [path:{path}] 连接已关闭")
except BrokenPipeError:
logger.info(f"客户端 {websocket.remote_address} [path:{path}] 连接BrokenPipeError")
finally:
if self.onClientConnect:
await self.onClientConnect(websocket, False)
async def start(self):
"""
启动WebSocket服务器
"""
self.server = await websockets.serve(self.handle_client, self.host, self.port)
logger.warning(f"WebSocket服务器已启动: {self.host}:{self.port}")
async def stop(self):
"""
停止WebSocket服务器
"""
if self.server:
self.server.close()
await self.server.wait_closed()
logger.info("WebSocket服务器已停止")
async def handle_file(self, websocket: websockets.WebSocketServerProtocol, payload: str):
logger.info(f"接收文件: {payload}, {payload['command']}")
command = payload['command']
if command == 'dir':
path = payload.get('path', '/')
files = []
folders = []
try:
# 获取目录下的所有文件和文件夹
if os.path.exists(path) and os.path.isdir(path):
with os.scandir(path) as entries:
for entry in entries:
if entry.is_file():
files.append(entry.name)
elif entry.is_dir():
folders.append(entry.name)
else:
logger.warning(f"路径不存在或不是目录: {path}")
except Exception as e:
logger.error(f"读取目录时出错: {path}, 错误: {e}")
# 合并文件和文件夹列表
all_items = folders + files
logger.info(f"目录: {path}, 文件和文件夹列表: {all_items}")