463 lines
13 KiB
Python
463 lines
13 KiB
Python
|
|
# -*- coding: utf-8 -*-
|
|||
|
|
"""
|
|||
|
|
HTTP to Kafka proxy service for training requests with message transformation
|
|||
|
|
|
|||
|
|
Features:
|
|||
|
|
- Receives HTTP POST requests with training parameters
|
|||
|
|
- Transforms message format from incoming to Kafka-compatible format
|
|||
|
|
- Validates the request format
|
|||
|
|
- Forwards valid requests to Kafka
|
|||
|
|
- Returns appropriate HTTP responses
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from flask import Flask, request, jsonify
|
|||
|
|
import os
|
|||
|
|
from kafka import KafkaProducer
|
|||
|
|
import json
|
|||
|
|
import random
|
|||
|
|
import string
|
|||
|
|
import logging
|
|||
|
|
from multiprocessing import Process
|
|||
|
|
import time
|
|||
|
|
|
|||
|
|
import os
|
|||
|
|
import random
|
|||
|
|
import yaml
|
|||
|
|
from shutil import copyfile
|
|||
|
|
|
|||
|
|
from minio import Minio
|
|||
|
|
from minio.error import S3Error
|
|||
|
|
from minio import Minio
|
|||
|
|
from minio.error import S3Error
|
|||
|
|
import os
|
|||
|
|
import urllib.parse
|
|||
|
|
from minio import Minio
|
|||
|
|
from minio.error import S3Error
|
|||
|
|
import os
|
|||
|
|
import urllib.parse
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
app = Flask(__name__)
|
|||
|
|
|
|||
|
|
# Configure logging
|
|||
|
|
logging.basicConfig(level=logging.INFO)
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
# Kafka configuration
|
|||
|
|
KAFKA_CONFIG = {
|
|||
|
|
'bootstrap_servers': '172.18.0.126:9094',
|
|||
|
|
'value_serializer': lambda v: json.dumps(v).encode('utf-8')
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# Kafka topics
|
|||
|
|
TOPIC_TRAIN = 'training-tasks'
|
|||
|
|
|
|||
|
|
model_dict = {
|
|||
|
|
"001": "AnglerSwimmer",
|
|||
|
|
"002": "channel2",
|
|||
|
|
"003": "channelEmergency",
|
|||
|
|
"004": "cityMangement",
|
|||
|
|
"005": "cityMangement2",
|
|||
|
|
"006": "cityMangement3",
|
|||
|
|
"007": "cityRoad",
|
|||
|
|
"008": "conf",
|
|||
|
|
"009": "countryRoad",
|
|||
|
|
"010": "crackMeasurement",
|
|||
|
|
"011": "crowdCounting",
|
|||
|
|
"012": "drowning",
|
|||
|
|
"013": "forest",
|
|||
|
|
"014": "forest2",
|
|||
|
|
"015": "forestCrowd",
|
|||
|
|
"016": "highWay2",
|
|||
|
|
"017": "illParking",
|
|||
|
|
"018": "noParking",
|
|||
|
|
"019": "ocr2",
|
|||
|
|
"020": "ocr_en",
|
|||
|
|
"021": "pedestrian",
|
|||
|
|
"022": "pothole",
|
|||
|
|
"023": "river",
|
|||
|
|
"024": "river2",
|
|||
|
|
"025": "riverT",
|
|||
|
|
"026": "road",
|
|||
|
|
"027": "ship",
|
|||
|
|
"028": "ship2",
|
|||
|
|
"029": "smogfire",
|
|||
|
|
"030": "trafficAccident",
|
|||
|
|
"031": "vehicle"
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
def validate_minio_url(url):
|
|||
|
|
"""验证MinIO URL格式是否正确"""
|
|||
|
|
try:
|
|||
|
|
parsed = urllib.parse.urlparse(url)
|
|||
|
|
if not all([parsed.scheme in ('http', 'https'), parsed.netloc, len(parsed.path.split('/')) >= 3]):
|
|||
|
|
raise ValueError("Invalid MinIO URL format")
|
|||
|
|
return parsed
|
|||
|
|
except Exception as e:
|
|||
|
|
raise ValueError(f"URL validation failed: {str(e)}")
|
|||
|
|
|
|||
|
|
def download_minio_directory(minio_url, local_base_dir):
|
|||
|
|
"""
|
|||
|
|
下载MinIO目录到本地
|
|||
|
|
|
|||
|
|
:param minio_url: 完整的MinIO路径 (e.g. "http://172.18.0.129:8084/th-dsp/027123")
|
|||
|
|
:param local_base_dir: 本地存储基础目录 (e.g. "/home/th/jcq/AI_AutoPlat/DATA/TrainDatasets")
|
|||
|
|
:return: (success, message) 元组
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
# 验证并解析URL
|
|||
|
|
parsed = validate_minio_url(minio_url)
|
|||
|
|
endpoint = parsed.netloc
|
|||
|
|
bucket_name = parsed.path.split('/')[1]
|
|||
|
|
prefix = '/'.join(parsed.path.split('/')[2:])
|
|||
|
|
|
|||
|
|
# 从环境变量获取凭证(推荐方式)
|
|||
|
|
access_key = os.getenv("MINIO_ACCESS_KEY", "IKf3A0ZSXsR1m0oalMjV") # 默认值仅用于测试
|
|||
|
|
secret_key = os.getenv("MINIO_SECRET_KEY", "yoC6qRo2hlyZu8Pdbt6eh9TVaTV4gD7KRudromrk")
|
|||
|
|
|
|||
|
|
# 创建本地目录(使用prefix的最后部分作为目录名)
|
|||
|
|
dir_name = os.path.basename(prefix) or "minio_download"
|
|||
|
|
local_dir = os.path.join(local_base_dir, dir_name)
|
|||
|
|
os.makedirs(local_dir, exist_ok=True)
|
|||
|
|
|
|||
|
|
# 初始化MinIO客户端
|
|||
|
|
client = Minio(
|
|||
|
|
endpoint,
|
|||
|
|
access_key=access_key,
|
|||
|
|
secret_key=secret_key,
|
|||
|
|
secure=parsed.scheme == 'https'
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
print(f"Starting download from {minio_url} to {local_dir}")
|
|||
|
|
|
|||
|
|
# 递归下载所有文件
|
|||
|
|
objects = client.list_objects(bucket_name, prefix=prefix, recursive=True)
|
|||
|
|
downloaded_files = 0
|
|||
|
|
|
|||
|
|
for obj in objects:
|
|||
|
|
# 构建本地路径(保持相对目录结构)
|
|||
|
|
relative_path = os.path.relpath(obj.object_name, prefix)
|
|||
|
|
local_path = os.path.join(local_dir, relative_path)
|
|||
|
|
|
|||
|
|
# 确保目录存在
|
|||
|
|
os.makedirs(os.path.dirname(local_path), exist_ok=True)
|
|||
|
|
|
|||
|
|
# 下载文件
|
|||
|
|
client.fget_object(bucket_name, obj.object_name, local_path)
|
|||
|
|
downloaded_files += 1
|
|||
|
|
|
|||
|
|
print(f"Downloaded [{downloaded_files}]: {obj.object_name} -> {local_path}")
|
|||
|
|
|
|||
|
|
print(f"Download completed! Total files: {downloaded_files}")
|
|||
|
|
return True, local_dir
|
|||
|
|
|
|||
|
|
except S3Error as e:
|
|||
|
|
error_msg = f"MinIO Error: {str(e)}"
|
|||
|
|
if "NoSuchBucket" in str(e):
|
|||
|
|
error_msg = f"Bucket not found: {bucket_name}"
|
|||
|
|
return False, error_msg
|
|||
|
|
except ValueError as e:
|
|||
|
|
return False, f"Invalid URL: {str(e)}"
|
|||
|
|
except Exception as e:
|
|||
|
|
return False, f"Unexpected error: {str(e)}"
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
def split_dataset(data_dir, train_ratio=0.8):
|
|||
|
|
# 获取所有jpg文件
|
|||
|
|
jpg_files = [f for f in os.listdir(data_dir) if f.endswith('.jpg')]
|
|||
|
|
|
|||
|
|
# 检查对应的txt文件是否存在
|
|||
|
|
valid_files = []
|
|||
|
|
for jpg in jpg_files:
|
|||
|
|
txt = jpg.replace('.jpg', '.txt')
|
|||
|
|
if os.path.exists(os.path.join(data_dir, txt)):
|
|||
|
|
valid_files.append(jpg)
|
|||
|
|
|
|||
|
|
# 随机打乱
|
|||
|
|
random.shuffle(valid_files)
|
|||
|
|
|
|||
|
|
# 划分训练集和验证集
|
|||
|
|
split_idx = int(len(valid_files) * train_ratio)
|
|||
|
|
train_files = valid_files[:split_idx]
|
|||
|
|
val_files = valid_files[split_idx:]
|
|||
|
|
|
|||
|
|
return train_files, val_files
|
|||
|
|
|
|||
|
|
def write_file_list(file_list, output_file, base_dir):
|
|||
|
|
with open(output_file, 'w') as f:
|
|||
|
|
for file in file_list:
|
|||
|
|
f.write(os.path.join(base_dir, file) + '\n')
|
|||
|
|
|
|||
|
|
def copy_files_to_folders(files, src_dir, dest_dir):
|
|||
|
|
"""复制图片和标签文件到目标文件夹"""
|
|||
|
|
os.makedirs(dest_dir, exist_ok=True)
|
|||
|
|
for file in files:
|
|||
|
|
# 复制图片
|
|||
|
|
src_img = os.path.join(src_dir, file)
|
|||
|
|
dest_img = os.path.join(dest_dir, file)
|
|||
|
|
copyfile(src_img, dest_img)
|
|||
|
|
|
|||
|
|
# 复制对应的标签文件
|
|||
|
|
src_txt = os.path.join(src_dir, file.replace('.jpg', '.txt'))
|
|||
|
|
dest_txt = os.path.join(dest_dir, file.replace('.jpg', '.txt'))
|
|||
|
|
copyfile(src_txt, dest_txt)
|
|||
|
|
|
|||
|
|
def update_yaml(yaml_file, train_txt, val_txt, nc, names):
|
|||
|
|
with open(yaml_file, 'r') as f:
|
|||
|
|
data = yaml.safe_load(f)
|
|||
|
|
|
|||
|
|
data['train'] = train_txt
|
|||
|
|
data['val'] = val_txt
|
|||
|
|
data['nc'] = nc
|
|||
|
|
data['names'] = names
|
|||
|
|
|
|||
|
|
with open(yaml_file, 'w') as f:
|
|||
|
|
yaml.dump(data, f, default_flow_style=False)
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_model_name(code):
|
|||
|
|
# Ensure code is a 3-digit string
|
|||
|
|
code_str = f"{int(code):03d}" if str(code).isdigit() else code
|
|||
|
|
return model_dict.get(code_str, f"Unknown code: {code}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def generate_request_id(length=30):
|
|||
|
|
"""Generate a unique request ID"""
|
|||
|
|
return 'bb' + ''.join(random.sample(string.ascii_letters, length))
|
|||
|
|
|
|||
|
|
def transform_message_format(input_data):
|
|||
|
|
|
|||
|
|
transformed = {
|
|||
|
|
"Scene": input_data.get("ModelScenetype", "train").capitalize(),
|
|||
|
|
"Command": "start", # Default command
|
|||
|
|
"Request_ID": input_data.get("request_id", generate_request_id()),
|
|||
|
|
"Version": input_data.get("Version", "v1.0"),
|
|||
|
|
"Name": "channel2", # Default name if not provided
|
|||
|
|
"Model": input_data.get("code", input_data.get("Code", "026")),
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
# Handle TrainParameter construction
|
|||
|
|
params = input_data.get("parameters", {})
|
|||
|
|
batch_size = params.get("batch_size")
|
|||
|
|
img_size = params.get("img_size")
|
|||
|
|
epochs = params.get("epochs")
|
|||
|
|
|
|||
|
|
|
|||
|
|
model_name = get_model_name(input_data.get("code", input_data.get("Code", "026")))
|
|||
|
|
|
|||
|
|
RAW_YAML_PATH = "/home/th/jcq/AI_AutoPlat/AI_web_dsj/config/yaml"
|
|||
|
|
|
|||
|
|
yaml_path = os.path.join(RAW_YAML_PATH,f"{model_name}.yaml")
|
|||
|
|
|
|||
|
|
|
|||
|
|
print("####line137",yaml_path)
|
|||
|
|
|
|||
|
|
# minio_url = "http://172.18.0.129:8084/th-dsp/027123"
|
|||
|
|
# 获取URL
|
|||
|
|
minio_url = input_data.get("TrainImageDir")
|
|||
|
|
if not minio_url:
|
|||
|
|
print("Error: TrainImageDir not specified in input_data")
|
|||
|
|
else:
|
|||
|
|
# 指定本地存储目录
|
|||
|
|
local_base_dir = "/home/th/jcq/AI_AutoPlat/DATA/TrainDatasets"
|
|||
|
|
|
|||
|
|
# 执行下载
|
|||
|
|
success, result = download_minio_directory(minio_url, local_base_dir)
|
|||
|
|
|
|||
|
|
if success:
|
|||
|
|
print(f"Successfully downloaded to: {result}")
|
|||
|
|
else:
|
|||
|
|
print(f"Download failed: {result}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
if success:
|
|||
|
|
print("line277",f"Files downloaded successfully to: {result}")
|
|||
|
|
else:
|
|||
|
|
print(f"Download failed: {result}")
|
|||
|
|
|
|||
|
|
TrainImageDir = result
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 创建 train 和 val 文件夹
|
|||
|
|
train_dir = os.path.join(TrainImageDir, "train")
|
|||
|
|
val_dir = os.path.join(TrainImageDir, "val")
|
|||
|
|
os.makedirs(train_dir, exist_ok=True)
|
|||
|
|
os.makedirs(val_dir, exist_ok=True)
|
|||
|
|
|
|||
|
|
# 划分数据集
|
|||
|
|
train_files, val_files = split_dataset(TrainImageDir)
|
|||
|
|
|
|||
|
|
# 复制文件到 train/ 和 val/ 文件夹
|
|||
|
|
copy_files_to_folders(train_files, TrainImageDir, train_dir)
|
|||
|
|
copy_files_to_folders(val_files, TrainImageDir, val_dir)
|
|||
|
|
|
|||
|
|
# 写入训练集和验证集txt文件
|
|||
|
|
train_txt = os.path.join(TrainImageDir, "train.txt")
|
|||
|
|
val_txt = os.path.join(TrainImageDir, "val.txt")
|
|||
|
|
|
|||
|
|
write_file_list(train_files, train_txt, train_dir)
|
|||
|
|
write_file_list(val_files, val_txt, val_dir)
|
|||
|
|
|
|||
|
|
# 更新yaml文件
|
|||
|
|
nc = 5
|
|||
|
|
names = ['flag', 'buoy', 'shipname', 'ship', 'uncover']
|
|||
|
|
update_yaml(yaml_path, train_txt, val_txt, nc, names)
|
|||
|
|
|
|||
|
|
print(f"数据集划分完成,训练集: {len(train_files)} 张,验证集: {len(val_files)} 张")
|
|||
|
|
print(f"训练集列表已写入: {train_txt}")
|
|||
|
|
print(f"验证集列表已写入: {val_txt}")
|
|||
|
|
print(f"YAML文件已更新: {yaml_path}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
RAW_WEIGHTS_PATH = "/home/th/jcq/AIlib/weights_init"
|
|||
|
|
|
|||
|
|
model_path = os.path.join(RAW_WEIGHTS_PATH,model_name,"yolov5.pt")
|
|||
|
|
print("####line98",model_path)
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
transformed["TrainParameter"] = str([
|
|||
|
|
batch_size,
|
|||
|
|
img_size,
|
|||
|
|
epochs,
|
|||
|
|
yaml_path,
|
|||
|
|
model_path
|
|||
|
|
])
|
|||
|
|
|
|||
|
|
return transformed
|
|||
|
|
|
|||
|
|
def validate_training_request(data):
|
|||
|
|
"""Validate the training request structure"""
|
|||
|
|
required_fields = {
|
|||
|
|
"Scene": str,
|
|||
|
|
"Command": str,
|
|||
|
|
"Version": str,
|
|||
|
|
"Name": str,
|
|||
|
|
"Model": str,
|
|||
|
|
"TrainParameter": str
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
for field, field_type in required_fields.items():
|
|||
|
|
if field not in data:
|
|||
|
|
return False, f"Missing required field: {field}"
|
|||
|
|
if not isinstance(data[field], field_type):
|
|||
|
|
return False, f"Field {field} must be {field_type.__name__}"
|
|||
|
|
|
|||
|
|
if data["Scene"] not in ["Train"]:
|
|||
|
|
return False, "Invalid Scene value"
|
|||
|
|
|
|||
|
|
if data["Command"] not in ["start", "stop"]:
|
|||
|
|
return False, "Invalid Command value"
|
|||
|
|
|
|||
|
|
return True, ""
|
|||
|
|
|
|||
|
|
def send_to_kafka(topic, message):
|
|||
|
|
"""Send message to Kafka topic"""
|
|||
|
|
producer = None
|
|||
|
|
try:
|
|||
|
|
producer = KafkaProducer(**KAFKA_CONFIG)
|
|||
|
|
future = producer.send(topic, value=message)
|
|||
|
|
|
|||
|
|
# Wait for message to be delivered
|
|||
|
|
result = future.get(timeout=10)
|
|||
|
|
logger.info(f"Message sent to partition {result.partition}, offset {result.offset}")
|
|||
|
|
return True, ""
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"Failed to send message: {e}")
|
|||
|
|
return False, str(e)
|
|||
|
|
finally:
|
|||
|
|
if producer:
|
|||
|
|
producer.flush()
|
|||
|
|
producer.close()
|
|||
|
|
|
|||
|
|
@app.route('/api/train', methods=['POST'])
|
|||
|
|
def handle_training_request():
|
|||
|
|
"""HTTP endpoint for training requests (non-blocking)"""
|
|||
|
|
try:
|
|||
|
|
# Parse and validate basic request structure
|
|||
|
|
data = request.get_json()
|
|||
|
|
if not data or not isinstance(data, dict):
|
|||
|
|
return jsonify({"status": "error", "message": "Invalid JSON data"}), 400
|
|||
|
|
|
|||
|
|
# Generate request ID
|
|||
|
|
request_id = generate_request_id()
|
|||
|
|
|
|||
|
|
# Start background process (daemon=True ensures process exits with main)
|
|||
|
|
p = Process(
|
|||
|
|
target=process_training_task,
|
|||
|
|
args=(data, request_id),
|
|||
|
|
daemon=True
|
|||
|
|
)
|
|||
|
|
p.start()
|
|||
|
|
|
|||
|
|
# Immediate response
|
|||
|
|
return jsonify({
|
|||
|
|
"status": "success",
|
|||
|
|
"message": "Training task started",
|
|||
|
|
"request_id": request_id,
|
|||
|
|
"process_id": p.pid
|
|||
|
|
}), 202 # 202 Accepted
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"Request handling error: {str(e)}")
|
|||
|
|
return jsonify({
|
|||
|
|
"status": "error",
|
|||
|
|
"message": "Internal server error"
|
|||
|
|
}), 500
|
|||
|
|
|
|||
|
|
def process_training_task(data, request_id):
|
|||
|
|
"""独立进程执行耗时任务"""
|
|||
|
|
try:
|
|||
|
|
logger.info(f"[{request_id}] Starting background task (PID: {os.getpid()})")
|
|||
|
|
start_time = time.time()
|
|||
|
|
|
|||
|
|
# 1. 数据转换
|
|||
|
|
transformed = transform_message_format(data)
|
|||
|
|
transformed["Request_ID"] = request_id
|
|||
|
|
|
|||
|
|
# 2. 数据验证
|
|||
|
|
is_valid, msg = validate_training_request(transformed)
|
|||
|
|
if not is_valid:
|
|||
|
|
raise ValueError(msg)
|
|||
|
|
|
|||
|
|
# 3. 发送到Kafka
|
|||
|
|
success, error = send_to_kafka(TOPIC_TRAIN, transformed)
|
|||
|
|
if not success:
|
|||
|
|
raise RuntimeError(error)
|
|||
|
|
|
|||
|
|
elapsed = time.time() - start_time
|
|||
|
|
logger.info(f"[{request_id}] Task completed in {elapsed:.2f}s")
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"[{request_id}] Task failed: {str(e)}")
|
|||
|
|
|
|||
|
|
if __name__ == '__main__':
|
|||
|
|
app.run(host='192.168.10.11', port=6688, debug=True)
|