tuoheng_AIPlatform/yolov5-th/yolov5_kafka.py

463 lines
13 KiB
Python
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
HTTP to Kafka proxy service for training requests with message transformation
Features:
- Receives HTTP POST requests with training parameters
- Transforms message format from incoming to Kafka-compatible format
- Validates the request format
- Forwards valid requests to Kafka
- Returns appropriate HTTP responses
"""
from flask import Flask, request, jsonify
import os
from kafka import KafkaProducer
import json
import random
import string
import logging
from multiprocessing import Process
import time
import os
import random
import yaml
from shutil import copyfile
from minio import Minio
from minio.error import S3Error
from minio import Minio
from minio.error import S3Error
import os
import urllib.parse
from minio import Minio
from minio.error import S3Error
import os
import urllib.parse
app = Flask(__name__)
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Kafka configuration
KAFKA_CONFIG = {
'bootstrap_servers': '172.18.0.126:9094',
'value_serializer': lambda v: json.dumps(v).encode('utf-8')
}
# Kafka topics
TOPIC_TRAIN = 'training-tasks'
model_dict = {
"001": "AnglerSwimmer",
"002": "channel2",
"003": "channelEmergency",
"004": "cityMangement",
"005": "cityMangement2",
"006": "cityMangement3",
"007": "cityRoad",
"008": "conf",
"009": "countryRoad",
"010": "crackMeasurement",
"011": "crowdCounting",
"012": "drowning",
"013": "forest",
"014": "forest2",
"015": "forestCrowd",
"016": "highWay2",
"017": "illParking",
"018": "noParking",
"019": "ocr2",
"020": "ocr_en",
"021": "pedestrian",
"022": "pothole",
"023": "river",
"024": "river2",
"025": "riverT",
"026": "road",
"027": "ship",
"028": "ship2",
"029": "smogfire",
"030": "trafficAccident",
"031": "vehicle"
}
def validate_minio_url(url):
"""验证MinIO URL格式是否正确"""
try:
parsed = urllib.parse.urlparse(url)
if not all([parsed.scheme in ('http', 'https'), parsed.netloc, len(parsed.path.split('/')) >= 3]):
raise ValueError("Invalid MinIO URL format")
return parsed
except Exception as e:
raise ValueError(f"URL validation failed: {str(e)}")
def download_minio_directory(minio_url, local_base_dir):
"""
下载MinIO目录到本地
:param minio_url: 完整的MinIO路径 (e.g. "http://172.18.0.129:8084/th-dsp/027123")
:param local_base_dir: 本地存储基础目录 (e.g. "/home/th/jcq/AI_AutoPlat/DATA/TrainDatasets")
:return: (success, message) 元组
"""
try:
# 验证并解析URL
parsed = validate_minio_url(minio_url)
endpoint = parsed.netloc
bucket_name = parsed.path.split('/')[1]
prefix = '/'.join(parsed.path.split('/')[2:])
# 从环境变量获取凭证(推荐方式)
access_key = os.getenv("MINIO_ACCESS_KEY", "IKf3A0ZSXsR1m0oalMjV") # 默认值仅用于测试
secret_key = os.getenv("MINIO_SECRET_KEY", "yoC6qRo2hlyZu8Pdbt6eh9TVaTV4gD7KRudromrk")
# 创建本地目录使用prefix的最后部分作为目录名
dir_name = os.path.basename(prefix) or "minio_download"
local_dir = os.path.join(local_base_dir, dir_name)
os.makedirs(local_dir, exist_ok=True)
# 初始化MinIO客户端
client = Minio(
endpoint,
access_key=access_key,
secret_key=secret_key,
secure=parsed.scheme == 'https'
)
print(f"Starting download from {minio_url} to {local_dir}")
# 递归下载所有文件
objects = client.list_objects(bucket_name, prefix=prefix, recursive=True)
downloaded_files = 0
for obj in objects:
# 构建本地路径(保持相对目录结构)
relative_path = os.path.relpath(obj.object_name, prefix)
local_path = os.path.join(local_dir, relative_path)
# 确保目录存在
os.makedirs(os.path.dirname(local_path), exist_ok=True)
# 下载文件
client.fget_object(bucket_name, obj.object_name, local_path)
downloaded_files += 1
print(f"Downloaded [{downloaded_files}]: {obj.object_name} -> {local_path}")
print(f"Download completed! Total files: {downloaded_files}")
return True, local_dir
except S3Error as e:
error_msg = f"MinIO Error: {str(e)}"
if "NoSuchBucket" in str(e):
error_msg = f"Bucket not found: {bucket_name}"
return False, error_msg
except ValueError as e:
return False, f"Invalid URL: {str(e)}"
except Exception as e:
return False, f"Unexpected error: {str(e)}"
def split_dataset(data_dir, train_ratio=0.8):
# 获取所有jpg文件
jpg_files = [f for f in os.listdir(data_dir) if f.endswith('.jpg')]
# 检查对应的txt文件是否存在
valid_files = []
for jpg in jpg_files:
txt = jpg.replace('.jpg', '.txt')
if os.path.exists(os.path.join(data_dir, txt)):
valid_files.append(jpg)
# 随机打乱
random.shuffle(valid_files)
# 划分训练集和验证集
split_idx = int(len(valid_files) * train_ratio)
train_files = valid_files[:split_idx]
val_files = valid_files[split_idx:]
return train_files, val_files
def write_file_list(file_list, output_file, base_dir):
with open(output_file, 'w') as f:
for file in file_list:
f.write(os.path.join(base_dir, file) + '\n')
def copy_files_to_folders(files, src_dir, dest_dir):
"""复制图片和标签文件到目标文件夹"""
os.makedirs(dest_dir, exist_ok=True)
for file in files:
# 复制图片
src_img = os.path.join(src_dir, file)
dest_img = os.path.join(dest_dir, file)
copyfile(src_img, dest_img)
# 复制对应的标签文件
src_txt = os.path.join(src_dir, file.replace('.jpg', '.txt'))
dest_txt = os.path.join(dest_dir, file.replace('.jpg', '.txt'))
copyfile(src_txt, dest_txt)
def update_yaml(yaml_file, train_txt, val_txt, nc, names):
with open(yaml_file, 'r') as f:
data = yaml.safe_load(f)
data['train'] = train_txt
data['val'] = val_txt
data['nc'] = nc
data['names'] = names
with open(yaml_file, 'w') as f:
yaml.dump(data, f, default_flow_style=False)
def get_model_name(code):
# Ensure code is a 3-digit string
code_str = f"{int(code):03d}" if str(code).isdigit() else code
return model_dict.get(code_str, f"Unknown code: {code}")
def generate_request_id(length=30):
"""Generate a unique request ID"""
return 'bb' + ''.join(random.sample(string.ascii_letters, length))
def transform_message_format(input_data):
transformed = {
"Scene": input_data.get("ModelScenetype", "train").capitalize(),
"Command": "start", # Default command
"Request_ID": input_data.get("request_id", generate_request_id()),
"Version": input_data.get("Version", "v1.0"),
"Name": "channel2", # Default name if not provided
"Model": input_data.get("code", input_data.get("Code", "026")),
}
# Handle TrainParameter construction
params = input_data.get("parameters", {})
batch_size = params.get("batch_size")
img_size = params.get("img_size")
epochs = params.get("epochs")
model_name = get_model_name(input_data.get("code", input_data.get("Code", "026")))
RAW_YAML_PATH = "/home/th/jcq/AI_AutoPlat/AI_web_dsj/config/yaml"
yaml_path = os.path.join(RAW_YAML_PATH,f"{model_name}.yaml")
print("####line137",yaml_path)
# minio_url = "http://172.18.0.129:8084/th-dsp/027123"
# 获取URL
minio_url = input_data.get("TrainImageDir")
if not minio_url:
print("Error: TrainImageDir not specified in input_data")
else:
# 指定本地存储目录
local_base_dir = "/home/th/jcq/AI_AutoPlat/DATA/TrainDatasets"
# 执行下载
success, result = download_minio_directory(minio_url, local_base_dir)
if success:
print(f"Successfully downloaded to: {result}")
else:
print(f"Download failed: {result}")
if success:
print("line277",f"Files downloaded successfully to: {result}")
else:
print(f"Download failed: {result}")
TrainImageDir = result
# 创建 train 和 val 文件夹
train_dir = os.path.join(TrainImageDir, "train")
val_dir = os.path.join(TrainImageDir, "val")
os.makedirs(train_dir, exist_ok=True)
os.makedirs(val_dir, exist_ok=True)
# 划分数据集
train_files, val_files = split_dataset(TrainImageDir)
# 复制文件到 train/ 和 val/ 文件夹
copy_files_to_folders(train_files, TrainImageDir, train_dir)
copy_files_to_folders(val_files, TrainImageDir, val_dir)
# 写入训练集和验证集txt文件
train_txt = os.path.join(TrainImageDir, "train.txt")
val_txt = os.path.join(TrainImageDir, "val.txt")
write_file_list(train_files, train_txt, train_dir)
write_file_list(val_files, val_txt, val_dir)
# 更新yaml文件
nc = 5
names = ['flag', 'buoy', 'shipname', 'ship', 'uncover']
update_yaml(yaml_path, train_txt, val_txt, nc, names)
print(f"数据集划分完成,训练集: {len(train_files)} 张,验证集: {len(val_files)}")
print(f"训练集列表已写入: {train_txt}")
print(f"验证集列表已写入: {val_txt}")
print(f"YAML文件已更新: {yaml_path}")
RAW_WEIGHTS_PATH = "/home/th/jcq/AIlib/weights_init"
model_path = os.path.join(RAW_WEIGHTS_PATH,model_name,"yolov5.pt")
print("####line98",model_path)
transformed["TrainParameter"] = str([
batch_size,
img_size,
epochs,
yaml_path,
model_path
])
return transformed
def validate_training_request(data):
"""Validate the training request structure"""
required_fields = {
"Scene": str,
"Command": str,
"Version": str,
"Name": str,
"Model": str,
"TrainParameter": str
}
for field, field_type in required_fields.items():
if field not in data:
return False, f"Missing required field: {field}"
if not isinstance(data[field], field_type):
return False, f"Field {field} must be {field_type.__name__}"
if data["Scene"] not in ["Train"]:
return False, "Invalid Scene value"
if data["Command"] not in ["start", "stop"]:
return False, "Invalid Command value"
return True, ""
def send_to_kafka(topic, message):
"""Send message to Kafka topic"""
producer = None
try:
producer = KafkaProducer(**KAFKA_CONFIG)
future = producer.send(topic, value=message)
# Wait for message to be delivered
result = future.get(timeout=10)
logger.info(f"Message sent to partition {result.partition}, offset {result.offset}")
return True, ""
except Exception as e:
logger.error(f"Failed to send message: {e}")
return False, str(e)
finally:
if producer:
producer.flush()
producer.close()
@app.route('/api/train', methods=['POST'])
def handle_training_request():
"""HTTP endpoint for training requests (non-blocking)"""
try:
# Parse and validate basic request structure
data = request.get_json()
if not data or not isinstance(data, dict):
return jsonify({"status": "error", "message": "Invalid JSON data"}), 400
# Generate request ID
request_id = generate_request_id()
# Start background process (daemon=True ensures process exits with main)
p = Process(
target=process_training_task,
args=(data, request_id),
daemon=True
)
p.start()
# Immediate response
return jsonify({
"status": "success",
"message": "Training task started",
"request_id": request_id,
"process_id": p.pid
}), 202 # 202 Accepted
except Exception as e:
logger.error(f"Request handling error: {str(e)}")
return jsonify({
"status": "error",
"message": "Internal server error"
}), 500
def process_training_task(data, request_id):
"""独立进程执行耗时任务"""
try:
logger.info(f"[{request_id}] Starting background task (PID: {os.getpid()})")
start_time = time.time()
# 1. 数据转换
transformed = transform_message_format(data)
transformed["Request_ID"] = request_id
# 2. 数据验证
is_valid, msg = validate_training_request(transformed)
if not is_valid:
raise ValueError(msg)
# 3. 发送到Kafka
success, error = send_to_kafka(TOPIC_TRAIN, transformed)
if not success:
raise RuntimeError(error)
elapsed = time.time() - start_time
logger.info(f"[{request_id}] Task completed in {elapsed:.2f}s")
except Exception as e:
logger.error(f"[{request_id}] Task failed: {str(e)}")
if __name__ == '__main__':
app.run(host='192.168.10.11', port=6688, debug=True)