import time, websockets, json, inspect, os import numpy as np from loguru import logger from typing import Dict, Set, Callable, Any, Optional from traceback import format_exc from DrGraph.utils.Helper import * import DrGraph.utils.vclEnums as enums class WebSocketServer: """ WebSocket服务器类 提供WebSocket服务器功能,包括客户端连接管理、消息处理和广播功能 """ def __init__(self, host: str = "localhost", port: int = 8765): self.host = host self.port = port self.message_handlers: Dict[str, Callable] = {} self.server = None self.onClientConnect = None self.register_handler("file", self.handle_file) def register_handler(self, message_type: str, handler: Callable): self.message_handlers[message_type] = handler async def response_text(self, websocket: websockets.WebSocketServerProtocol, message, t = 'default'): if isinstance(message, dict): message = json.dumps(message, ensure_ascii=False) if Helper.AppFlag_SaveLog: logger.warning(f"发送消息: {message}") await websocket.send(f't{message}') caller = inspect.stack()[1] logger.info(f"(type={t})to {websocket.remote_address}: {message} - caller={caller} ") async def response_binary(self, websocket: websockets.WebSocketServerProtocol, data): # if isinstance(data, np.ndarray): # data = data.tobytes() await websocket.send(bytearray(b'b' + data)) async def handle_message(self, websocket: websockets.WebSocketServerProtocol, message: str): try: data = json.loads(message) message_type = data.get("type") payload = data.get("data") if Helper.AppFlag_SaveLog: logger.warning(f"收到消息: {message}") if message_type in self.message_handlers: response = await self.message_handlers[message_type](websocket, payload) if response is not None: if isinstance(response, (bytes, bytearray, np.ndarray)): # print("发送图片数据") await self.response_binary(websocket, response); else: await self.response_text(websocket, response, message_type); elif message_type == 'echo': print("echo message ", int(time.time() * 1000)) data["type"] = "echo_response" await self.response_text(websocket, data); else: logger.warning(f"未知类型消息: {message} - {websocket.remote_address}") await self.response_text(websocket, Helper.build_response(message_type, enums.Response.ERROR, f"未知消息类型 - {message_type}")) except json.JSONDecodeError: logger.error(f"无效的JSON格式 - {message}") await self.response_text(websocket, Helper.build_response("JSONDecodeError", enums.Response.EXCEPTION, f"无效的JSON格式 - {message}")) except BrokenPipeError as e: logger.error(f"WebSocket BrokenPipeError: {str(e)} - 客户端: {websocket.remote_address}") # BrokenPipeError表示连接已断开,不需要特殊处理,让上层处理ConnectionClosed异常 except Exception as e: logger.error(f"处理消息时出错: {format_exc()}") await self.response_text(websocket, Helper.build_response("Exception", enums.Response.EXCEPTION, f"服务器内部错误 - 处理消息时出错: {format_exc()}")) async def handle_client(self, websocket: websockets.WebSocketServerProtocol, path: str = ""): """ 处理客户端连接 参数: websocket (websockets.WebSocketServerProtocol): WebSocket连接对象 path (str): 请求路径 """ logger.info(f"客户端 {websocket.remote_address} [path:{path}] 新建连接") if self.onClientConnect: await self.onClientConnect(websocket, True) try: async for message in websocket: await self.handle_message(websocket, message) except websockets.exceptions.ConnectionClosed: logger.info(f"客户端 {websocket.remote_address} [path:{path}] 连接已关闭") except BrokenPipeError: logger.info(f"客户端 {websocket.remote_address} [path:{path}] 连接BrokenPipeError") finally: if self.onClientConnect: await self.onClientConnect(websocket, False) async def start(self): """ 启动WebSocket服务器 """ self.server = await websockets.serve(self.handle_client, self.host, self.port) logger.warning(f"WebSocket服务器已启动: {self.host}:{self.port}") async def stop(self): """ 停止WebSocket服务器 """ if self.server: self.server.close() await self.server.wait_closed() logger.info("WebSocket服务器已停止") async def handle_file(self, websocket: websockets.WebSocketServerProtocol, payload: str): logger.info(f"接收文件: {payload}, {payload['command']}") command = payload['command'] if command == 'dir': path = payload.get('path', '/') files = [] folders = [] try: # 获取目录下的所有文件和文件夹 if os.path.exists(path) and os.path.isdir(path): with os.scandir(path) as entries: for entry in entries: if entry.is_file(): files.append(entry.name) elif entry.is_dir(): folders.append(entry.name) else: logger.warning(f"路径不存在或不是目录: {path}") except Exception as e: logger.error(f"读取目录时出错: {path}, 错误: {e}") # 合并文件和文件夹列表 all_items = folders + files logger.info(f"目录: {path}, 文件和文件夹列表: {all_items}")