463 lines
13 KiB
Python
Executable File
463 lines
13 KiB
Python
Executable File
# -*- coding: utf-8 -*-
|
||
"""
|
||
HTTP to Kafka proxy service for training requests with message transformation
|
||
|
||
Features:
|
||
- Receives HTTP POST requests with training parameters
|
||
- Transforms message format from incoming to Kafka-compatible format
|
||
- Validates the request format
|
||
- Forwards valid requests to Kafka
|
||
- Returns appropriate HTTP responses
|
||
"""
|
||
|
||
from flask import Flask, request, jsonify
|
||
import os
|
||
from kafka import KafkaProducer
|
||
import json
|
||
import random
|
||
import string
|
||
import logging
|
||
from multiprocessing import Process
|
||
import time
|
||
|
||
import os
|
||
import random
|
||
import yaml
|
||
from shutil import copyfile
|
||
|
||
from minio import Minio
|
||
from minio.error import S3Error
|
||
from minio import Minio
|
||
from minio.error import S3Error
|
||
import os
|
||
import urllib.parse
|
||
from minio import Minio
|
||
from minio.error import S3Error
|
||
import os
|
||
import urllib.parse
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
app = Flask(__name__)
|
||
|
||
# Configure logging
|
||
logging.basicConfig(level=logging.INFO)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# Kafka configuration
|
||
KAFKA_CONFIG = {
|
||
'bootstrap_servers': '172.18.0.126:9094',
|
||
'value_serializer': lambda v: json.dumps(v).encode('utf-8')
|
||
}
|
||
|
||
# Kafka topics
|
||
TOPIC_TRAIN = 'training-tasks'
|
||
|
||
model_dict = {
|
||
"001": "AnglerSwimmer",
|
||
"002": "channel2",
|
||
"003": "channelEmergency",
|
||
"004": "cityMangement",
|
||
"005": "cityMangement2",
|
||
"006": "cityMangement3",
|
||
"007": "cityRoad",
|
||
"008": "conf",
|
||
"009": "countryRoad",
|
||
"010": "crackMeasurement",
|
||
"011": "crowdCounting",
|
||
"012": "drowning",
|
||
"013": "forest",
|
||
"014": "forest2",
|
||
"015": "forestCrowd",
|
||
"016": "highWay2",
|
||
"017": "illParking",
|
||
"018": "noParking",
|
||
"019": "ocr2",
|
||
"020": "ocr_en",
|
||
"021": "pedestrian",
|
||
"022": "pothole",
|
||
"023": "river",
|
||
"024": "river2",
|
||
"025": "riverT",
|
||
"026": "road",
|
||
"027": "ship",
|
||
"028": "ship2",
|
||
"029": "smogfire",
|
||
"030": "trafficAccident",
|
||
"031": "vehicle"
|
||
}
|
||
|
||
|
||
def validate_minio_url(url):
|
||
"""验证MinIO URL格式是否正确"""
|
||
try:
|
||
parsed = urllib.parse.urlparse(url)
|
||
if not all([parsed.scheme in ('http', 'https'), parsed.netloc, len(parsed.path.split('/')) >= 3]):
|
||
raise ValueError("Invalid MinIO URL format")
|
||
return parsed
|
||
except Exception as e:
|
||
raise ValueError(f"URL validation failed: {str(e)}")
|
||
|
||
def download_minio_directory(minio_url, local_base_dir):
|
||
"""
|
||
下载MinIO目录到本地
|
||
|
||
:param minio_url: 完整的MinIO路径 (e.g. "http://172.18.0.129:8084/th-dsp/027123")
|
||
:param local_base_dir: 本地存储基础目录 (e.g. "/home/th/jcq/AI_AutoPlat/DATA/TrainDatasets")
|
||
:return: (success, message) 元组
|
||
"""
|
||
try:
|
||
# 验证并解析URL
|
||
parsed = validate_minio_url(minio_url)
|
||
endpoint = parsed.netloc
|
||
bucket_name = parsed.path.split('/')[1]
|
||
prefix = '/'.join(parsed.path.split('/')[2:])
|
||
|
||
# 从环境变量获取凭证(推荐方式)
|
||
access_key = os.getenv("MINIO_ACCESS_KEY", "IKf3A0ZSXsR1m0oalMjV") # 默认值仅用于测试
|
||
secret_key = os.getenv("MINIO_SECRET_KEY", "yoC6qRo2hlyZu8Pdbt6eh9TVaTV4gD7KRudromrk")
|
||
|
||
# 创建本地目录(使用prefix的最后部分作为目录名)
|
||
dir_name = os.path.basename(prefix) or "minio_download"
|
||
local_dir = os.path.join(local_base_dir, dir_name)
|
||
os.makedirs(local_dir, exist_ok=True)
|
||
|
||
# 初始化MinIO客户端
|
||
client = Minio(
|
||
endpoint,
|
||
access_key=access_key,
|
||
secret_key=secret_key,
|
||
secure=parsed.scheme == 'https'
|
||
)
|
||
|
||
print(f"Starting download from {minio_url} to {local_dir}")
|
||
|
||
# 递归下载所有文件
|
||
objects = client.list_objects(bucket_name, prefix=prefix, recursive=True)
|
||
downloaded_files = 0
|
||
|
||
for obj in objects:
|
||
# 构建本地路径(保持相对目录结构)
|
||
relative_path = os.path.relpath(obj.object_name, prefix)
|
||
local_path = os.path.join(local_dir, relative_path)
|
||
|
||
# 确保目录存在
|
||
os.makedirs(os.path.dirname(local_path), exist_ok=True)
|
||
|
||
# 下载文件
|
||
client.fget_object(bucket_name, obj.object_name, local_path)
|
||
downloaded_files += 1
|
||
|
||
print(f"Downloaded [{downloaded_files}]: {obj.object_name} -> {local_path}")
|
||
|
||
print(f"Download completed! Total files: {downloaded_files}")
|
||
return True, local_dir
|
||
|
||
except S3Error as e:
|
||
error_msg = f"MinIO Error: {str(e)}"
|
||
if "NoSuchBucket" in str(e):
|
||
error_msg = f"Bucket not found: {bucket_name}"
|
||
return False, error_msg
|
||
except ValueError as e:
|
||
return False, f"Invalid URL: {str(e)}"
|
||
except Exception as e:
|
||
return False, f"Unexpected error: {str(e)}"
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
def split_dataset(data_dir, train_ratio=0.8):
|
||
# 获取所有jpg文件
|
||
jpg_files = [f for f in os.listdir(data_dir) if f.endswith('.jpg')]
|
||
|
||
# 检查对应的txt文件是否存在
|
||
valid_files = []
|
||
for jpg in jpg_files:
|
||
txt = jpg.replace('.jpg', '.txt')
|
||
if os.path.exists(os.path.join(data_dir, txt)):
|
||
valid_files.append(jpg)
|
||
|
||
# 随机打乱
|
||
random.shuffle(valid_files)
|
||
|
||
# 划分训练集和验证集
|
||
split_idx = int(len(valid_files) * train_ratio)
|
||
train_files = valid_files[:split_idx]
|
||
val_files = valid_files[split_idx:]
|
||
|
||
return train_files, val_files
|
||
|
||
def write_file_list(file_list, output_file, base_dir):
|
||
with open(output_file, 'w') as f:
|
||
for file in file_list:
|
||
f.write(os.path.join(base_dir, file) + '\n')
|
||
|
||
def copy_files_to_folders(files, src_dir, dest_dir):
|
||
"""复制图片和标签文件到目标文件夹"""
|
||
os.makedirs(dest_dir, exist_ok=True)
|
||
for file in files:
|
||
# 复制图片
|
||
src_img = os.path.join(src_dir, file)
|
||
dest_img = os.path.join(dest_dir, file)
|
||
copyfile(src_img, dest_img)
|
||
|
||
# 复制对应的标签文件
|
||
src_txt = os.path.join(src_dir, file.replace('.jpg', '.txt'))
|
||
dest_txt = os.path.join(dest_dir, file.replace('.jpg', '.txt'))
|
||
copyfile(src_txt, dest_txt)
|
||
|
||
def update_yaml(yaml_file, train_txt, val_txt, nc, names):
|
||
with open(yaml_file, 'r') as f:
|
||
data = yaml.safe_load(f)
|
||
|
||
data['train'] = train_txt
|
||
data['val'] = val_txt
|
||
data['nc'] = nc
|
||
data['names'] = names
|
||
|
||
with open(yaml_file, 'w') as f:
|
||
yaml.dump(data, f, default_flow_style=False)
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
def get_model_name(code):
|
||
# Ensure code is a 3-digit string
|
||
code_str = f"{int(code):03d}" if str(code).isdigit() else code
|
||
return model_dict.get(code_str, f"Unknown code: {code}")
|
||
|
||
|
||
def generate_request_id(length=30):
|
||
"""Generate a unique request ID"""
|
||
return 'bb' + ''.join(random.sample(string.ascii_letters, length))
|
||
|
||
def transform_message_format(input_data):
|
||
|
||
transformed = {
|
||
"Scene": input_data.get("ModelScenetype", "train").capitalize(),
|
||
"Command": "start", # Default command
|
||
"Request_ID": input_data.get("request_id", generate_request_id()),
|
||
"Version": input_data.get("Version", "v1.0"),
|
||
"Name": "channel2", # Default name if not provided
|
||
"Model": input_data.get("code", input_data.get("Code", "026")),
|
||
}
|
||
|
||
|
||
# Handle TrainParameter construction
|
||
params = input_data.get("parameters", {})
|
||
batch_size = params.get("batch_size")
|
||
img_size = params.get("img_size")
|
||
epochs = params.get("epochs")
|
||
|
||
|
||
model_name = get_model_name(input_data.get("code", input_data.get("Code", "026")))
|
||
|
||
RAW_YAML_PATH = "/home/th/jcq/AI_AutoPlat/AI_web_dsj/config/yaml"
|
||
|
||
yaml_path = os.path.join(RAW_YAML_PATH,f"{model_name}.yaml")
|
||
|
||
|
||
print("####line137",yaml_path)
|
||
|
||
# minio_url = "http://172.18.0.129:8084/th-dsp/027123"
|
||
# 获取URL
|
||
minio_url = input_data.get("TrainImageDir")
|
||
if not minio_url:
|
||
print("Error: TrainImageDir not specified in input_data")
|
||
else:
|
||
# 指定本地存储目录
|
||
local_base_dir = "/home/th/jcq/AI_AutoPlat/DATA/TrainDatasets"
|
||
|
||
# 执行下载
|
||
success, result = download_minio_directory(minio_url, local_base_dir)
|
||
|
||
if success:
|
||
print(f"Successfully downloaded to: {result}")
|
||
else:
|
||
print(f"Download failed: {result}")
|
||
|
||
|
||
|
||
|
||
|
||
if success:
|
||
print("line277",f"Files downloaded successfully to: {result}")
|
||
else:
|
||
print(f"Download failed: {result}")
|
||
|
||
TrainImageDir = result
|
||
|
||
|
||
# 创建 train 和 val 文件夹
|
||
train_dir = os.path.join(TrainImageDir, "train")
|
||
val_dir = os.path.join(TrainImageDir, "val")
|
||
os.makedirs(train_dir, exist_ok=True)
|
||
os.makedirs(val_dir, exist_ok=True)
|
||
|
||
# 划分数据集
|
||
train_files, val_files = split_dataset(TrainImageDir)
|
||
|
||
# 复制文件到 train/ 和 val/ 文件夹
|
||
copy_files_to_folders(train_files, TrainImageDir, train_dir)
|
||
copy_files_to_folders(val_files, TrainImageDir, val_dir)
|
||
|
||
# 写入训练集和验证集txt文件
|
||
train_txt = os.path.join(TrainImageDir, "train.txt")
|
||
val_txt = os.path.join(TrainImageDir, "val.txt")
|
||
|
||
write_file_list(train_files, train_txt, train_dir)
|
||
write_file_list(val_files, val_txt, val_dir)
|
||
|
||
# 更新yaml文件
|
||
nc = 5
|
||
names = ['flag', 'buoy', 'shipname', 'ship', 'uncover']
|
||
update_yaml(yaml_path, train_txt, val_txt, nc, names)
|
||
|
||
print(f"数据集划分完成,训练集: {len(train_files)} 张,验证集: {len(val_files)} 张")
|
||
print(f"训练集列表已写入: {train_txt}")
|
||
print(f"验证集列表已写入: {val_txt}")
|
||
print(f"YAML文件已更新: {yaml_path}")
|
||
|
||
|
||
|
||
|
||
|
||
RAW_WEIGHTS_PATH = "/home/th/jcq/AIlib/weights_init"
|
||
|
||
model_path = os.path.join(RAW_WEIGHTS_PATH,model_name,"yolov5.pt")
|
||
print("####line98",model_path)
|
||
|
||
|
||
|
||
|
||
transformed["TrainParameter"] = str([
|
||
batch_size,
|
||
img_size,
|
||
epochs,
|
||
yaml_path,
|
||
model_path
|
||
])
|
||
|
||
return transformed
|
||
|
||
def validate_training_request(data):
|
||
"""Validate the training request structure"""
|
||
required_fields = {
|
||
"Scene": str,
|
||
"Command": str,
|
||
"Version": str,
|
||
"Name": str,
|
||
"Model": str,
|
||
"TrainParameter": str
|
||
}
|
||
|
||
for field, field_type in required_fields.items():
|
||
if field not in data:
|
||
return False, f"Missing required field: {field}"
|
||
if not isinstance(data[field], field_type):
|
||
return False, f"Field {field} must be {field_type.__name__}"
|
||
|
||
if data["Scene"] not in ["Train"]:
|
||
return False, "Invalid Scene value"
|
||
|
||
if data["Command"] not in ["start", "stop"]:
|
||
return False, "Invalid Command value"
|
||
|
||
return True, ""
|
||
|
||
def send_to_kafka(topic, message):
|
||
"""Send message to Kafka topic"""
|
||
producer = None
|
||
try:
|
||
producer = KafkaProducer(**KAFKA_CONFIG)
|
||
future = producer.send(topic, value=message)
|
||
|
||
# Wait for message to be delivered
|
||
result = future.get(timeout=10)
|
||
logger.info(f"Message sent to partition {result.partition}, offset {result.offset}")
|
||
return True, ""
|
||
except Exception as e:
|
||
logger.error(f"Failed to send message: {e}")
|
||
return False, str(e)
|
||
finally:
|
||
if producer:
|
||
producer.flush()
|
||
producer.close()
|
||
|
||
@app.route('/api/train', methods=['POST'])
|
||
def handle_training_request():
|
||
"""HTTP endpoint for training requests (non-blocking)"""
|
||
try:
|
||
# Parse and validate basic request structure
|
||
data = request.get_json()
|
||
if not data or not isinstance(data, dict):
|
||
return jsonify({"status": "error", "message": "Invalid JSON data"}), 400
|
||
|
||
# Generate request ID
|
||
request_id = generate_request_id()
|
||
|
||
# Start background process (daemon=True ensures process exits with main)
|
||
p = Process(
|
||
target=process_training_task,
|
||
args=(data, request_id),
|
||
daemon=True
|
||
)
|
||
p.start()
|
||
|
||
# Immediate response
|
||
return jsonify({
|
||
"status": "success",
|
||
"message": "Training task started",
|
||
"request_id": request_id,
|
||
"process_id": p.pid
|
||
}), 202 # 202 Accepted
|
||
|
||
except Exception as e:
|
||
logger.error(f"Request handling error: {str(e)}")
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": "Internal server error"
|
||
}), 500
|
||
|
||
def process_training_task(data, request_id):
|
||
"""独立进程执行耗时任务"""
|
||
try:
|
||
logger.info(f"[{request_id}] Starting background task (PID: {os.getpid()})")
|
||
start_time = time.time()
|
||
|
||
# 1. 数据转换
|
||
transformed = transform_message_format(data)
|
||
transformed["Request_ID"] = request_id
|
||
|
||
# 2. 数据验证
|
||
is_valid, msg = validate_training_request(transformed)
|
||
if not is_valid:
|
||
raise ValueError(msg)
|
||
|
||
# 3. 发送到Kafka
|
||
success, error = send_to_kafka(TOPIC_TRAIN, transformed)
|
||
if not success:
|
||
raise RuntimeError(error)
|
||
|
||
elapsed = time.time() - start_time
|
||
logger.info(f"[{request_id}] Task completed in {elapsed:.2f}s")
|
||
|
||
except Exception as e:
|
||
logger.error(f"[{request_id}] Task failed: {str(e)}")
|
||
|
||
if __name__ == '__main__':
|
||
app.run(host='192.168.10.11', port=6688, debug=True) |