136 lines
6.3 KiB
Python
136 lines
6.3 KiB
Python
|
|
import time, websockets, json, inspect, os
|
|||
|
|
import numpy as np
|
|||
|
|
from loguru import logger
|
|||
|
|
|
|||
|
|
from typing import Dict, Set, Callable, Any, Optional
|
|||
|
|
from traceback import format_exc
|
|||
|
|
|
|||
|
|
from DrGraph.utils.Helper import *
|
|||
|
|
import DrGraph.utils.vclEnums as enums
|
|||
|
|
|
|||
|
|
class WebSocketServer:
|
|||
|
|
"""
|
|||
|
|
WebSocket服务器类
|
|||
|
|
提供WebSocket服务器功能,包括客户端连接管理、消息处理和广播功能
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(self, host: str = "localhost", port: int = 8765):
|
|||
|
|
self.host = host
|
|||
|
|
self.port = port
|
|||
|
|
self.message_handlers: Dict[str, Callable] = {}
|
|||
|
|
self.server = None
|
|||
|
|
self.onClientConnect = None
|
|||
|
|
self.register_handler("file", self.handle_file)
|
|||
|
|
|
|||
|
|
def register_handler(self, message_type: str, handler: Callable):
|
|||
|
|
self.message_handlers[message_type] = handler
|
|||
|
|
|
|||
|
|
async def response_text(self, websocket: websockets.WebSocketServerProtocol, message, t = 'default'):
|
|||
|
|
if isinstance(message, dict):
|
|||
|
|
message = json.dumps(message, ensure_ascii=False)
|
|||
|
|
if Helper.AppFlag_SaveLog:
|
|||
|
|
logger.warning(f"发送消息: {message}")
|
|||
|
|
await websocket.send(f't{message}')
|
|||
|
|
caller = inspect.stack()[1]
|
|||
|
|
logger.info(f"(type={t})to {websocket.remote_address}: {message} - caller={caller} ")
|
|||
|
|
|
|||
|
|
async def response_binary(self, websocket: websockets.WebSocketServerProtocol, data):
|
|||
|
|
# if isinstance(data, np.ndarray):
|
|||
|
|
# data = data.tobytes()
|
|||
|
|
await websocket.send(bytearray(b'b' + data))
|
|||
|
|
|
|||
|
|
async def handle_message(self, websocket: websockets.WebSocketServerProtocol, message: str):
|
|||
|
|
try:
|
|||
|
|
data = json.loads(message)
|
|||
|
|
message_type = data.get("type")
|
|||
|
|
payload = data.get("data")
|
|||
|
|
if Helper.AppFlag_SaveLog:
|
|||
|
|
logger.warning(f"收到消息: {message}")
|
|||
|
|
|
|||
|
|
if message_type in self.message_handlers:
|
|||
|
|
response = await self.message_handlers[message_type](websocket, payload)
|
|||
|
|
if response is not None:
|
|||
|
|
if isinstance(response, (bytes, bytearray, np.ndarray)):
|
|||
|
|
# print("发送图片数据")
|
|||
|
|
await self.response_binary(websocket, response);
|
|||
|
|
else:
|
|||
|
|
await self.response_text(websocket, response, message_type);
|
|||
|
|
elif message_type == 'echo':
|
|||
|
|
print("echo message ", int(time.time() * 1000))
|
|||
|
|
data["type"] = "echo_response"
|
|||
|
|
await self.response_text(websocket, data);
|
|||
|
|
else:
|
|||
|
|
logger.warning(f"未知类型消息: {message} - {websocket.remote_address}")
|
|||
|
|
await self.response_text(websocket, Helper.build_response(message_type, enums.Response.ERROR, f"未知消息类型 - {message_type}"))
|
|||
|
|
except json.JSONDecodeError:
|
|||
|
|
logger.error(f"无效的JSON格式 - {message}")
|
|||
|
|
await self.response_text(websocket, Helper.build_response("JSONDecodeError", enums.Response.EXCEPTION, f"无效的JSON格式 - {message}"))
|
|||
|
|
except BrokenPipeError as e:
|
|||
|
|
logger.error(f"WebSocket BrokenPipeError: {str(e)} - 客户端: {websocket.remote_address}")
|
|||
|
|
# BrokenPipeError表示连接已断开,不需要特殊处理,让上层处理ConnectionClosed异常
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"处理消息时出错: {format_exc()}")
|
|||
|
|
await self.response_text(websocket, Helper.build_response("Exception", enums.Response.EXCEPTION, f"服务器内部错误 - 处理消息时出错: {format_exc()}"))
|
|||
|
|
|
|||
|
|
async def handle_client(self, websocket: websockets.WebSocketServerProtocol, path: str = ""):
|
|||
|
|
"""
|
|||
|
|
处理客户端连接
|
|||
|
|
|
|||
|
|
参数:
|
|||
|
|
websocket (websockets.WebSocketServerProtocol): WebSocket连接对象
|
|||
|
|
path (str): 请求路径
|
|||
|
|
"""
|
|||
|
|
logger.info(f"客户端 {websocket.remote_address} [path:{path}] 新建连接")
|
|||
|
|
if self.onClientConnect:
|
|||
|
|
await self.onClientConnect(websocket, True)
|
|||
|
|
try:
|
|||
|
|
async for message in websocket:
|
|||
|
|
await self.handle_message(websocket, message)
|
|||
|
|
except websockets.exceptions.ConnectionClosed:
|
|||
|
|
logger.info(f"客户端 {websocket.remote_address} [path:{path}] 连接已关闭")
|
|||
|
|
except BrokenPipeError:
|
|||
|
|
logger.info(f"客户端 {websocket.remote_address} [path:{path}] 连接BrokenPipeError")
|
|||
|
|
finally:
|
|||
|
|
if self.onClientConnect:
|
|||
|
|
await self.onClientConnect(websocket, False)
|
|||
|
|
|
|||
|
|
async def start(self):
|
|||
|
|
"""
|
|||
|
|
启动WebSocket服务器
|
|||
|
|
"""
|
|||
|
|
self.server = await websockets.serve(self.handle_client, self.host, self.port)
|
|||
|
|
logger.warning(f"WebSocket服务器已启动: {self.host}:{self.port}")
|
|||
|
|
|
|||
|
|
async def stop(self):
|
|||
|
|
"""
|
|||
|
|
停止WebSocket服务器
|
|||
|
|
"""
|
|||
|
|
if self.server:
|
|||
|
|
self.server.close()
|
|||
|
|
await self.server.wait_closed()
|
|||
|
|
logger.info("WebSocket服务器已停止")
|
|||
|
|
|
|||
|
|
async def handle_file(self, websocket: websockets.WebSocketServerProtocol, payload: str):
|
|||
|
|
logger.info(f"接收文件: {payload}, {payload['command']}")
|
|||
|
|
command = payload['command']
|
|||
|
|
if command == 'dir':
|
|||
|
|
path = payload.get('path', '/')
|
|||
|
|
files = []
|
|||
|
|
folders = []
|
|||
|
|
try:
|
|||
|
|
# 获取目录下的所有文件和文件夹
|
|||
|
|
if os.path.exists(path) and os.path.isdir(path):
|
|||
|
|
with os.scandir(path) as entries:
|
|||
|
|
for entry in entries:
|
|||
|
|
if entry.is_file():
|
|||
|
|
files.append(entry.name)
|
|||
|
|
elif entry.is_dir():
|
|||
|
|
folders.append(entry.name)
|
|||
|
|
else:
|
|||
|
|
logger.warning(f"路径不存在或不是目录: {path}")
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"读取目录时出错: {path}, 错误: {e}")
|
|||
|
|
|
|||
|
|
# 合并文件和文件夹列表
|
|||
|
|
all_items = folders + files
|
|||
|
|
logger.info(f"目录: {path}, 文件和文件夹列表: {all_items}")
|