# -*- coding: utf-8 -*- """ HTTP to Kafka proxy service for training requests with message transformation Features: - Receives HTTP POST requests with training parameters - Transforms message format from incoming to Kafka-compatible format - Validates the request format - Forwards valid requests to Kafka - Returns appropriate HTTP responses """ from flask import Flask, request, jsonify import os from kafka import KafkaProducer import json import random import string import logging from multiprocessing import Process import time import os import random import yaml from shutil import copyfile from minio import Minio from minio.error import S3Error from minio import Minio from minio.error import S3Error import os import urllib.parse from minio import Minio from minio.error import S3Error import os import urllib.parse app = Flask(__name__) # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Kafka configuration KAFKA_CONFIG = { 'bootstrap_servers': '172.18.0.126:9094', 'value_serializer': lambda v: json.dumps(v).encode('utf-8') } # Kafka topics TOPIC_TRAIN = 'training-tasks' model_dict = { "001": "AnglerSwimmer", "002": "channel2", "003": "channelEmergency", "004": "cityMangement", "005": "cityMangement2", "006": "cityMangement3", "007": "cityRoad", "008": "conf", "009": "countryRoad", "010": "crackMeasurement", "011": "crowdCounting", "012": "drowning", "013": "forest", "014": "forest2", "015": "forestCrowd", "016": "highWay2", "017": "illParking", "018": "noParking", "019": "ocr2", "020": "ocr_en", "021": "pedestrian", "022": "pothole", "023": "river", "024": "river2", "025": "riverT", "026": "road", "027": "ship", "028": "ship2", "029": "smogfire", "030": "trafficAccident", "031": "vehicle" } def validate_minio_url(url): """验证MinIO URL格式是否正确""" try: parsed = urllib.parse.urlparse(url) if not all([parsed.scheme in ('http', 'https'), parsed.netloc, len(parsed.path.split('/')) >= 3]): raise ValueError("Invalid MinIO URL format") return parsed except Exception as e: raise ValueError(f"URL validation failed: {str(e)}") def download_minio_directory(minio_url, local_base_dir): """ 下载MinIO目录到本地 :param minio_url: 完整的MinIO路径 (e.g. "http://172.18.0.129:8084/th-dsp/027123") :param local_base_dir: 本地存储基础目录 (e.g. "/home/th/jcq/AI_AutoPlat/DATA/TrainDatasets") :return: (success, message) 元组 """ try: # 验证并解析URL parsed = validate_minio_url(minio_url) endpoint = parsed.netloc bucket_name = parsed.path.split('/')[1] prefix = '/'.join(parsed.path.split('/')[2:]) # 从环境变量获取凭证(推荐方式) access_key = os.getenv("MINIO_ACCESS_KEY", "IKf3A0ZSXsR1m0oalMjV") # 默认值仅用于测试 secret_key = os.getenv("MINIO_SECRET_KEY", "yoC6qRo2hlyZu8Pdbt6eh9TVaTV4gD7KRudromrk") # 创建本地目录(使用prefix的最后部分作为目录名) dir_name = os.path.basename(prefix) or "minio_download" local_dir = os.path.join(local_base_dir, dir_name) os.makedirs(local_dir, exist_ok=True) # 初始化MinIO客户端 client = Minio( endpoint, access_key=access_key, secret_key=secret_key, secure=parsed.scheme == 'https' ) print(f"Starting download from {minio_url} to {local_dir}") # 递归下载所有文件 objects = client.list_objects(bucket_name, prefix=prefix, recursive=True) downloaded_files = 0 for obj in objects: # 构建本地路径(保持相对目录结构) relative_path = os.path.relpath(obj.object_name, prefix) local_path = os.path.join(local_dir, relative_path) # 确保目录存在 os.makedirs(os.path.dirname(local_path), exist_ok=True) # 下载文件 client.fget_object(bucket_name, obj.object_name, local_path) downloaded_files += 1 print(f"Downloaded [{downloaded_files}]: {obj.object_name} -> {local_path}") print(f"Download completed! Total files: {downloaded_files}") return True, local_dir except S3Error as e: error_msg = f"MinIO Error: {str(e)}" if "NoSuchBucket" in str(e): error_msg = f"Bucket not found: {bucket_name}" return False, error_msg except ValueError as e: return False, f"Invalid URL: {str(e)}" except Exception as e: return False, f"Unexpected error: {str(e)}" def split_dataset(data_dir, train_ratio=0.8): # 获取所有jpg文件 jpg_files = [f for f in os.listdir(data_dir) if f.endswith('.jpg')] # 检查对应的txt文件是否存在 valid_files = [] for jpg in jpg_files: txt = jpg.replace('.jpg', '.txt') if os.path.exists(os.path.join(data_dir, txt)): valid_files.append(jpg) # 随机打乱 random.shuffle(valid_files) # 划分训练集和验证集 split_idx = int(len(valid_files) * train_ratio) train_files = valid_files[:split_idx] val_files = valid_files[split_idx:] return train_files, val_files def write_file_list(file_list, output_file, base_dir): with open(output_file, 'w') as f: for file in file_list: f.write(os.path.join(base_dir, file) + '\n') def copy_files_to_folders(files, src_dir, dest_dir): """复制图片和标签文件到目标文件夹""" os.makedirs(dest_dir, exist_ok=True) for file in files: # 复制图片 src_img = os.path.join(src_dir, file) dest_img = os.path.join(dest_dir, file) copyfile(src_img, dest_img) # 复制对应的标签文件 src_txt = os.path.join(src_dir, file.replace('.jpg', '.txt')) dest_txt = os.path.join(dest_dir, file.replace('.jpg', '.txt')) copyfile(src_txt, dest_txt) def update_yaml(yaml_file, train_txt, val_txt, nc, names): with open(yaml_file, 'r') as f: data = yaml.safe_load(f) data['train'] = train_txt data['val'] = val_txt data['nc'] = nc data['names'] = names with open(yaml_file, 'w') as f: yaml.dump(data, f, default_flow_style=False) def get_model_name(code): # Ensure code is a 3-digit string code_str = f"{int(code):03d}" if str(code).isdigit() else code return model_dict.get(code_str, f"Unknown code: {code}") def generate_request_id(length=30): """Generate a unique request ID""" return 'bb' + ''.join(random.sample(string.ascii_letters, length)) def transform_message_format(input_data): transformed = { "Scene": input_data.get("ModelScenetype", "train").capitalize(), "Command": "start", # Default command "Request_ID": input_data.get("request_id", generate_request_id()), "Version": input_data.get("Version", "v1.0"), "Name": "channel2", # Default name if not provided "Model": input_data.get("code", input_data.get("Code", "026")), } # Handle TrainParameter construction params = input_data.get("parameters", {}) batch_size = params.get("batch_size") img_size = params.get("img_size") epochs = params.get("epochs") model_name = get_model_name(input_data.get("code", input_data.get("Code", "026"))) RAW_YAML_PATH = "/home/th/jcq/AI_AutoPlat/AI_web_dsj/config/yaml" yaml_path = os.path.join(RAW_YAML_PATH,f"{model_name}.yaml") print("####line137",yaml_path) # minio_url = "http://172.18.0.129:8084/th-dsp/027123" # 获取URL minio_url = input_data.get("TrainImageDir") if not minio_url: print("Error: TrainImageDir not specified in input_data") else: # 指定本地存储目录 local_base_dir = "/home/th/jcq/AI_AutoPlat/DATA/TrainDatasets" # 执行下载 success, result = download_minio_directory(minio_url, local_base_dir) if success: print(f"Successfully downloaded to: {result}") else: print(f"Download failed: {result}") if success: print("line277",f"Files downloaded successfully to: {result}") else: print(f"Download failed: {result}") TrainImageDir = result # 创建 train 和 val 文件夹 train_dir = os.path.join(TrainImageDir, "train") val_dir = os.path.join(TrainImageDir, "val") os.makedirs(train_dir, exist_ok=True) os.makedirs(val_dir, exist_ok=True) # 划分数据集 train_files, val_files = split_dataset(TrainImageDir) # 复制文件到 train/ 和 val/ 文件夹 copy_files_to_folders(train_files, TrainImageDir, train_dir) copy_files_to_folders(val_files, TrainImageDir, val_dir) # 写入训练集和验证集txt文件 train_txt = os.path.join(TrainImageDir, "train.txt") val_txt = os.path.join(TrainImageDir, "val.txt") write_file_list(train_files, train_txt, train_dir) write_file_list(val_files, val_txt, val_dir) # 更新yaml文件 nc = 5 names = ['flag', 'buoy', 'shipname', 'ship', 'uncover'] update_yaml(yaml_path, train_txt, val_txt, nc, names) print(f"数据集划分完成,训练集: {len(train_files)} 张,验证集: {len(val_files)} 张") print(f"训练集列表已写入: {train_txt}") print(f"验证集列表已写入: {val_txt}") print(f"YAML文件已更新: {yaml_path}") RAW_WEIGHTS_PATH = "/home/th/jcq/AIlib/weights_init" model_path = os.path.join(RAW_WEIGHTS_PATH,model_name,"yolov5.pt") print("####line98",model_path) transformed["TrainParameter"] = str([ batch_size, img_size, epochs, yaml_path, model_path ]) return transformed def validate_training_request(data): """Validate the training request structure""" required_fields = { "Scene": str, "Command": str, "Version": str, "Name": str, "Model": str, "TrainParameter": str } for field, field_type in required_fields.items(): if field not in data: return False, f"Missing required field: {field}" if not isinstance(data[field], field_type): return False, f"Field {field} must be {field_type.__name__}" if data["Scene"] not in ["Train"]: return False, "Invalid Scene value" if data["Command"] not in ["start", "stop"]: return False, "Invalid Command value" return True, "" def send_to_kafka(topic, message): """Send message to Kafka topic""" producer = None try: producer = KafkaProducer(**KAFKA_CONFIG) future = producer.send(topic, value=message) # Wait for message to be delivered result = future.get(timeout=10) logger.info(f"Message sent to partition {result.partition}, offset {result.offset}") return True, "" except Exception as e: logger.error(f"Failed to send message: {e}") return False, str(e) finally: if producer: producer.flush() producer.close() @app.route('/api/train', methods=['POST']) def handle_training_request(): """HTTP endpoint for training requests (non-blocking)""" try: # Parse and validate basic request structure data = request.get_json() if not data or not isinstance(data, dict): return jsonify({"status": "error", "message": "Invalid JSON data"}), 400 # Generate request ID request_id = generate_request_id() # Start background process (daemon=True ensures process exits with main) p = Process( target=process_training_task, args=(data, request_id), daemon=True ) p.start() # Immediate response return jsonify({ "status": "success", "message": "Training task started", "request_id": request_id, "process_id": p.pid }), 202 # 202 Accepted except Exception as e: logger.error(f"Request handling error: {str(e)}") return jsonify({ "status": "error", "message": "Internal server error" }), 500 def process_training_task(data, request_id): """独立进程执行耗时任务""" try: logger.info(f"[{request_id}] Starting background task (PID: {os.getpid()})") start_time = time.time() # 1. 数据转换 transformed = transform_message_format(data) transformed["Request_ID"] = request_id # 2. 数据验证 is_valid, msg = validate_training_request(transformed) if not is_valid: raise ValueError(msg) # 3. 发送到Kafka success, error = send_to_kafka(TOPIC_TRAIN, transformed) if not success: raise RuntimeError(error) elapsed = time.time() - start_time logger.info(f"[{request_id}] Task completed in {elapsed:.2f}s") except Exception as e: logger.error(f"[{request_id}] Task failed: {str(e)}") if __name__ == '__main__': app.run(host='192.168.10.11', port=6688, debug=True)