from flask import Flask, request, jsonify from flask_restful import Api, Resource import random import json import string import threading import time import os import requests from api_task import check_task_api import datetime import shutil from flask import Flask, request, jsonify from minio import Minio from minio.error import S3Error import os import paramiko from scp import SCPClient import urllib.parse import multiprocessing import uuid from flask import Flask, request, jsonify import os from kafka import KafkaProducer import json import random import string import logging from multiprocessing import Process import time import os import random import yaml from shutil import copyfile from minio import Minio from minio.error import S3Error from minio import Minio from minio.error import S3Error import os import urllib.parse from minio import Minio from minio.error import S3Error import os import urllib.parse from api_task import check_task_api app = Flask(__name__) api = Api(app) model_dict = { "001": "river", "002": "forest2", "003": "highWay2", "006": "vehicle", "007": "pedestrian", "008": "smogfire", "009": "AnglerSwimmer", "010": "countryRoad", "011": "ship2", "013": "channelEmergency", "014": "forest2", "015": "river2", "016": "cityMangement2", "017": "drowning", "018": "noParking", "019": "illParking", "020": "cityRoad", "023": "pothole", "024": "channel2", "025": "riverT", "026": "forestCrowd", "027": "highWay2T", "028": "smartSite", "029": "rubbish", "030": "firework" } PREFIX = "weights/026/v1.0/" LOCAL_DIR = "downloaded_weights" # 本地存储目录 # Remote server configuration REMOTE_HOST = "172.16.225.150" REMOTE_USER = "root" REMOTE_PASSWORD = "P#mnuJ6r4A" REMOTE_BASE_PATH = "/home/th/jcq/AIlib/AIlib2_train" # Mock database for storing task statuses training_tasks = {} model_versions = {} active_threads = {} # To track and manage active training threads # Configuration for different endpoints CALLBACK_URLS = { 'training': "http://172.18.0.129:8084/api/admin/trainTask/updateTrainTask", 'check': "http://172.18.0.129:8084/api/admin/checkTask/callback", 'management': "http://172.18.0.129:8084/api/admin/trainTask/updateTrainTask", 'deployment': "http://172.18.0.129:8084/api/admin/modelApplication/aiCallBack" } UPDATE_INTERVAL = 1 # Update interval in seconds def get_model_name(code): # Ensure code is a 3-digit string code_str = f"{int(code):03d}" if str(code).isdigit() else code return model_dict.get(code_str, f"Unknown code: {code}") class TrainThread(threading.Thread): def __init__(self, request_id, process_type, version): super().__init__() self.request_id = request_id self.process_type = process_type self.version = version self._stop_event = threading.Event() def run(self): """Simulate a long-running process with progress updates""" for i in range(1, 11): if self._stop_event.is_set(): print(f"Training {self.request_id} stopped by user request") training_tasks[self.request_id] = { 'request_id': self.request_id, 'process_type': self.process_type, 'progress': f'{i*10}%', 'model_version': self.version, 'status': 'stopped' } return time.sleep(2) progress = i * 10 training_tasks[self.request_id] = { 'request_id': self.request_id, 'process_type': self.process_type, 'progress': f'{progress}%', 'model_version': self.version, 'status': 'running' } # Mark as completed when done if not self._stop_event.is_set(): training_tasks[self.request_id] = { 'request_id': self.request_id, 'process_type': self.process_type, 'progress': '100%', 'model_version': self.version, 'status': 'completed' } def stop(self): """Signal the thread to stop""" self._stop_event.set() def generate_request_id(length=30): return 'bb' + ''.join(random.sample(string.ascii_letters, length)) def create_status_updater(process_type): """Create a status update function for a specific process type""" def status_updater(): while True: try: for task_id, task_data in list(training_tasks.items()): # Only update tasks of this process type if task_data.get('process_type') != process_type: continue # Skip if task was stopped if task_data.get('status') == 'stopped': continue # Create payload based on process type if process_type == 'training': payload = { "id": task_id, "dspModelCode": 502, "algorithmMap": "123", "algorithmAccuracy": "123", "algorithmRecall": "123", "aiVersion": "V1.1", "progress": task_data.get('progress', '0%'), "traiStartTime": "2025-05-21 16:10:05", "trainEndTime": "2025-05-21 16:10:05", "trainDuration": "100", "status": 4 if task_data.get('status') == 'completed' else 2 } elif process_type == 'check': payload = { "id": task_id, "checkStatus": "2" if task_data.get('status') == 'completed' else "running", "aiFileImage": "http://172.18.0.129:8084/5116fd79-7f98-4089-a59f-ec4df2cdb0f3_ai.jpg", "checkTaskDetails": [ { "dspModelCode": "123", "dspModelName": "行人识别", "possibility": "30%" }, { "dspModelCode": "123", "dspModelName": "车辆识别", "possibility": "50%" } ], "checkTime": "2025-05-21 16:10:05" } elif process_type == 'management': payload = { "id": task_id, "modelStatus": "active", "lastUpdated": "2025-05-21 16:10:05" } elif process_type == 'deployment': payload = { "id": task_id, "status": "4" if task_data.get('status') == 'completed' else "processing", "deployTime": "2025-05-21 16:10:05" } print(f"Sending {process_type} status update:", payload) try: response = requests.post(CALLBACK_URLS[process_type], json=payload) if response.status_code == 200: print(f"{process_type} status update successful") else: print(f"{process_type} status update failed with status code: {response.status_code}") except Exception as e: print(f"Error sending {process_type} status update: {str(e)}") except Exception as e: print(f"Error in {process_type} status update thread: {str(e)}") time.sleep(UPDATE_INTERVAL) return status_updater # Start status update threads for each process type for process_type in CALLBACK_URLS: thread = threading.Thread( target=create_status_updater(process_type), name=f"{process_type}_status_updater" ) thread.daemon = True thread.start() class TrainingTask(Resource): def post(self): data = request.json required_fields = ['code', 'version', 'image_dir', 'parameters'] if not all(field in data for field in required_fields): return {'error': 'Missing required fields'}, 400 version = data['version'] request_id = data['request_id'] # Create and start the training thread thread = TrainThread(request_id, 'training', version) active_threads[request_id] = thread thread.start() training_tasks[request_id] = { 'request_id': request_id, 'process_type': 'training', 'progress': '0%', 'model_version': version, 'status': 'initializing' } return { 'request_id': request_id, 'code': '0', 'msg': 'Training task created', "data": None }, 201 def get(self, request_id=None): if request_id: task = training_tasks.get(request_id) if not task: return {'error': 'Task not found'}, 404 return task else: return { 'tasks': [{ 'request_id': k, 'process_type': v.get('process_type', 'unknown'), 'status': v['status'], 'progress': v['progress'], 'model_version': v['model_version'] } for k, v in training_tasks.items()] } class TrainingTask_Stop0(Resource): def post(self): data = request.json request_id = data['request_id'] if request_id in active_threads: # Stop the thread active_threads[request_id].stop() # Remove from active threads del active_threads[request_id] if request_id in training_tasks: training_tasks[request_id]['status'] = 'stopped' return { 'request_id': request_id, 'code': '0', 'msg': 'Training stopped successfully', "data": None }, 200 class TrainingTask_Stop(Resource): def post(self): data = request.json # {"Command":"Stop","Errorcode":"","ModelScenetype":"train","request_id":"jcq13f4353a9312838c1d647f2593d3c3a6"} # { # "request_id": "jcq13f4353a9312838c1d647f2593d3c3a6", # "ModelSceneType": "train", # "Command": "Stop" # } # # 验证必要字段 # if 'request_id' not in data or 'ModelSceneType' not in data or 'Command' not in data: # return { # 'code': '1', # 'msg': 'Missing required fields (request_id, ModelSceneType, Command)', # 'data': None # }, 400 # 转换消息格式 training_request = { "Scene": data['ModelScenetype'].capitalize(), # 首字母大写 "Command": data['Command'].lower(), # 转为小写 "Request_ID": data['request_id'] } # 发送到Kafka try: producer = KafkaProducer(**KAFKA_CONFIG) future = producer.send(TOPIC_TRAIN, value=training_request) # 等待消息发送完成 future.get(timeout=10) producer.flush() # 更新本地任务状态 if data['request_id'] in active_threads: active_threads[data['request_id']].stop() del active_threads[data['request_id']] if data['request_id'] in training_tasks: training_tasks[data['request_id']]['status'] = 'stopped' return { 'request_id': data['request_id'], 'code': '0', 'msg': 'Stop command sent successfully', 'data': None }, 200 except Exception as e: return { 'request_id': data['request_id'], 'code': '2', 'msg': f'Failed to send stop command: {str(e)}', 'data': None }, 500 def manager_simulate_process(ModelCode): # request_id = data.get('request_id', "7d0cd072a7a939e77c4a602ac6151582") # version = data.get('version', "1.0") # ModelCode = data.get("ModelCode") if str(ModelCode) in ["000","032"]: # 必须得特定的数据格式才可以 callback_data = { "isExist":1 } else: # 必须得特定的数据格式才可以 callback_data = { "isExist":2 } # # 更新本地任务状态(如果需要) # training_tasks[request_id] = callback_data training_tasks=callback_data # 发送回调请求 callback_url = "http://172.18.0.129:8084/api/admin/modelApplication/aiCallBack" headers = { 'Content-Type': 'application/json', } try: response = requests.post( callback_url, data=json.dumps(callback_data), headers=headers ) # 检查响应状态 if response.status_code == 200: print(f"回调成功: {response.json()}") else: print(f"回调失败,状态码: {response.status_code}, 响应: {response.text}") except Exception as e: print(f"回调请求异常: {str(e)}") def stop_simulate_process(request_id, process_type, version="1.0"): # 必须得特定的数据格式才可以 callback_data = { "id": request_id, "status":4 } # 更新本地任务状态(如果需要) training_tasks[request_id] = callback_data # 发送回调请求 callback_url = "http://172.18.0.129:8084/api/admin/modelApplication/aiCallBack" headers = { 'Content-Type': 'application/json', # 如果需要认证,添加相应的头部 # 'Authorization': 'Bearer your_token_here' } try: response = requests.post( callback_url, data=json.dumps(callback_data), headers=headers ) # 检查响应状态 if response.status_code == 200: print(f"回调成功: {response.json()}") else: print(f"回调失败,状态码: {response.status_code}, 响应: {response.text}") except Exception as e: print(f"回调请求异常: {str(e)}") # √ def deploy_simulate_process(request_id, process_type, version="1.0"): callback_data = { "id": request_id, "status":4 } # 更新本地任务状态(如果需要) training_tasks[request_id] = callback_data # 发送回调请求 callback_url = "http://172.18.0.129:8084/api/admin/modelApplication/aiCallBack" headers = { 'Content-Type': 'application/json', # 如果需要认证,添加相应的头部 # 'Authorization': 'Bearer your_token_here' } try: response = requests.post( callback_url, data=json.dumps(callback_data), headers=headers ) # 检查响应状态 if response.status_code == 200: print(f"回调成功: {response.json()}") else: print(f"回调失败,状态码: {response.status_code}, 响应: {response.text}") except Exception as e: print(f"回调请求异常: {str(e)}") # √ def check_simulate_process(input_json): callback_data = check_task_api.infer_yolov5(input_json) print("*******line366",callback_data) request_id = input_json.get("request_id", "") # 更新本地任务状态(如果需要) training_tasks[request_id] = callback_data # 发送回调请求 callback_url = "http://172.18.0.129:8084/api/admin/checkTask/callback" headers = { 'Content-Type': 'application/json', # 如果需要认证,添加相应的头部 # 'Authorization': 'Bearer your_token_here' } try: response = requests.post( callback_url, data=json.dumps(callback_data), headers=headers ) # 检查响应状态 if response.status_code == 200: print(f"回调成功: {response.json()}") else: print(f"回调失败,状态码: {response.status_code}, 响应: {response.text}") except Exception as e: print(f"回调请求异常: {str(e)}") class TrainingTask_Check(Resource): def post(self): input_json = request.json print("###line587",input_json) data = request.json request_id = data['request_id'] # Start the inference in a new thread thread = threading.Thread( target=check_simulate_process, args=(input_json,) ) thread.start() return { 'id': request_id, 'code': '0', 'msg': 'Model check started', 'data': None }, 201 class TrainingTask_Manager(Resource): def post(self): data = request.json ModelCode = data.get("ModelCode") # Convert to string and pad with leading zeros (if it's a number) if isinstance(ModelCode, int): model_code_str = f"{ModelCode:03d}" # Formats as 3 digits (e.g., 5 → "005") else: model_code_str = str(ModelCode).zfill(3) # Pads with zeros (e.g., "5" → "005") if 0 <= int(model_code_str) <= 32: callback_data = { "isExist": 1 } else: callback_data = { "isExist": 0 } return callback_data , 201 model_dict = { "001": "river", "002": "forest2", "003": "highWay2", "006": "vehicle", "007": "pedestrian", "008": "smogfire", "009": "AnglerSwimmer", "010": "countryRoad", "011": "ship2", "013": "channelEmergency", "014": "forest2", "015": "river2", "016": "cityMangement2", "017": "drowning", "018": "noParking", "019": "illParking", "020": "cityRoad", "023": "pothole", "024": "channel2", "025": "riverT", "026": "forestCrowd", "027": "highWay2T", "028": "smartSite", "029": "rubbish", "030": "firework" } ENDPOINT = "minio-jndsj.t-aaron.com:2443" ACCESS_KEY = "PJM0c2qlauoXv5TMEHm2" SECRET_KEY = "Wr69Dm3ZH39M3GCSeyB3eFLynLPuGCKYfphixZuI" BUCKET_NAME = "algorithm" PREFIX = "weights/026/v1.0/" LOCAL_DIR = "downloaded_weights" # 本地存储目录 # Remote server configuration REMOTE_HOST = "172.16.225.151" REMOTE_USER = "root" REMOTE_PASSWORD = "P#mnuJ6r4A" REMOTE_BASE_PATH = "/home/th/jcq/AI_AutoPlat/" def get_model_name(code): # Ensure code is a 3-digit string code_str = f"{int(code):03d}" if str(code).isdigit() else code return model_dict.get(code_str, f"Unknown code: {code}") def parse_minio_path(model_dir): """Extract bucket name and object prefix from MinIO URL""" parsed = urllib.parse.urlparse(model_dir) if not parsed.netloc.startswith('minio-jndsj.t-aaron.com:2443'): raise ValueError("Invalid MinIO URL") path_parts = parsed.path.lstrip('/').split('/') if len(path_parts) < 3: raise ValueError("Invalid MinIO path format") bucket = path_parts[0] # The prefix is everything after the bucket name and before the filename prefix = '/'.join(path_parts[1:-1]) # Exclude the filename filename = path_parts[-1] return bucket, prefix, filename def download_minio_files(model_dir, local_dir='downloads'): """Download files from MinIO based on the model_dir URL""" try: ENDPOINT = "minio-jndsj.t-aaron.com:2443" ACCESS_KEY = "PJM0c2qlauoXv5TMEHm2" SECRET_KEY = "Wr69Dm3ZH39M3GCSeyB3eFLynLPuGCKYfphixZuI" BUCKET_NAME = "algorithm" bucket, prefix, filename = parse_minio_path(model_dir) # Create MinIO client client = Minio( ENDPOINT, access_key=ACCESS_KEY, secret_key=SECRET_KEY, secure=False ) # Ensure local directory exists os.makedirs(local_dir, exist_ok=True) # Build full object path object_path = f"{prefix}/{filename}" # Local file path local_path = os.path.join(local_dir, filename) # Download the file client.fget_object(bucket, object_path, local_path) print(f"Downloaded: {object_path} -> {local_path}") return local_path except S3Error as err: print(f"MinIO Error occurred: {err}") return None except Exception as e: print(f"Error occurred: {e}") return None def download_minio_deploy(endpoint, access_key, secret_key, bucket_name,prefix='', local_dir='downloaded_weights'): # 创建 MinIO 客户端 client = Minio( endpoint, access_key=access_key, secret_key=secret_key, secure=True # 使用 HTTPS ) try: # 确保本地目录存在 os.makedirs(local_dir, exist_ok=True) # 列出文件夹中的所有对象 objects = client.list_objects(bucket_name, prefix=prefix, recursive=True) for obj in objects: # 构建本地文件路径 local_path = os.path.join(local_dir, os.path.relpath(obj.object_name, prefix)) # 确保子目录存在 os.makedirs(os.path.dirname(local_path), exist_ok=True) # 下载文件 print(f"正在下载: {obj.object_name} -> {local_path}") client.fget_object(bucket_name, obj.object_name, local_path) print("所有文件下载完成!") return local_path except S3Error as e: print(f"MinIO 错误: {e}") except Exception as e: print(f"发生错误: {e}") def deploy_to_remote(local_path, model_code): """Deploy file to remote server""" try: model_name = get_model_name(model_code) remote_path = os.path.join(REMOTE_BASE_PATH, model_name, os.path.basename(local_path)) # Create SSH client ssh = paramiko.SSHClient() ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) ssh.connect(REMOTE_HOST, username=REMOTE_USER, password=REMOTE_PASSWORD) # Create SCP client scp = SCPClient(ssh.get_transport()) # Upload file scp.put(local_path, remote_path) print(f"Successfully deployed {local_path} to {remote_path}") # Close connections scp.close() ssh.close() return remote_path except Exception as e: print(f"Deployment failed: {e}") return None def deploy_to_local(local_path, model_code): """Deploy file to remote server""" try: model_name = get_model_name(model_code) REMOTE_BASE_PATH = "/home/th/jcq/AIlib/Deployment" remote_path = os.path.join(REMOTE_BASE_PATH, model_name, os.path.basename(local_path)) print("###line858",REMOTE_BASE_PATH,model_name,local_path) #/home/th/jcq/AIlib/Deployment # # Create SSH client # ssh = paramiko.SSHClient() # ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) # ssh.connect(REMOTE_HOST, username=REMOTE_USER, password=REMOTE_PASSWORD) # # Create SCP client # scp = SCPClient(ssh.get_transport()) # # Upload file # scp.put(local_path, remote_path) # 确保目标目录存在 os.makedirs(os.path.dirname(remote_path), exist_ok=True) # 移动文件 shutil.move(local_path, remote_path) print(f"Successfully deployed {local_path} to {remote_path}") # # Close connections # scp.close() # ssh.close() return remote_path except Exception as e: print(f"Deployment failed: {e}") return None def deployment_worker(model_dir, model_code, request_id): """Worker function that runs in a separate process to handle deployment""" callback_url = "http://172.18.0.129:8084/api/admin/modelApplication/aiCallBack" success = False message = "" remote_path = "" try: # Parse the path to get bucket and prefix bucket, prefix, filename = parse_minio_path(model_dir) print(f"[Deployment Worker] Downloading from bucket: {bucket}, prefix: {prefix}, filename: {filename}") # 配置信息 ENDPOINT = "minio-jndsj.t-aaron.com:2443" ACCESS_KEY = "PJM0c2qlauoXv5TMEHm2" SECRET_KEY = "Wr69Dm3ZH39M3GCSeyB3eFLynLPuGCKYfphixZuI" model_path = download_minio_deploy( endpoint=ENDPOINT, access_key=ACCESS_KEY, secret_key=SECRET_KEY, bucket_name=bucket, prefix=prefix, local_dir=LOCAL_DIR ) print("###line890",model_path) # # 2. Deploy to remote server # remote_path = deploy_to_remote(model_path, model_code) remote_path = deploy_to_local(model_path, model_code) if not remote_path: message = "Failed to deploy model to remote server" print(f"[Deployment Worker] {message}") raise Exception(message) # If we get here, deployment was successful success = True message = "操作成功" print(f"[Deployment Worker] Deployment successful: {remote_path}") except Exception as e: message = f"Deployment failed: {str(e)}" print(f"[Deployment Worker] {message}") success = False finally: # Clean up local file if it exists if 'local_path' in locals() and os.path.exists(model_path): try: os.remove(model_path) except OSError as e: print(f"[Deployment Worker] Warning: Failed to clean up local file: {e}") # Send callback try: if success: callback_data = { "id":request_id, "status": 3, "msg": "操作成功", "data": None # Using None instead of "null" string } else: callback_data = { "id":request_id, "status": 4, "msg": "未查询到数据", "data": None # Using None instead of "null" string } response = requests.post( callback_url, json=callback_data, timeout=10 ) if response.status_code != 200: print(f"[Deployment Worker] Callback failed with status {response.status_code}: {response.text}") print("line956",callback_data) else: print("[Deployment Worker] Callback sent successfully") print("line959",callback_data) except Exception as e: print(f"[Deployment Worker] Failed to send callback: {str(e)}") return success @app.route('/api/train/deploy', methods=['POST']) def deploy_model(): # Get input data data = request.get_json() print("###line879",data) if not data or 'ModelDir' not in data or 'ModelCode' not in data: return jsonify({ "status": "error", "message": "Missing required fields (ModelDir or ModelCode)" }), 400 # model_dir = data['ModelDir'] model_dir = os.path.join(data['ModelDir'] ,'yolov5.pt') # model_dir = os.path.join(data['ModelDir'] ,'best.pt') model_code = data['ModelCode'] # Generate a unique request ID if not provided request_id = data.get('request_id', str(uuid.uuid4())) try: # Start deployment in a separate process process = multiprocessing.Process( target=deployment_worker, args=(model_dir, model_code,request_id) ) process.start() # Immediately return response while deployment continues in background # return jsonify({ "request_id": request_id, "status": 4 }), 202 # 202 Accepted status code except Exception as e: return jsonify({ "status": "error", "message": f"Failed to start deployment: {str(e)}", "request_id": request_id }), 500 # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Kafka configuration KAFKA_CONFIG = { 'bootstrap_servers': '172.18.0.126:9094', 'value_serializer': lambda v: json.dumps(v).encode('utf-8') } # Kafka topics TOPIC_TRAIN = 'training-tasks' model_dict = { "001": "river", "002": "forest2", "003": "highWay2", "006": "vehicle", "007": "pedestrian", "008": "smogfire", "009": "AnglerSwimmer", "010": "countryRoad", "011": "ship2", "013": "channelEmergency", "014": "forest2", "015": "river2", "016": "cityMangement2", "017": "drowning", "018": "noParking", "019": "illParking", "020": "cityRoad", "023": "pothole", "024": "channel2", "025": "riverT", "026": "forestCrowd", "027": "highWay2T", "028": "smartSite", "029": "rubbish", "030": "firework" } def validate_minio_url(url): """验证MinIO URL格式是否正确""" try: parsed = urllib.parse.urlparse(url) if not all([parsed.scheme in ('http', 'https'), parsed.netloc, len(parsed.path.split('/')) >= 3]): raise ValueError("Invalid MinIO URL format") return parsed except Exception as e: raise ValueError(f"URL validation failed: {str(e)}") def download_minio_directory(minio_url, local_base_dir): """ 下载MinIO目录到本地 :param minio_url: 完整的MinIO路径 (e.g. "http://172.18.0.129:8084/th-dsp/027123") :param local_base_dir: 本地存储基础目录 (e.g. "/home/th/jcq/AI_AutoPlat/DATA/TrainDatasets") :return: (success, message) 元组 """ try: # 验证并解析URL parsed = validate_minio_url(minio_url) endpoint = parsed.netloc bucket_name = parsed.path.split('/')[1] prefix = '/'.join(parsed.path.split('/')[2:]) # 从环境变量获取凭证(推荐方式) access_key = os.getenv("MINIO_ACCESS_KEY", "PJM0c2qlauoXv5TMEHm2") # 默认值仅用于测试 secret_key = os.getenv("MINIO_SECRET_KEY", "Wr69Dm3ZH39M3GCSeyB3eFLynLPuGCKYfphixZuI") # 创建本地目录(使用prefix的最后部分作为目录名) dir_name = os.path.basename(prefix) or "minio_download" local_dir = os.path.join(local_base_dir, dir_name) os.makedirs(local_dir, exist_ok=True) # 初始化MinIO客户端 client = Minio( endpoint, access_key=access_key, secret_key=secret_key, secure=parsed.scheme == 'https' ) print(f"Starting download from {minio_url} to {local_dir}") # 递归下载所有文件 objects = client.list_objects(bucket_name, prefix=prefix, recursive=True) downloaded_files = 0 for obj in objects: # 构建本地路径(保持相对目录结构) relative_path = os.path.relpath(obj.object_name, prefix) local_path = os.path.join(local_dir, relative_path) # 确保目录存在 os.makedirs(os.path.dirname(local_path), exist_ok=True) # 下载文件 client.fget_object(bucket_name, obj.object_name, local_path) downloaded_files += 1 print(f"Downloaded [{downloaded_files}]: {obj.object_name} -> {local_path}") print(f"Download completed! Total files: {downloaded_files}") return True, local_dir except S3Error as e: error_msg = f"MinIO Error: {str(e)}" if "NoSuchBucket" in str(e): error_msg = f"Bucket not found: {bucket_name}" return False, error_msg except ValueError as e: return False, f"Invalid URL: {str(e)}" except Exception as e: return False, f"Unexpected error: {str(e)}" def split_dataset(data_dir, train_ratio=0.8): """Split dataset into training and validation sets, supporting both JPG and PNG formats Args: data_dir (str): Directory containing images and labels train_ratio (float): Ratio of training data (default: 0.8) Returns: tuple: (train_files, val_files) lists of image filenames """ # Supported image extensions IMAGE_EXTENSIONS = ('.jpg', '.jpeg', '.png') # Find all valid image-label pairs valid_files = [] for f in os.listdir(data_dir): # Check for image files if f.lower().endswith(IMAGE_EXTENSIONS): # Get corresponding label file path base_name = os.path.splitext(f)[0] # Remove extension txt_file = f"{base_name}.txt" # Only include if label file exists if os.path.exists(os.path.join(data_dir, txt_file)): valid_files.append(f) # Shuffle the dataset random.shuffle(valid_files) # Split into training and validation sets split_idx = int(len(valid_files) * train_ratio) return valid_files[:split_idx], valid_files[split_idx:] def write_file_list(file_list, output_file, base_dir): with open(output_file, 'w') as f: for file in file_list: f.write(os.path.join(base_dir, file) + '\n') def copy_files_to_folders(files, src_dir, dest_dir): """复制图片和标签文件到目标文件夹""" os.makedirs(dest_dir, exist_ok=True) for file in files: # 复制图片 src_img = os.path.join(src_dir, file) dest_img = os.path.join(dest_dir, file) copyfile(src_img, dest_img) # # 复制对应的标签文件 # src_txt = os.path.join(src_dir, file.replace('.jpg', '.txt')) # dest_txt = os.path.join(dest_dir, file.replace('.jpg', '.txt')) # 动态替换扩展名 src_txt = os.path.join(src_dir, file.replace('.jpg', '.txt').replace('.png', '.txt')) dest_txt = os.path.join(dest_dir, file.replace('.jpg', '.txt').replace('.png', '.txt')) copyfile(src_txt, dest_txt) def update_yaml(yaml_file, train_txt, val_txt): with open(yaml_file, 'r') as f: data = yaml.safe_load(f) data['train'] = train_txt data['val'] = val_txt with open(yaml_file, 'w') as f: yaml.dump(data, f, default_flow_style=False) def get_model_name(code): # Ensure code is a 3-digit string code_str = f"{int(code):03d}" if str(code).isdigit() else code return model_dict.get(code_str, f"Unknown code: {code}") def generate_request_id(length=30): """Generate a unique request ID""" return 'bb' + ''.join(random.sample(string.ascii_letters, length)) def transform_message_format(input_data): # transformed = { # "Scene": input_data.get("ModelScenetype", "train").capitalize(), # "Command": "start", # Default command # "Request_ID": input_data.get("request_id", generate_request_id()), # "Version": input_data.get("Version", "v1.0"), # "Version": input_data.get("Version", "v1.0"), # "Name": "channel2", # Default name if not provided # "Model": input_data.get("code", input_data.get("Code", "026")), # } transformed = { "Scene": input_data.get("ModelScenetype", "train").capitalize(), "Command": "start", # Default command "Request_ID": input_data.get("request_id", generate_request_id()), "Version": input_data.get("version"), "Name": "channel2", # Default name if not provided "Model": input_data.get("code", input_data.get("Code", "026")), } # Handle TrainParameter construction params = input_data.get("parameters", {}) batch_size = params.get("batch_size") img_size = params.get("img_size") epochs = params.get("epochs") model_name = get_model_name(input_data.get("code", input_data.get("Code", "026"))) RAW_YAML_PATH = "/home/th/jcq/AI_AutoPlat/AI_web_dsj/config/yaml" yaml_path = os.path.join(RAW_YAML_PATH,f"{model_name}.yaml") print("####line137",yaml_path) # minio_url = "http://172.18.0.129:8084/th-dsp/027123" # 获取URL minio_url = input_data.get("TrainImageDir") if not minio_url: print("Error: TrainImageDir not specified in input_data") else: # 指定本地存储目录 local_base_dir = "/home/th/jcq/AI_AutoPlat/DATA/TrainDatasets" # 执行下载 success, result = download_minio_directory(minio_url, local_base_dir) if success: print(f"Successfully downloaded to: {result}") else: print(f"Download failed: {result}") if success: print("line277",f"Files downloaded successfully to: {result}") else: print(f"Download failed: {result}") TrainImageDir = result # 创建 train 和 val 文件夹 train_dir = os.path.join(TrainImageDir, "train") val_dir = os.path.join(TrainImageDir, "val") os.makedirs(train_dir, exist_ok=True) os.makedirs(val_dir, exist_ok=True) # 划分数据集 train_files, val_files = split_dataset(TrainImageDir) # 复制文件到 train/ 和 val/ 文件夹 copy_files_to_folders(train_files, TrainImageDir, train_dir) copy_files_to_folders(val_files, TrainImageDir, val_dir) # 写入训练集和验证集txt文件 train_txt = os.path.join(TrainImageDir, "train.txt") val_txt = os.path.join(TrainImageDir, "val.txt") write_file_list(train_files, train_txt, train_dir) write_file_list(val_files, val_txt, val_dir) #JCQ: 更新yaml训练文件 # 更新yaml文件,预先获取,直接更新yaml nc = 5 names = ['flag', 'buoy', 'shipname', 'ship', 'uncover'] update_yaml(yaml_path, train_txt, val_txt) print(f"数据集划分完成,训练集: {len(train_files)} 张,验证集: {len(val_files)} 张") print(f"训练集列表已写入: {train_txt}") print(f"验证集列表已写入: {val_txt}") print(f"YAML文件已更新: {yaml_path}") RAW_WEIGHTS_PATH = "/home/th/jcq/AIlib/weights_init" model_path = os.path.join(RAW_WEIGHTS_PATH,model_name,"yolov5.pt") # model_path = os.path.join(RAW_WEIGHTS_PATH,model_name,"best.pt") print("####line98",model_path) transformed["TrainParameter"] = str([ batch_size, img_size, epochs, yaml_path, model_path ]) return transformed def validate_training_request(data): """Validate the training request structure""" required_fields = { "Scene": str, "Command": str, "Version": str, "Name": str, "Model": str, "TrainParameter": str } for field, field_type in required_fields.items(): if field not in data: return False, f"Missing required field: {field}" if not isinstance(data[field], field_type): return False, f"Field {field} must be {field_type.__name__}" if data["Scene"] not in ["Train"]: return False, "Invalid Scene value" if data["Command"] not in ["start", "stop"]: return False, "Invalid Command value" return True, "" def send_to_kafka(topic, message): """Send message to Kafka topic""" producer = None try: producer = KafkaProducer(**KAFKA_CONFIG) future = producer.send(topic, value=message) # Wait for message to be delivered result = future.get(timeout=10) logger.info(f"Message sent to partition {result.partition}, offset {result.offset}") return True, "" except Exception as e: logger.error(f"Failed to send message: {e}") return False, str(e) finally: if producer: producer.flush() producer.close() @app.route('/api/train', methods=['POST']) def handle_training_request(): """HTTP endpoint for training requests (non-blocking)""" try: # Parse and validate basic request structure data = request.get_json() if not data or not isinstance(data, dict): return jsonify({"status": "error", "message": "Invalid JSON data"}), 400 # Generate request ID request_id = data.get("request_id") # Start background process (daemon=True ensures process exits with main) p = Process( target=process_training_task, args=(data, request_id), daemon=True ) p.start() # Immediate response return jsonify({ "status": "success", "message": "Training task started", "request_id": request_id, "process_id": p.pid }), 202 # 202 Accepted except Exception as e: logger.error(f"Request handling error: {str(e)}") return jsonify({ "status": "error", "message": "Internal server error" }), 500 def process_training_task(data, request_id): """独立进程执行耗时任务""" try: logger.info(f"[{request_id}] Starting background task (PID: {os.getpid()})") start_time = time.time() # 1. 数据转换 transformed = transform_message_format(data) transformed["Request_ID"] = request_id # 2. 数据验证 is_valid, msg = validate_training_request(transformed) if not is_valid: raise ValueError(msg) # 3. 发送到Kafka success, error = send_to_kafka(TOPIC_TRAIN, transformed) if not success: raise RuntimeError(error) elapsed = time.time() - start_time logger.info(f"[{request_id}] Task completed in {elapsed:.2f}s") except Exception as e: logger.error(f"[{request_id}] Task failed: {str(e)}") api.add_resource(TrainingTask_Stop, '/api/train/control') api.add_resource(TrainingTask_Check, '/api/train/check') api.add_resource(TrainingTask_Manager, '/api/train/manager') if __name__ == '__main__': app.run(host='172.16.225.150', port=5001, debug=True)