tuoheng_AIPlatform/AI_web_dsj/AI_auto_train.py

1508 lines
45 KiB
Python
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from flask import Flask, request, jsonify
from flask_restful import Api, Resource
import random
import json
import string
import threading
import time
import os
import requests
from api_task import check_task_api
import datetime
import shutil
from flask import Flask, request, jsonify
from minio import Minio
from minio.error import S3Error
import os
import paramiko
from scp import SCPClient
import urllib.parse
import multiprocessing
import uuid
from flask import Flask, request, jsonify
import os
from kafka import KafkaProducer
import json
import random
import string
import logging
from multiprocessing import Process
import time
import os
import random
import yaml
from shutil import copyfile
from minio import Minio
from minio.error import S3Error
from minio import Minio
from minio.error import S3Error
import os
import urllib.parse
from minio import Minio
from minio.error import S3Error
import os
import urllib.parse
from api_task import check_task_api
app = Flask(__name__)
api = Api(app)
model_dict = {
"001": "river",
"002": "forest2",
"003": "highWay2",
"006": "vehicle",
"007": "pedestrian",
"008": "smogfire",
"009": "AnglerSwimmer",
"010": "countryRoad",
"011": "ship2",
"013": "channelEmergency",
"014": "forest2",
"015": "river2",
"016": "cityMangement2",
"017": "drowning",
"018": "noParking",
"019": "illParking",
"020": "cityRoad",
"023": "pothole",
"024": "channel2",
"025": "riverT",
"026": "forestCrowd",
"027": "highWay2T",
"028": "smartSite",
"029": "rubbish",
"030": "firework"
}
PREFIX = "weights/026/v1.0/"
LOCAL_DIR = "downloaded_weights" # 本地存储目录
# Remote server configuration
REMOTE_HOST = "172.16.225.150"
REMOTE_USER = "root"
REMOTE_PASSWORD = "P#mnuJ6r4A"
REMOTE_BASE_PATH = "/home/th/jcq/AIlib/AIlib2_train"
# Mock database for storing task statuses
training_tasks = {}
model_versions = {}
active_threads = {} # To track and manage active training threads
# Configuration for different endpoints
CALLBACK_URLS = {
'training': "http://172.18.0.129:8084/api/admin/trainTask/updateTrainTask",
'check': "http://172.18.0.129:8084/api/admin/checkTask/callback",
'management': "http://172.18.0.129:8084/api/admin/trainTask/updateTrainTask",
'deployment': "http://172.18.0.129:8084/api/admin/modelApplication/aiCallBack"
}
UPDATE_INTERVAL = 1 # Update interval in seconds
def get_model_name(code):
# Ensure code is a 3-digit string
code_str = f"{int(code):03d}" if str(code).isdigit() else code
return model_dict.get(code_str, f"Unknown code: {code}")
class TrainThread(threading.Thread):
def __init__(self, request_id, process_type, version):
super().__init__()
self.request_id = request_id
self.process_type = process_type
self.version = version
self._stop_event = threading.Event()
def run(self):
"""Simulate a long-running process with progress updates"""
for i in range(1, 11):
if self._stop_event.is_set():
print(f"Training {self.request_id} stopped by user request")
training_tasks[self.request_id] = {
'request_id': self.request_id,
'process_type': self.process_type,
'progress': f'{i*10}%',
'model_version': self.version,
'status': 'stopped'
}
return
time.sleep(2)
progress = i * 10
training_tasks[self.request_id] = {
'request_id': self.request_id,
'process_type': self.process_type,
'progress': f'{progress}%',
'model_version': self.version,
'status': 'running'
}
# Mark as completed when done
if not self._stop_event.is_set():
training_tasks[self.request_id] = {
'request_id': self.request_id,
'process_type': self.process_type,
'progress': '100%',
'model_version': self.version,
'status': 'completed'
}
def stop(self):
"""Signal the thread to stop"""
self._stop_event.set()
def generate_request_id(length=30):
return 'bb' + ''.join(random.sample(string.ascii_letters, length))
def create_status_updater(process_type):
"""Create a status update function for a specific process type"""
def status_updater():
while True:
try:
for task_id, task_data in list(training_tasks.items()):
# Only update tasks of this process type
if task_data.get('process_type') != process_type:
continue
# Skip if task was stopped
if task_data.get('status') == 'stopped':
continue
# Create payload based on process type
if process_type == 'training':
payload = {
"id": task_id,
"dspModelCode": 502,
"algorithmMap": "123",
"algorithmAccuracy": "123",
"algorithmRecall": "123",
"aiVersion": "V1.1",
"progress": task_data.get('progress', '0%'),
"traiStartTime": "2025-05-21 16:10:05",
"trainEndTime": "2025-05-21 16:10:05",
"trainDuration": "100",
"status": 4 if task_data.get('status') == 'completed' else 2
}
elif process_type == 'check':
payload = {
"id": task_id,
"checkStatus": "2" if task_data.get('status') == 'completed' else "running",
"aiFileImage": "http://172.18.0.129:8084/5116fd79-7f98-4089-a59f-ec4df2cdb0f3_ai.jpg",
"checkTaskDetails": [
{
"dspModelCode": "123",
"dspModelName": "行人识别",
"possibility": "30%"
},
{
"dspModelCode": "123",
"dspModelName": "车辆识别",
"possibility": "50%"
}
],
"checkTime": "2025-05-21 16:10:05"
}
elif process_type == 'management':
payload = {
"id": task_id,
"modelStatus": "active",
"lastUpdated": "2025-05-21 16:10:05"
}
elif process_type == 'deployment':
payload = {
"id": task_id,
"status": "4" if task_data.get('status') == 'completed' else "processing",
"deployTime": "2025-05-21 16:10:05"
}
print(f"Sending {process_type} status update:", payload)
try:
response = requests.post(CALLBACK_URLS[process_type], json=payload)
if response.status_code == 200:
print(f"{process_type} status update successful")
else:
print(f"{process_type} status update failed with status code: {response.status_code}")
except Exception as e:
print(f"Error sending {process_type} status update: {str(e)}")
except Exception as e:
print(f"Error in {process_type} status update thread: {str(e)}")
time.sleep(UPDATE_INTERVAL)
return status_updater
# Start status update threads for each process type
for process_type in CALLBACK_URLS:
thread = threading.Thread(
target=create_status_updater(process_type),
name=f"{process_type}_status_updater"
)
thread.daemon = True
thread.start()
class TrainingTask(Resource):
def post(self):
data = request.json
required_fields = ['code', 'version', 'image_dir', 'parameters']
if not all(field in data for field in required_fields):
return {'error': 'Missing required fields'}, 400
version = data['version']
request_id = data['request_id']
# Create and start the training thread
thread = TrainThread(request_id, 'training', version)
active_threads[request_id] = thread
thread.start()
training_tasks[request_id] = {
'request_id': request_id,
'process_type': 'training',
'progress': '0%',
'model_version': version,
'status': 'initializing'
}
return {
'request_id': request_id,
'code': '0',
'msg': 'Training task created',
"data": None
}, 201
def get(self, request_id=None):
if request_id:
task = training_tasks.get(request_id)
if not task:
return {'error': 'Task not found'}, 404
return task
else:
return {
'tasks': [{
'request_id': k,
'process_type': v.get('process_type', 'unknown'),
'status': v['status'],
'progress': v['progress'],
'model_version': v['model_version']
} for k, v in training_tasks.items()]
}
class TrainingTask_Stop0(Resource):
def post(self):
data = request.json
request_id = data['request_id']
if request_id in active_threads:
# Stop the thread
active_threads[request_id].stop()
# Remove from active threads
del active_threads[request_id]
if request_id in training_tasks:
training_tasks[request_id]['status'] = 'stopped'
return {
'request_id': request_id,
'code': '0',
'msg': 'Training stopped successfully',
"data": None
}, 200
class TrainingTask_Stop(Resource):
def post(self):
data = request.json
# {"Command":"Stop","Errorcode":"","ModelScenetype":"train","request_id":"jcq13f4353a9312838c1d647f2593d3c3a6"}
# {
# "request_id": "jcq13f4353a9312838c1d647f2593d3c3a6",
# "ModelSceneType": "train",
# "Command": "Stop"
# }
# # 验证必要字段
# if 'request_id' not in data or 'ModelSceneType' not in data or 'Command' not in data:
# return {
# 'code': '1',
# 'msg': 'Missing required fields (request_id, ModelSceneType, Command)',
# 'data': None
# }, 400
# 转换消息格式
training_request = {
"Scene": data['ModelScenetype'].capitalize(), # 首字母大写
"Command": data['Command'].lower(), # 转为小写
"Request_ID": data['request_id']
}
# 发送到Kafka
try:
producer = KafkaProducer(**KAFKA_CONFIG)
future = producer.send(TOPIC_TRAIN, value=training_request)
# 等待消息发送完成
future.get(timeout=10)
producer.flush()
# 更新本地任务状态
if data['request_id'] in active_threads:
active_threads[data['request_id']].stop()
del active_threads[data['request_id']]
if data['request_id'] in training_tasks:
training_tasks[data['request_id']]['status'] = 'stopped'
return {
'request_id': data['request_id'],
'code': '0',
'msg': 'Stop command sent successfully',
'data': None
}, 200
except Exception as e:
return {
'request_id': data['request_id'],
'code': '2',
'msg': f'Failed to send stop command: {str(e)}',
'data': None
}, 500
def manager_simulate_process(ModelCode):
# request_id = data.get('request_id', "7d0cd072a7a939e77c4a602ac6151582")
# version = data.get('version', "1.0")
# ModelCode = data.get("ModelCode")
if str(ModelCode) in ["000","032"]:
# 必须得特定的数据格式才可以
callback_data = {
"isExist":1
}
else:
# 必须得特定的数据格式才可以
callback_data = {
"isExist":2
}
# # 更新本地任务状态(如果需要)
# training_tasks[request_id] = callback_data
training_tasks=callback_data
# 发送回调请求
callback_url = "http://172.18.0.129:8084/api/admin/modelApplication/aiCallBack"
headers = {
'Content-Type': 'application/json',
}
try:
response = requests.post(
callback_url,
data=json.dumps(callback_data),
headers=headers
)
# 检查响应状态
if response.status_code == 200:
print(f"回调成功: {response.json()}")
else:
print(f"回调失败,状态码: {response.status_code}, 响应: {response.text}")
except Exception as e:
print(f"回调请求异常: {str(e)}")
def stop_simulate_process(request_id, process_type, version="1.0"):
# 必须得特定的数据格式才可以
callback_data = {
"id": request_id,
"status":4
}
# 更新本地任务状态(如果需要)
training_tasks[request_id] = callback_data
# 发送回调请求
callback_url = "http://172.18.0.129:8084/api/admin/modelApplication/aiCallBack"
headers = {
'Content-Type': 'application/json',
# 如果需要认证,添加相应的头部
# 'Authorization': 'Bearer your_token_here'
}
try:
response = requests.post(
callback_url,
data=json.dumps(callback_data),
headers=headers
)
# 检查响应状态
if response.status_code == 200:
print(f"回调成功: {response.json()}")
else:
print(f"回调失败,状态码: {response.status_code}, 响应: {response.text}")
except Exception as e:
print(f"回调请求异常: {str(e)}")
# √
def deploy_simulate_process(request_id, process_type, version="1.0"):
callback_data = {
"id": request_id,
"status":4
}
# 更新本地任务状态(如果需要)
training_tasks[request_id] = callback_data
# 发送回调请求
callback_url = "http://172.18.0.129:8084/api/admin/modelApplication/aiCallBack"
headers = {
'Content-Type': 'application/json',
# 如果需要认证,添加相应的头部
# 'Authorization': 'Bearer your_token_here'
}
try:
response = requests.post(
callback_url,
data=json.dumps(callback_data),
headers=headers
)
# 检查响应状态
if response.status_code == 200:
print(f"回调成功: {response.json()}")
else:
print(f"回调失败,状态码: {response.status_code}, 响应: {response.text}")
except Exception as e:
print(f"回调请求异常: {str(e)}")
# √
def check_simulate_process(input_json):
callback_data = check_task_api.infer_yolov5(input_json)
print("*******line366",callback_data)
request_id = input_json.get("request_id", "")
# 更新本地任务状态(如果需要)
training_tasks[request_id] = callback_data
# 发送回调请求
callback_url = "http://172.18.0.129:8084/api/admin/checkTask/callback"
headers = {
'Content-Type': 'application/json',
# 如果需要认证,添加相应的头部
# 'Authorization': 'Bearer your_token_here'
}
try:
response = requests.post(
callback_url,
data=json.dumps(callback_data),
headers=headers
)
# 检查响应状态
if response.status_code == 200:
print(f"回调成功: {response.json()}")
else:
print(f"回调失败,状态码: {response.status_code}, 响应: {response.text}")
except Exception as e:
print(f"回调请求异常: {str(e)}")
class TrainingTask_Check(Resource):
def post(self):
input_json = request.json
print("###line587",input_json)
data = request.json
request_id = data['request_id']
# Start the inference in a new thread
thread = threading.Thread(
target=check_simulate_process,
args=(input_json,)
)
thread.start()
return {
'id': request_id,
'code': '0',
'msg': 'Model check started',
'data': None
}, 201
class TrainingTask_Manager(Resource):
def post(self):
data = request.json
ModelCode = data.get("ModelCode")
# Convert to string and pad with leading zeros (if it's a number)
if isinstance(ModelCode, int):
model_code_str = f"{ModelCode:03d}" # Formats as 3 digits (e.g., 5 → "005")
else:
model_code_str = str(ModelCode).zfill(3) # Pads with zeros (e.g., "5" → "005")
if 0 <= int(model_code_str) <= 32:
callback_data = {
"isExist": 1
}
else:
callback_data = {
"isExist": 0
}
return callback_data , 201
model_dict = {
"001": "river",
"002": "forest2",
"003": "highWay2",
"006": "vehicle",
"007": "pedestrian",
"008": "smogfire",
"009": "AnglerSwimmer",
"010": "countryRoad",
"011": "ship2",
"013": "channelEmergency",
"014": "forest2",
"015": "river2",
"016": "cityMangement2",
"017": "drowning",
"018": "noParking",
"019": "illParking",
"020": "cityRoad",
"023": "pothole",
"024": "channel2",
"025": "riverT",
"026": "forestCrowd",
"027": "highWay2T",
"028": "smartSite",
"029": "rubbish",
"030": "firework"
}
ENDPOINT = "minio-jndsj.t-aaron.com:2443"
ACCESS_KEY = "PJM0c2qlauoXv5TMEHm2"
SECRET_KEY = "Wr69Dm3ZH39M3GCSeyB3eFLynLPuGCKYfphixZuI"
BUCKET_NAME = "algorithm"
PREFIX = "weights/026/v1.0/"
LOCAL_DIR = "downloaded_weights" # 本地存储目录
# Remote server configuration
REMOTE_HOST = "172.16.225.151"
REMOTE_USER = "root"
REMOTE_PASSWORD = "P#mnuJ6r4A"
REMOTE_BASE_PATH = "/home/th/jcq/AI_AutoPlat/"
def get_model_name(code):
# Ensure code is a 3-digit string
code_str = f"{int(code):03d}" if str(code).isdigit() else code
return model_dict.get(code_str, f"Unknown code: {code}")
def parse_minio_path(model_dir):
"""Extract bucket name and object prefix from MinIO URL"""
parsed = urllib.parse.urlparse(model_dir)
if not parsed.netloc.startswith('minio-jndsj.t-aaron.com:2443'):
raise ValueError("Invalid MinIO URL")
path_parts = parsed.path.lstrip('/').split('/')
if len(path_parts) < 3:
raise ValueError("Invalid MinIO path format")
bucket = path_parts[0]
# The prefix is everything after the bucket name and before the filename
prefix = '/'.join(path_parts[1:-1]) # Exclude the filename
filename = path_parts[-1]
return bucket, prefix, filename
def download_minio_files(model_dir, local_dir='downloads'):
"""Download files from MinIO based on the model_dir URL"""
try:
ENDPOINT = "minio-jndsj.t-aaron.com:2443"
ACCESS_KEY = "PJM0c2qlauoXv5TMEHm2"
SECRET_KEY = "Wr69Dm3ZH39M3GCSeyB3eFLynLPuGCKYfphixZuI"
BUCKET_NAME = "algorithm"
bucket, prefix, filename = parse_minio_path(model_dir)
# Create MinIO client
client = Minio(
ENDPOINT,
access_key=ACCESS_KEY,
secret_key=SECRET_KEY,
secure=False
)
# Ensure local directory exists
os.makedirs(local_dir, exist_ok=True)
# Build full object path
object_path = f"{prefix}/{filename}"
# Local file path
local_path = os.path.join(local_dir, filename)
# Download the file
client.fget_object(bucket, object_path, local_path)
print(f"Downloaded: {object_path} -> {local_path}")
return local_path
except S3Error as err:
print(f"MinIO Error occurred: {err}")
return None
except Exception as e:
print(f"Error occurred: {e}")
return None
def download_minio_deploy(endpoint, access_key, secret_key, bucket_name,prefix='', local_dir='downloaded_weights'):
# 创建 MinIO 客户端
client = Minio(
endpoint,
access_key=access_key,
secret_key=secret_key,
secure=True # 使用 HTTPS
)
try:
# 确保本地目录存在
os.makedirs(local_dir, exist_ok=True)
# 列出文件夹中的所有对象
objects = client.list_objects(bucket_name, prefix=prefix, recursive=True)
for obj in objects:
# 构建本地文件路径
local_path = os.path.join(local_dir, os.path.relpath(obj.object_name, prefix))
# 确保子目录存在
os.makedirs(os.path.dirname(local_path), exist_ok=True)
# 下载文件
print(f"正在下载: {obj.object_name} -> {local_path}")
client.fget_object(bucket_name, obj.object_name, local_path)
print("所有文件下载完成!")
return local_path
except S3Error as e:
print(f"MinIO 错误: {e}")
except Exception as e:
print(f"发生错误: {e}")
def deploy_to_remote(local_path, model_code):
"""Deploy file to remote server"""
try:
model_name = get_model_name(model_code)
remote_path = os.path.join(REMOTE_BASE_PATH, model_name, os.path.basename(local_path))
# Create SSH client
ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
ssh.connect(REMOTE_HOST, username=REMOTE_USER, password=REMOTE_PASSWORD)
# Create SCP client
scp = SCPClient(ssh.get_transport())
# Upload file
scp.put(local_path, remote_path)
print(f"Successfully deployed {local_path} to {remote_path}")
# Close connections
scp.close()
ssh.close()
return remote_path
except Exception as e:
print(f"Deployment failed: {e}")
return None
def deploy_to_local(local_path, model_code):
"""Deploy file to remote server"""
try:
model_name = get_model_name(model_code)
REMOTE_BASE_PATH = "/home/th/jcq/AIlib/Deployment"
remote_path = os.path.join(REMOTE_BASE_PATH, model_name, os.path.basename(local_path))
print("###line858",REMOTE_BASE_PATH,model_name,local_path)
#/home/th/jcq/AIlib/Deployment
# # Create SSH client
# ssh = paramiko.SSHClient()
# ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
# ssh.connect(REMOTE_HOST, username=REMOTE_USER, password=REMOTE_PASSWORD)
# # Create SCP client
# scp = SCPClient(ssh.get_transport())
# # Upload file
# scp.put(local_path, remote_path)
# 确保目标目录存在
os.makedirs(os.path.dirname(remote_path), exist_ok=True)
# 移动文件
shutil.move(local_path, remote_path)
print(f"Successfully deployed {local_path} to {remote_path}")
# # Close connections
# scp.close()
# ssh.close()
return remote_path
except Exception as e:
print(f"Deployment failed: {e}")
return None
def deployment_worker(model_dir, model_code, request_id):
"""Worker function that runs in a separate process to handle deployment"""
callback_url = "http://172.18.0.129:8084/api/admin/modelApplication/aiCallBack"
success = False
message = ""
remote_path = ""
try:
# Parse the path to get bucket and prefix
bucket, prefix, filename = parse_minio_path(model_dir)
print(f"[Deployment Worker] Downloading from bucket: {bucket}, prefix: {prefix}, filename: {filename}")
# 配置信息
ENDPOINT = "minio-jndsj.t-aaron.com:2443"
ACCESS_KEY = "PJM0c2qlauoXv5TMEHm2"
SECRET_KEY = "Wr69Dm3ZH39M3GCSeyB3eFLynLPuGCKYfphixZuI"
model_path = download_minio_deploy(
endpoint=ENDPOINT,
access_key=ACCESS_KEY,
secret_key=SECRET_KEY,
bucket_name=bucket,
prefix=prefix,
local_dir=LOCAL_DIR
)
print("###line890",model_path)
# # 2. Deploy to remote server
# remote_path = deploy_to_remote(model_path, model_code)
remote_path = deploy_to_local(model_path, model_code)
if not remote_path:
message = "Failed to deploy model to remote server"
print(f"[Deployment Worker] {message}")
raise Exception(message)
# If we get here, deployment was successful
success = True
message = "操作成功"
print(f"[Deployment Worker] Deployment successful: {remote_path}")
except Exception as e:
message = f"Deployment failed: {str(e)}"
print(f"[Deployment Worker] {message}")
success = False
finally:
# Clean up local file if it exists
if 'local_path' in locals() and os.path.exists(model_path):
try:
os.remove(model_path)
except OSError as e:
print(f"[Deployment Worker] Warning: Failed to clean up local file: {e}")
# Send callback
try:
if success:
callback_data = {
"id":request_id,
"status": 3,
"msg": "操作成功",
"data": None # Using None instead of "null" string
}
else:
callback_data = {
"id":request_id,
"status": 4,
"msg": "未查询到数据",
"data": None # Using None instead of "null" string
}
response = requests.post(
callback_url,
json=callback_data,
timeout=10
)
if response.status_code != 200:
print(f"[Deployment Worker] Callback failed with status {response.status_code}: {response.text}")
print("line956",callback_data)
else:
print("[Deployment Worker] Callback sent successfully")
print("line959",callback_data)
except Exception as e:
print(f"[Deployment Worker] Failed to send callback: {str(e)}")
return success
@app.route('/api/train/deploy', methods=['POST'])
def deploy_model():
# Get input data
data = request.get_json()
print("###line879",data)
if not data or 'ModelDir' not in data or 'ModelCode' not in data:
return jsonify({
"status": "error",
"message": "Missing required fields (ModelDir or ModelCode)"
}), 400
# model_dir = data['ModelDir']
model_dir = os.path.join(data['ModelDir'] ,'yolov5.pt')
# model_dir = os.path.join(data['ModelDir'] ,'best.pt')
model_code = data['ModelCode']
# Generate a unique request ID if not provided
request_id = data.get('request_id', str(uuid.uuid4()))
try:
# Start deployment in a separate process
process = multiprocessing.Process(
target=deployment_worker,
args=(model_dir, model_code,request_id)
)
process.start()
# Immediately return response while deployment continues in background
#
return jsonify({
"request_id": request_id,
"status": 4
}), 202 # 202 Accepted status code
except Exception as e:
return jsonify({
"status": "error",
"message": f"Failed to start deployment: {str(e)}",
"request_id": request_id
}), 500
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Kafka configuration
KAFKA_CONFIG = {
'bootstrap_servers': '172.18.0.126:9094',
'value_serializer': lambda v: json.dumps(v).encode('utf-8')
}
# Kafka topics
TOPIC_TRAIN = 'training-tasks'
model_dict = {
"001": "river",
"002": "forest2",
"003": "highWay2",
"006": "vehicle",
"007": "pedestrian",
"008": "smogfire",
"009": "AnglerSwimmer",
"010": "countryRoad",
"011": "ship2",
"013": "channelEmergency",
"014": "forest2",
"015": "river2",
"016": "cityMangement2",
"017": "drowning",
"018": "noParking",
"019": "illParking",
"020": "cityRoad",
"023": "pothole",
"024": "channel2",
"025": "riverT",
"026": "forestCrowd",
"027": "highWay2T",
"028": "smartSite",
"029": "rubbish",
"030": "firework"
}
def validate_minio_url(url):
"""验证MinIO URL格式是否正确"""
try:
parsed = urllib.parse.urlparse(url)
if not all([parsed.scheme in ('http', 'https'), parsed.netloc, len(parsed.path.split('/')) >= 3]):
raise ValueError("Invalid MinIO URL format")
return parsed
except Exception as e:
raise ValueError(f"URL validation failed: {str(e)}")
def download_minio_directory(minio_url, local_base_dir):
"""
下载MinIO目录到本地
:param minio_url: 完整的MinIO路径 (e.g. "http://172.18.0.129:8084/th-dsp/027123")
:param local_base_dir: 本地存储基础目录 (e.g. "/home/th/jcq/AI_AutoPlat/DATA/TrainDatasets")
:return: (success, message) 元组
"""
try:
# 验证并解析URL
parsed = validate_minio_url(minio_url)
endpoint = parsed.netloc
bucket_name = parsed.path.split('/')[1]
prefix = '/'.join(parsed.path.split('/')[2:])
# 从环境变量获取凭证(推荐方式)
access_key = os.getenv("MINIO_ACCESS_KEY", "PJM0c2qlauoXv5TMEHm2") # 默认值仅用于测试
secret_key = os.getenv("MINIO_SECRET_KEY", "Wr69Dm3ZH39M3GCSeyB3eFLynLPuGCKYfphixZuI")
# 创建本地目录使用prefix的最后部分作为目录名
dir_name = os.path.basename(prefix) or "minio_download"
local_dir = os.path.join(local_base_dir, dir_name)
os.makedirs(local_dir, exist_ok=True)
# 初始化MinIO客户端
client = Minio(
endpoint,
access_key=access_key,
secret_key=secret_key,
secure=parsed.scheme == 'https'
)
print(f"Starting download from {minio_url} to {local_dir}")
# 递归下载所有文件
objects = client.list_objects(bucket_name, prefix=prefix, recursive=True)
downloaded_files = 0
for obj in objects:
# 构建本地路径(保持相对目录结构)
relative_path = os.path.relpath(obj.object_name, prefix)
local_path = os.path.join(local_dir, relative_path)
# 确保目录存在
os.makedirs(os.path.dirname(local_path), exist_ok=True)
# 下载文件
client.fget_object(bucket_name, obj.object_name, local_path)
downloaded_files += 1
print(f"Downloaded [{downloaded_files}]: {obj.object_name} -> {local_path}")
print(f"Download completed! Total files: {downloaded_files}")
return True, local_dir
except S3Error as e:
error_msg = f"MinIO Error: {str(e)}"
if "NoSuchBucket" in str(e):
error_msg = f"Bucket not found: {bucket_name}"
return False, error_msg
except ValueError as e:
return False, f"Invalid URL: {str(e)}"
except Exception as e:
return False, f"Unexpected error: {str(e)}"
def split_dataset(data_dir, train_ratio=0.8):
"""Split dataset into training and validation sets, supporting both JPG and PNG formats
Args:
data_dir (str): Directory containing images and labels
train_ratio (float): Ratio of training data (default: 0.8)
Returns:
tuple: (train_files, val_files) lists of image filenames
"""
# Supported image extensions
IMAGE_EXTENSIONS = ('.jpg', '.jpeg', '.png')
# Find all valid image-label pairs
valid_files = []
for f in os.listdir(data_dir):
# Check for image files
if f.lower().endswith(IMAGE_EXTENSIONS):
# Get corresponding label file path
base_name = os.path.splitext(f)[0] # Remove extension
txt_file = f"{base_name}.txt"
# Only include if label file exists
if os.path.exists(os.path.join(data_dir, txt_file)):
valid_files.append(f)
# Shuffle the dataset
random.shuffle(valid_files)
# Split into training and validation sets
split_idx = int(len(valid_files) * train_ratio)
return valid_files[:split_idx], valid_files[split_idx:]
def write_file_list(file_list, output_file, base_dir):
with open(output_file, 'w') as f:
for file in file_list:
f.write(os.path.join(base_dir, file) + '\n')
def copy_files_to_folders(files, src_dir, dest_dir):
"""复制图片和标签文件到目标文件夹"""
os.makedirs(dest_dir, exist_ok=True)
for file in files:
# 复制图片
src_img = os.path.join(src_dir, file)
dest_img = os.path.join(dest_dir, file)
copyfile(src_img, dest_img)
# # 复制对应的标签文件
# src_txt = os.path.join(src_dir, file.replace('.jpg', '.txt'))
# dest_txt = os.path.join(dest_dir, file.replace('.jpg', '.txt'))
# 动态替换扩展名
src_txt = os.path.join(src_dir, file.replace('.jpg', '.txt').replace('.png', '.txt'))
dest_txt = os.path.join(dest_dir, file.replace('.jpg', '.txt').replace('.png', '.txt'))
copyfile(src_txt, dest_txt)
def update_yaml(yaml_file, train_txt, val_txt):
with open(yaml_file, 'r') as f:
data = yaml.safe_load(f)
data['train'] = train_txt
data['val'] = val_txt
with open(yaml_file, 'w') as f:
yaml.dump(data, f, default_flow_style=False)
def get_model_name(code):
# Ensure code is a 3-digit string
code_str = f"{int(code):03d}" if str(code).isdigit() else code
return model_dict.get(code_str, f"Unknown code: {code}")
def generate_request_id(length=30):
"""Generate a unique request ID"""
return 'bb' + ''.join(random.sample(string.ascii_letters, length))
def transform_message_format(input_data):
# transformed = {
# "Scene": input_data.get("ModelScenetype", "train").capitalize(),
# "Command": "start", # Default command
# "Request_ID": input_data.get("request_id", generate_request_id()),
# "Version": input_data.get("Version", "v1.0"),
# "Version": input_data.get("Version", "v1.0"),
# "Name": "channel2", # Default name if not provided
# "Model": input_data.get("code", input_data.get("Code", "026")),
# }
transformed = {
"Scene": input_data.get("ModelScenetype", "train").capitalize(),
"Command": "start", # Default command
"Request_ID": input_data.get("request_id", generate_request_id()),
"Version": input_data.get("version"),
"Name": "channel2", # Default name if not provided
"Model": input_data.get("code", input_data.get("Code", "026")),
}
# Handle TrainParameter construction
params = input_data.get("parameters", {})
batch_size = params.get("batch_size")
img_size = params.get("img_size")
epochs = params.get("epochs")
model_name = get_model_name(input_data.get("code", input_data.get("Code", "026")))
RAW_YAML_PATH = "/home/th/jcq/AI_AutoPlat/AI_web_dsj/config/yaml"
yaml_path = os.path.join(RAW_YAML_PATH,f"{model_name}.yaml")
print("####line137",yaml_path)
# minio_url = "http://172.18.0.129:8084/th-dsp/027123"
# 获取URL
minio_url = input_data.get("TrainImageDir")
if not minio_url:
print("Error: TrainImageDir not specified in input_data")
else:
# 指定本地存储目录
local_base_dir = "/home/th/jcq/AI_AutoPlat/DATA/TrainDatasets"
# 执行下载
success, result = download_minio_directory(minio_url, local_base_dir)
if success:
print(f"Successfully downloaded to: {result}")
else:
print(f"Download failed: {result}")
if success:
print("line277",f"Files downloaded successfully to: {result}")
else:
print(f"Download failed: {result}")
TrainImageDir = result
# 创建 train 和 val 文件夹
train_dir = os.path.join(TrainImageDir, "train")
val_dir = os.path.join(TrainImageDir, "val")
os.makedirs(train_dir, exist_ok=True)
os.makedirs(val_dir, exist_ok=True)
# 划分数据集
train_files, val_files = split_dataset(TrainImageDir)
# 复制文件到 train/ 和 val/ 文件夹
copy_files_to_folders(train_files, TrainImageDir, train_dir)
copy_files_to_folders(val_files, TrainImageDir, val_dir)
# 写入训练集和验证集txt文件
train_txt = os.path.join(TrainImageDir, "train.txt")
val_txt = os.path.join(TrainImageDir, "val.txt")
write_file_list(train_files, train_txt, train_dir)
write_file_list(val_files, val_txt, val_dir)
#JCQ: 更新yaml训练文件
# 更新yaml文件预先获取直接更新yaml
nc = 5
names = ['flag', 'buoy', 'shipname', 'ship', 'uncover']
update_yaml(yaml_path, train_txt, val_txt)
print(f"数据集划分完成,训练集: {len(train_files)} 张,验证集: {len(val_files)}")
print(f"训练集列表已写入: {train_txt}")
print(f"验证集列表已写入: {val_txt}")
print(f"YAML文件已更新: {yaml_path}")
RAW_WEIGHTS_PATH = "/home/th/jcq/AIlib/weights_init"
model_path = os.path.join(RAW_WEIGHTS_PATH,model_name,"yolov5.pt")
# model_path = os.path.join(RAW_WEIGHTS_PATH,model_name,"best.pt")
print("####line98",model_path)
transformed["TrainParameter"] = str([
batch_size,
img_size,
epochs,
yaml_path,
model_path
])
return transformed
def validate_training_request(data):
"""Validate the training request structure"""
required_fields = {
"Scene": str,
"Command": str,
"Version": str,
"Name": str,
"Model": str,
"TrainParameter": str
}
for field, field_type in required_fields.items():
if field not in data:
return False, f"Missing required field: {field}"
if not isinstance(data[field], field_type):
return False, f"Field {field} must be {field_type.__name__}"
if data["Scene"] not in ["Train"]:
return False, "Invalid Scene value"
if data["Command"] not in ["start", "stop"]:
return False, "Invalid Command value"
return True, ""
def send_to_kafka(topic, message):
"""Send message to Kafka topic"""
producer = None
try:
producer = KafkaProducer(**KAFKA_CONFIG)
future = producer.send(topic, value=message)
# Wait for message to be delivered
result = future.get(timeout=10)
logger.info(f"Message sent to partition {result.partition}, offset {result.offset}")
return True, ""
except Exception as e:
logger.error(f"Failed to send message: {e}")
return False, str(e)
finally:
if producer:
producer.flush()
producer.close()
@app.route('/api/train', methods=['POST'])
def handle_training_request():
"""HTTP endpoint for training requests (non-blocking)"""
try:
# Parse and validate basic request structure
data = request.get_json()
if not data or not isinstance(data, dict):
return jsonify({"status": "error", "message": "Invalid JSON data"}), 400
# Generate request ID
request_id = data.get("request_id")
# Start background process (daemon=True ensures process exits with main)
p = Process(
target=process_training_task,
args=(data, request_id),
daemon=True
)
p.start()
# Immediate response
return jsonify({
"status": "success",
"message": "Training task started",
"request_id": request_id,
"process_id": p.pid
}), 202 # 202 Accepted
except Exception as e:
logger.error(f"Request handling error: {str(e)}")
return jsonify({
"status": "error",
"message": "Internal server error"
}), 500
def process_training_task(data, request_id):
"""独立进程执行耗时任务"""
try:
logger.info(f"[{request_id}] Starting background task (PID: {os.getpid()})")
start_time = time.time()
# 1. 数据转换
transformed = transform_message_format(data)
transformed["Request_ID"] = request_id
# 2. 数据验证
is_valid, msg = validate_training_request(transformed)
if not is_valid:
raise ValueError(msg)
# 3. 发送到Kafka
success, error = send_to_kafka(TOPIC_TRAIN, transformed)
if not success:
raise RuntimeError(error)
elapsed = time.time() - start_time
logger.info(f"[{request_id}] Task completed in {elapsed:.2f}s")
except Exception as e:
logger.error(f"[{request_id}] Task failed: {str(e)}")
api.add_resource(TrainingTask_Stop, '/api/train/control')
api.add_resource(TrainingTask_Check, '/api/train/check')
api.add_resource(TrainingTask_Manager, '/api/train/manager')
if __name__ == '__main__':
app.run(host='172.16.225.150', port=5001, debug=True)