1508 lines
45 KiB
Python
1508 lines
45 KiB
Python
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
from flask import Flask, request, jsonify
|
|||
|
|
from flask_restful import Api, Resource
|
|||
|
|
import random
|
|||
|
|
import json
|
|||
|
|
import string
|
|||
|
|
import threading
|
|||
|
|
import time
|
|||
|
|
import os
|
|||
|
|
import requests
|
|||
|
|
|
|||
|
|
from api_task import check_task_api
|
|||
|
|
import datetime
|
|||
|
|
import shutil
|
|||
|
|
|
|||
|
|
|
|||
|
|
from flask import Flask, request, jsonify
|
|||
|
|
from minio import Minio
|
|||
|
|
from minio.error import S3Error
|
|||
|
|
import os
|
|||
|
|
import paramiko
|
|||
|
|
from scp import SCPClient
|
|||
|
|
import urllib.parse
|
|||
|
|
import multiprocessing
|
|||
|
|
import uuid
|
|||
|
|
|
|||
|
|
|
|||
|
|
from flask import Flask, request, jsonify
|
|||
|
|
import os
|
|||
|
|
from kafka import KafkaProducer
|
|||
|
|
import json
|
|||
|
|
import random
|
|||
|
|
import string
|
|||
|
|
import logging
|
|||
|
|
from multiprocessing import Process
|
|||
|
|
import time
|
|||
|
|
|
|||
|
|
import os
|
|||
|
|
import random
|
|||
|
|
import yaml
|
|||
|
|
from shutil import copyfile
|
|||
|
|
|
|||
|
|
from minio import Minio
|
|||
|
|
from minio.error import S3Error
|
|||
|
|
from minio import Minio
|
|||
|
|
from minio.error import S3Error
|
|||
|
|
import os
|
|||
|
|
import urllib.parse
|
|||
|
|
from minio import Minio
|
|||
|
|
from minio.error import S3Error
|
|||
|
|
import os
|
|||
|
|
import urllib.parse
|
|||
|
|
|
|||
|
|
from api_task import check_task_api
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
app = Flask(__name__)
|
|||
|
|
api = Api(app)
|
|||
|
|
|
|||
|
|
model_dict = {
|
|||
|
|
"001": "river",
|
|||
|
|
"002": "forest2",
|
|||
|
|
"003": "highWay2",
|
|||
|
|
"006": "vehicle",
|
|||
|
|
"007": "pedestrian",
|
|||
|
|
"008": "smogfire",
|
|||
|
|
"009": "AnglerSwimmer",
|
|||
|
|
"010": "countryRoad",
|
|||
|
|
"011": "ship2",
|
|||
|
|
"013": "channelEmergency",
|
|||
|
|
"014": "forest2",
|
|||
|
|
"015": "river2",
|
|||
|
|
"016": "cityMangement2",
|
|||
|
|
"017": "drowning",
|
|||
|
|
"018": "noParking",
|
|||
|
|
"019": "illParking",
|
|||
|
|
"020": "cityRoad",
|
|||
|
|
"023": "pothole",
|
|||
|
|
"024": "channel2",
|
|||
|
|
"025": "riverT",
|
|||
|
|
"026": "forestCrowd",
|
|||
|
|
"027": "highWay2T",
|
|||
|
|
"028": "smartSite",
|
|||
|
|
"029": "rubbish",
|
|||
|
|
"030": "firework"
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
PREFIX = "weights/026/v1.0/"
|
|||
|
|
LOCAL_DIR = "downloaded_weights" # 本地存储目录
|
|||
|
|
|
|||
|
|
# Remote server configuration
|
|||
|
|
REMOTE_HOST = "172.16.225.150"
|
|||
|
|
REMOTE_USER = "root"
|
|||
|
|
REMOTE_PASSWORD = "P#mnuJ6r4A"
|
|||
|
|
REMOTE_BASE_PATH = "/home/th/jcq/AIlib/AIlib2_train"
|
|||
|
|
|
|||
|
|
# Mock database for storing task statuses
|
|||
|
|
training_tasks = {}
|
|||
|
|
model_versions = {}
|
|||
|
|
active_threads = {} # To track and manage active training threads
|
|||
|
|
|
|||
|
|
# Configuration for different endpoints
|
|||
|
|
CALLBACK_URLS = {
|
|||
|
|
'training': "http://172.18.0.129:8084/api/admin/trainTask/updateTrainTask",
|
|||
|
|
'check': "http://172.18.0.129:8084/api/admin/checkTask/callback",
|
|||
|
|
'management': "http://172.18.0.129:8084/api/admin/trainTask/updateTrainTask",
|
|||
|
|
'deployment': "http://172.18.0.129:8084/api/admin/modelApplication/aiCallBack"
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
UPDATE_INTERVAL = 1 # Update interval in seconds
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_model_name(code):
|
|||
|
|
# Ensure code is a 3-digit string
|
|||
|
|
code_str = f"{int(code):03d}" if str(code).isdigit() else code
|
|||
|
|
return model_dict.get(code_str, f"Unknown code: {code}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TrainThread(threading.Thread):
|
|||
|
|
def __init__(self, request_id, process_type, version):
|
|||
|
|
super().__init__()
|
|||
|
|
self.request_id = request_id
|
|||
|
|
self.process_type = process_type
|
|||
|
|
self.version = version
|
|||
|
|
self._stop_event = threading.Event()
|
|||
|
|
|
|||
|
|
def run(self):
|
|||
|
|
"""Simulate a long-running process with progress updates"""
|
|||
|
|
for i in range(1, 11):
|
|||
|
|
if self._stop_event.is_set():
|
|||
|
|
print(f"Training {self.request_id} stopped by user request")
|
|||
|
|
training_tasks[self.request_id] = {
|
|||
|
|
'request_id': self.request_id,
|
|||
|
|
'process_type': self.process_type,
|
|||
|
|
'progress': f'{i*10}%',
|
|||
|
|
'model_version': self.version,
|
|||
|
|
'status': 'stopped'
|
|||
|
|
}
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
time.sleep(2)
|
|||
|
|
progress = i * 10
|
|||
|
|
|
|||
|
|
training_tasks[self.request_id] = {
|
|||
|
|
'request_id': self.request_id,
|
|||
|
|
'process_type': self.process_type,
|
|||
|
|
'progress': f'{progress}%',
|
|||
|
|
'model_version': self.version,
|
|||
|
|
'status': 'running'
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# Mark as completed when done
|
|||
|
|
if not self._stop_event.is_set():
|
|||
|
|
training_tasks[self.request_id] = {
|
|||
|
|
'request_id': self.request_id,
|
|||
|
|
'process_type': self.process_type,
|
|||
|
|
'progress': '100%',
|
|||
|
|
'model_version': self.version,
|
|||
|
|
'status': 'completed'
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
def stop(self):
|
|||
|
|
"""Signal the thread to stop"""
|
|||
|
|
self._stop_event.set()
|
|||
|
|
|
|||
|
|
def generate_request_id(length=30):
|
|||
|
|
return 'bb' + ''.join(random.sample(string.ascii_letters, length))
|
|||
|
|
|
|||
|
|
def create_status_updater(process_type):
|
|||
|
|
"""Create a status update function for a specific process type"""
|
|||
|
|
def status_updater():
|
|||
|
|
while True:
|
|||
|
|
try:
|
|||
|
|
for task_id, task_data in list(training_tasks.items()):
|
|||
|
|
# Only update tasks of this process type
|
|||
|
|
if task_data.get('process_type') != process_type:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# Skip if task was stopped
|
|||
|
|
if task_data.get('status') == 'stopped':
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# Create payload based on process type
|
|||
|
|
if process_type == 'training':
|
|||
|
|
payload = {
|
|||
|
|
"id": task_id,
|
|||
|
|
"dspModelCode": 502,
|
|||
|
|
"algorithmMap": "123",
|
|||
|
|
"algorithmAccuracy": "123",
|
|||
|
|
"algorithmRecall": "123",
|
|||
|
|
"aiVersion": "V1.1",
|
|||
|
|
"progress": task_data.get('progress', '0%'),
|
|||
|
|
"traiStartTime": "2025-05-21 16:10:05",
|
|||
|
|
"trainEndTime": "2025-05-21 16:10:05",
|
|||
|
|
"trainDuration": "100",
|
|||
|
|
"status": 4 if task_data.get('status') == 'completed' else 2
|
|||
|
|
}
|
|||
|
|
elif process_type == 'check':
|
|||
|
|
payload = {
|
|||
|
|
"id": task_id,
|
|||
|
|
"checkStatus": "2" if task_data.get('status') == 'completed' else "running",
|
|||
|
|
"aiFileImage": "http://172.18.0.129:8084/5116fd79-7f98-4089-a59f-ec4df2cdb0f3_ai.jpg",
|
|||
|
|
"checkTaskDetails": [
|
|||
|
|
{
|
|||
|
|
"dspModelCode": "123",
|
|||
|
|
"dspModelName": "行人识别",
|
|||
|
|
"possibility": "30%"
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"dspModelCode": "123",
|
|||
|
|
"dspModelName": "车辆识别",
|
|||
|
|
"possibility": "50%"
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"checkTime": "2025-05-21 16:10:05"
|
|||
|
|
}
|
|||
|
|
elif process_type == 'management':
|
|||
|
|
payload = {
|
|||
|
|
"id": task_id,
|
|||
|
|
"modelStatus": "active",
|
|||
|
|
"lastUpdated": "2025-05-21 16:10:05"
|
|||
|
|
}
|
|||
|
|
elif process_type == 'deployment':
|
|||
|
|
payload = {
|
|||
|
|
"id": task_id,
|
|||
|
|
"status": "4" if task_data.get('status') == 'completed' else "processing",
|
|||
|
|
"deployTime": "2025-05-21 16:10:05"
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
print(f"Sending {process_type} status update:", payload)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
response = requests.post(CALLBACK_URLS[process_type], json=payload)
|
|||
|
|
if response.status_code == 200:
|
|||
|
|
print(f"{process_type} status update successful")
|
|||
|
|
else:
|
|||
|
|
print(f"{process_type} status update failed with status code: {response.status_code}")
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"Error sending {process_type} status update: {str(e)}")
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"Error in {process_type} status update thread: {str(e)}")
|
|||
|
|
|
|||
|
|
time.sleep(UPDATE_INTERVAL)
|
|||
|
|
return status_updater
|
|||
|
|
|
|||
|
|
# Start status update threads for each process type
|
|||
|
|
for process_type in CALLBACK_URLS:
|
|||
|
|
thread = threading.Thread(
|
|||
|
|
target=create_status_updater(process_type),
|
|||
|
|
name=f"{process_type}_status_updater"
|
|||
|
|
)
|
|||
|
|
thread.daemon = True
|
|||
|
|
thread.start()
|
|||
|
|
|
|||
|
|
class TrainingTask(Resource):
|
|||
|
|
def post(self):
|
|||
|
|
data = request.json
|
|||
|
|
|
|||
|
|
required_fields = ['code', 'version', 'image_dir', 'parameters']
|
|||
|
|
if not all(field in data for field in required_fields):
|
|||
|
|
return {'error': 'Missing required fields'}, 400
|
|||
|
|
|
|||
|
|
version = data['version']
|
|||
|
|
request_id = data['request_id']
|
|||
|
|
|
|||
|
|
# Create and start the training thread
|
|||
|
|
thread = TrainThread(request_id, 'training', version)
|
|||
|
|
active_threads[request_id] = thread
|
|||
|
|
thread.start()
|
|||
|
|
|
|||
|
|
training_tasks[request_id] = {
|
|||
|
|
'request_id': request_id,
|
|||
|
|
'process_type': 'training',
|
|||
|
|
'progress': '0%',
|
|||
|
|
'model_version': version,
|
|||
|
|
'status': 'initializing'
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
'request_id': request_id,
|
|||
|
|
'code': '0',
|
|||
|
|
'msg': 'Training task created',
|
|||
|
|
"data": None
|
|||
|
|
}, 201
|
|||
|
|
|
|||
|
|
def get(self, request_id=None):
|
|||
|
|
if request_id:
|
|||
|
|
task = training_tasks.get(request_id)
|
|||
|
|
if not task:
|
|||
|
|
return {'error': 'Task not found'}, 404
|
|||
|
|
return task
|
|||
|
|
else:
|
|||
|
|
return {
|
|||
|
|
'tasks': [{
|
|||
|
|
'request_id': k,
|
|||
|
|
'process_type': v.get('process_type', 'unknown'),
|
|||
|
|
'status': v['status'],
|
|||
|
|
'progress': v['progress'],
|
|||
|
|
'model_version': v['model_version']
|
|||
|
|
} for k, v in training_tasks.items()]
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
class TrainingTask_Stop0(Resource):
|
|||
|
|
def post(self):
|
|||
|
|
|
|||
|
|
data = request.json
|
|||
|
|
request_id = data['request_id']
|
|||
|
|
|
|||
|
|
|
|||
|
|
if request_id in active_threads:
|
|||
|
|
# Stop the thread
|
|||
|
|
active_threads[request_id].stop()
|
|||
|
|
# Remove from active threads
|
|||
|
|
del active_threads[request_id]
|
|||
|
|
|
|||
|
|
if request_id in training_tasks:
|
|||
|
|
training_tasks[request_id]['status'] = 'stopped'
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
'request_id': request_id,
|
|||
|
|
'code': '0',
|
|||
|
|
'msg': 'Training stopped successfully',
|
|||
|
|
"data": None
|
|||
|
|
}, 200
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TrainingTask_Stop(Resource):
|
|||
|
|
def post(self):
|
|||
|
|
data = request.json
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
# {"Command":"Stop","Errorcode":"","ModelScenetype":"train","request_id":"jcq13f4353a9312838c1d647f2593d3c3a6"}
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
# {
|
|||
|
|
# "request_id": "jcq13f4353a9312838c1d647f2593d3c3a6",
|
|||
|
|
# "ModelSceneType": "train",
|
|||
|
|
# "Command": "Stop"
|
|||
|
|
# }
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
# # 验证必要字段
|
|||
|
|
# if 'request_id' not in data or 'ModelSceneType' not in data or 'Command' not in data:
|
|||
|
|
# return {
|
|||
|
|
# 'code': '1',
|
|||
|
|
# 'msg': 'Missing required fields (request_id, ModelSceneType, Command)',
|
|||
|
|
# 'data': None
|
|||
|
|
# }, 400
|
|||
|
|
|
|||
|
|
# 转换消息格式
|
|||
|
|
training_request = {
|
|||
|
|
"Scene": data['ModelScenetype'].capitalize(), # 首字母大写
|
|||
|
|
"Command": data['Command'].lower(), # 转为小写
|
|||
|
|
"Request_ID": data['request_id']
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# 发送到Kafka
|
|||
|
|
try:
|
|||
|
|
producer = KafkaProducer(**KAFKA_CONFIG)
|
|||
|
|
future = producer.send(TOPIC_TRAIN, value=training_request)
|
|||
|
|
|
|||
|
|
# 等待消息发送完成
|
|||
|
|
future.get(timeout=10)
|
|||
|
|
producer.flush()
|
|||
|
|
|
|||
|
|
# 更新本地任务状态
|
|||
|
|
if data['request_id'] in active_threads:
|
|||
|
|
active_threads[data['request_id']].stop()
|
|||
|
|
del active_threads[data['request_id']]
|
|||
|
|
|
|||
|
|
if data['request_id'] in training_tasks:
|
|||
|
|
training_tasks[data['request_id']]['status'] = 'stopped'
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
'request_id': data['request_id'],
|
|||
|
|
'code': '0',
|
|||
|
|
'msg': 'Stop command sent successfully',
|
|||
|
|
'data': None
|
|||
|
|
}, 200
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
return {
|
|||
|
|
'request_id': data['request_id'],
|
|||
|
|
'code': '2',
|
|||
|
|
'msg': f'Failed to send stop command: {str(e)}',
|
|||
|
|
'data': None
|
|||
|
|
}, 500
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
def manager_simulate_process(ModelCode):
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
# request_id = data.get('request_id', "7d0cd072a7a939e77c4a602ac6151582")
|
|||
|
|
# version = data.get('version', "1.0")
|
|||
|
|
|
|||
|
|
# ModelCode = data.get("ModelCode")
|
|||
|
|
|
|||
|
|
if str(ModelCode) in ["000","032"]:
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 必须得特定的数据格式才可以
|
|||
|
|
callback_data = {
|
|||
|
|
"isExist":1
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
else:
|
|||
|
|
|
|||
|
|
# 必须得特定的数据格式才可以
|
|||
|
|
callback_data = {
|
|||
|
|
"isExist":2
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# # 更新本地任务状态(如果需要)
|
|||
|
|
# training_tasks[request_id] = callback_data
|
|||
|
|
|
|||
|
|
training_tasks=callback_data
|
|||
|
|
|
|||
|
|
# 发送回调请求
|
|||
|
|
callback_url = "http://172.18.0.129:8084/api/admin/modelApplication/aiCallBack"
|
|||
|
|
headers = {
|
|||
|
|
'Content-Type': 'application/json',
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
response = requests.post(
|
|||
|
|
callback_url,
|
|||
|
|
data=json.dumps(callback_data),
|
|||
|
|
headers=headers
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 检查响应状态
|
|||
|
|
if response.status_code == 200:
|
|||
|
|
print(f"回调成功: {response.json()}")
|
|||
|
|
else:
|
|||
|
|
print(f"回调失败,状态码: {response.status_code}, 响应: {response.text}")
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"回调请求异常: {str(e)}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
def stop_simulate_process(request_id, process_type, version="1.0"):
|
|||
|
|
|
|||
|
|
# 必须得特定的数据格式才可以
|
|||
|
|
callback_data = {
|
|||
|
|
"id": request_id,
|
|||
|
|
"status":4
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 更新本地任务状态(如果需要)
|
|||
|
|
training_tasks[request_id] = callback_data
|
|||
|
|
|
|||
|
|
# 发送回调请求
|
|||
|
|
callback_url = "http://172.18.0.129:8084/api/admin/modelApplication/aiCallBack"
|
|||
|
|
headers = {
|
|||
|
|
'Content-Type': 'application/json',
|
|||
|
|
# 如果需要认证,添加相应的头部
|
|||
|
|
# 'Authorization': 'Bearer your_token_here'
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
response = requests.post(
|
|||
|
|
callback_url,
|
|||
|
|
data=json.dumps(callback_data),
|
|||
|
|
headers=headers
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 检查响应状态
|
|||
|
|
if response.status_code == 200:
|
|||
|
|
print(f"回调成功: {response.json()}")
|
|||
|
|
else:
|
|||
|
|
print(f"回调失败,状态码: {response.status_code}, 响应: {response.text}")
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"回调请求异常: {str(e)}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
# √
|
|||
|
|
def deploy_simulate_process(request_id, process_type, version="1.0"):
|
|||
|
|
|
|||
|
|
|
|||
|
|
callback_data = {
|
|||
|
|
"id": request_id,
|
|||
|
|
"status":4
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 更新本地任务状态(如果需要)
|
|||
|
|
training_tasks[request_id] = callback_data
|
|||
|
|
|
|||
|
|
# 发送回调请求
|
|||
|
|
callback_url = "http://172.18.0.129:8084/api/admin/modelApplication/aiCallBack"
|
|||
|
|
headers = {
|
|||
|
|
'Content-Type': 'application/json',
|
|||
|
|
# 如果需要认证,添加相应的头部
|
|||
|
|
# 'Authorization': 'Bearer your_token_here'
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
response = requests.post(
|
|||
|
|
callback_url,
|
|||
|
|
data=json.dumps(callback_data),
|
|||
|
|
headers=headers
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 检查响应状态
|
|||
|
|
if response.status_code == 200:
|
|||
|
|
print(f"回调成功: {response.json()}")
|
|||
|
|
else:
|
|||
|
|
print(f"回调失败,状态码: {response.status_code}, 响应: {response.text}")
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"回调请求异常: {str(e)}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
# √
|
|||
|
|
def check_simulate_process(input_json):
|
|||
|
|
|
|||
|
|
|
|||
|
|
callback_data = check_task_api.infer_yolov5(input_json)
|
|||
|
|
print("*******line366",callback_data)
|
|||
|
|
|
|||
|
|
request_id = input_json.get("request_id", "")
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 更新本地任务状态(如果需要)
|
|||
|
|
training_tasks[request_id] = callback_data
|
|||
|
|
|
|||
|
|
# 发送回调请求
|
|||
|
|
callback_url = "http://172.18.0.129:8084/api/admin/checkTask/callback"
|
|||
|
|
headers = {
|
|||
|
|
'Content-Type': 'application/json',
|
|||
|
|
# 如果需要认证,添加相应的头部
|
|||
|
|
# 'Authorization': 'Bearer your_token_here'
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
response = requests.post(
|
|||
|
|
callback_url,
|
|||
|
|
data=json.dumps(callback_data),
|
|||
|
|
headers=headers
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 检查响应状态
|
|||
|
|
if response.status_code == 200:
|
|||
|
|
print(f"回调成功: {response.json()}")
|
|||
|
|
else:
|
|||
|
|
print(f"回调失败,状态码: {response.status_code}, 响应: {response.text}")
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"回调请求异常: {str(e)}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TrainingTask_Check(Resource):
|
|||
|
|
def post(self):
|
|||
|
|
input_json = request.json
|
|||
|
|
|
|||
|
|
print("###line587",input_json)
|
|||
|
|
|
|||
|
|
data = request.json
|
|||
|
|
request_id = data['request_id']
|
|||
|
|
|
|||
|
|
# Start the inference in a new thread
|
|||
|
|
thread = threading.Thread(
|
|||
|
|
target=check_simulate_process,
|
|||
|
|
args=(input_json,)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
thread.start()
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
'id': request_id,
|
|||
|
|
'code': '0',
|
|||
|
|
'msg': 'Model check started',
|
|||
|
|
'data': None
|
|||
|
|
}, 201
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TrainingTask_Manager(Resource):
|
|||
|
|
def post(self):
|
|||
|
|
data = request.json
|
|||
|
|
ModelCode = data.get("ModelCode")
|
|||
|
|
|
|||
|
|
# Convert to string and pad with leading zeros (if it's a number)
|
|||
|
|
if isinstance(ModelCode, int):
|
|||
|
|
model_code_str = f"{ModelCode:03d}" # Formats as 3 digits (e.g., 5 → "005")
|
|||
|
|
else:
|
|||
|
|
model_code_str = str(ModelCode).zfill(3) # Pads with zeros (e.g., "5" → "005")
|
|||
|
|
|
|||
|
|
if 0 <= int(model_code_str) <= 32:
|
|||
|
|
callback_data = {
|
|||
|
|
"isExist": 1
|
|||
|
|
}
|
|||
|
|
else:
|
|||
|
|
callback_data = {
|
|||
|
|
"isExist": 0
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return callback_data , 201
|
|||
|
|
|
|||
|
|
|
|||
|
|
model_dict = {
|
|||
|
|
"001": "river",
|
|||
|
|
"002": "forest2",
|
|||
|
|
"003": "highWay2",
|
|||
|
|
"006": "vehicle",
|
|||
|
|
"007": "pedestrian",
|
|||
|
|
"008": "smogfire",
|
|||
|
|
"009": "AnglerSwimmer",
|
|||
|
|
"010": "countryRoad",
|
|||
|
|
"011": "ship2",
|
|||
|
|
"013": "channelEmergency",
|
|||
|
|
"014": "forest2",
|
|||
|
|
"015": "river2",
|
|||
|
|
"016": "cityMangement2",
|
|||
|
|
"017": "drowning",
|
|||
|
|
"018": "noParking",
|
|||
|
|
"019": "illParking",
|
|||
|
|
"020": "cityRoad",
|
|||
|
|
"023": "pothole",
|
|||
|
|
"024": "channel2",
|
|||
|
|
"025": "riverT",
|
|||
|
|
"026": "forestCrowd",
|
|||
|
|
"027": "highWay2T",
|
|||
|
|
"028": "smartSite",
|
|||
|
|
"029": "rubbish",
|
|||
|
|
"030": "firework"
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
ENDPOINT = "minio-jndsj.t-aaron.com:2443"
|
|||
|
|
ACCESS_KEY = "PJM0c2qlauoXv5TMEHm2"
|
|||
|
|
SECRET_KEY = "Wr69Dm3ZH39M3GCSeyB3eFLynLPuGCKYfphixZuI"
|
|||
|
|
BUCKET_NAME = "algorithm"
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
PREFIX = "weights/026/v1.0/"
|
|||
|
|
LOCAL_DIR = "downloaded_weights" # 本地存储目录
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
# Remote server configuration
|
|||
|
|
REMOTE_HOST = "172.16.225.151"
|
|||
|
|
REMOTE_USER = "root"
|
|||
|
|
REMOTE_PASSWORD = "P#mnuJ6r4A"
|
|||
|
|
REMOTE_BASE_PATH = "/home/th/jcq/AI_AutoPlat/"
|
|||
|
|
|
|||
|
|
def get_model_name(code):
|
|||
|
|
# Ensure code is a 3-digit string
|
|||
|
|
code_str = f"{int(code):03d}" if str(code).isdigit() else code
|
|||
|
|
return model_dict.get(code_str, f"Unknown code: {code}")
|
|||
|
|
|
|||
|
|
def parse_minio_path(model_dir):
|
|||
|
|
"""Extract bucket name and object prefix from MinIO URL"""
|
|||
|
|
parsed = urllib.parse.urlparse(model_dir)
|
|||
|
|
if not parsed.netloc.startswith('minio-jndsj.t-aaron.com:2443'):
|
|||
|
|
raise ValueError("Invalid MinIO URL")
|
|||
|
|
|
|||
|
|
path_parts = parsed.path.lstrip('/').split('/')
|
|||
|
|
if len(path_parts) < 3:
|
|||
|
|
raise ValueError("Invalid MinIO path format")
|
|||
|
|
|
|||
|
|
bucket = path_parts[0]
|
|||
|
|
# The prefix is everything after the bucket name and before the filename
|
|||
|
|
prefix = '/'.join(path_parts[1:-1]) # Exclude the filename
|
|||
|
|
filename = path_parts[-1]
|
|||
|
|
|
|||
|
|
return bucket, prefix, filename
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
def download_minio_files(model_dir, local_dir='downloads'):
|
|||
|
|
"""Download files from MinIO based on the model_dir URL"""
|
|||
|
|
try:
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
ENDPOINT = "minio-jndsj.t-aaron.com:2443"
|
|||
|
|
ACCESS_KEY = "PJM0c2qlauoXv5TMEHm2"
|
|||
|
|
SECRET_KEY = "Wr69Dm3ZH39M3GCSeyB3eFLynLPuGCKYfphixZuI"
|
|||
|
|
|
|||
|
|
BUCKET_NAME = "algorithm"
|
|||
|
|
|
|||
|
|
|
|||
|
|
bucket, prefix, filename = parse_minio_path(model_dir)
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
# Create MinIO client
|
|||
|
|
client = Minio(
|
|||
|
|
ENDPOINT,
|
|||
|
|
access_key=ACCESS_KEY,
|
|||
|
|
secret_key=SECRET_KEY,
|
|||
|
|
secure=False
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Ensure local directory exists
|
|||
|
|
os.makedirs(local_dir, exist_ok=True)
|
|||
|
|
|
|||
|
|
# Build full object path
|
|||
|
|
object_path = f"{prefix}/{filename}"
|
|||
|
|
|
|||
|
|
# Local file path
|
|||
|
|
local_path = os.path.join(local_dir, filename)
|
|||
|
|
|
|||
|
|
# Download the file
|
|||
|
|
client.fget_object(bucket, object_path, local_path)
|
|||
|
|
|
|||
|
|
print(f"Downloaded: {object_path} -> {local_path}")
|
|||
|
|
return local_path
|
|||
|
|
|
|||
|
|
except S3Error as err:
|
|||
|
|
print(f"MinIO Error occurred: {err}")
|
|||
|
|
return None
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"Error occurred: {e}")
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
def download_minio_deploy(endpoint, access_key, secret_key, bucket_name,prefix='', local_dir='downloaded_weights'):
|
|||
|
|
|
|||
|
|
# 创建 MinIO 客户端
|
|||
|
|
client = Minio(
|
|||
|
|
endpoint,
|
|||
|
|
access_key=access_key,
|
|||
|
|
secret_key=secret_key,
|
|||
|
|
secure=True # 使用 HTTPS
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
|
|||
|
|
# 确保本地目录存在
|
|||
|
|
os.makedirs(local_dir, exist_ok=True)
|
|||
|
|
|
|||
|
|
# 列出文件夹中的所有对象
|
|||
|
|
objects = client.list_objects(bucket_name, prefix=prefix, recursive=True)
|
|||
|
|
|
|||
|
|
for obj in objects:
|
|||
|
|
# 构建本地文件路径
|
|||
|
|
local_path = os.path.join(local_dir, os.path.relpath(obj.object_name, prefix))
|
|||
|
|
|
|||
|
|
# 确保子目录存在
|
|||
|
|
os.makedirs(os.path.dirname(local_path), exist_ok=True)
|
|||
|
|
|
|||
|
|
# 下载文件
|
|||
|
|
print(f"正在下载: {obj.object_name} -> {local_path}")
|
|||
|
|
client.fget_object(bucket_name, obj.object_name, local_path)
|
|||
|
|
|
|||
|
|
print("所有文件下载完成!")
|
|||
|
|
|
|||
|
|
return local_path
|
|||
|
|
|
|||
|
|
except S3Error as e:
|
|||
|
|
print(f"MinIO 错误: {e}")
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"发生错误: {e}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
def deploy_to_remote(local_path, model_code):
|
|||
|
|
"""Deploy file to remote server"""
|
|||
|
|
try:
|
|||
|
|
model_name = get_model_name(model_code)
|
|||
|
|
remote_path = os.path.join(REMOTE_BASE_PATH, model_name, os.path.basename(local_path))
|
|||
|
|
|
|||
|
|
# Create SSH client
|
|||
|
|
ssh = paramiko.SSHClient()
|
|||
|
|
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
|||
|
|
ssh.connect(REMOTE_HOST, username=REMOTE_USER, password=REMOTE_PASSWORD)
|
|||
|
|
|
|||
|
|
# Create SCP client
|
|||
|
|
scp = SCPClient(ssh.get_transport())
|
|||
|
|
|
|||
|
|
# Upload file
|
|||
|
|
scp.put(local_path, remote_path)
|
|||
|
|
print(f"Successfully deployed {local_path} to {remote_path}")
|
|||
|
|
|
|||
|
|
# Close connections
|
|||
|
|
scp.close()
|
|||
|
|
ssh.close()
|
|||
|
|
|
|||
|
|
return remote_path
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"Deployment failed: {e}")
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
def deploy_to_local(local_path, model_code):
|
|||
|
|
"""Deploy file to remote server"""
|
|||
|
|
try:
|
|||
|
|
model_name = get_model_name(model_code)
|
|||
|
|
|
|||
|
|
REMOTE_BASE_PATH = "/home/th/jcq/AIlib/Deployment"
|
|||
|
|
remote_path = os.path.join(REMOTE_BASE_PATH, model_name, os.path.basename(local_path))
|
|||
|
|
|
|||
|
|
|
|||
|
|
print("###line858",REMOTE_BASE_PATH,model_name,local_path)
|
|||
|
|
|
|||
|
|
#/home/th/jcq/AIlib/Deployment
|
|||
|
|
|
|||
|
|
# # Create SSH client
|
|||
|
|
# ssh = paramiko.SSHClient()
|
|||
|
|
# ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
|||
|
|
# ssh.connect(REMOTE_HOST, username=REMOTE_USER, password=REMOTE_PASSWORD)
|
|||
|
|
|
|||
|
|
# # Create SCP client
|
|||
|
|
# scp = SCPClient(ssh.get_transport())
|
|||
|
|
|
|||
|
|
# # Upload file
|
|||
|
|
# scp.put(local_path, remote_path)
|
|||
|
|
|
|||
|
|
# 确保目标目录存在
|
|||
|
|
os.makedirs(os.path.dirname(remote_path), exist_ok=True)
|
|||
|
|
# 移动文件
|
|||
|
|
shutil.move(local_path, remote_path)
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
print(f"Successfully deployed {local_path} to {remote_path}")
|
|||
|
|
|
|||
|
|
# # Close connections
|
|||
|
|
# scp.close()
|
|||
|
|
# ssh.close()
|
|||
|
|
|
|||
|
|
return remote_path
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"Deployment failed: {e}")
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
def deployment_worker(model_dir, model_code, request_id):
|
|||
|
|
"""Worker function that runs in a separate process to handle deployment"""
|
|||
|
|
callback_url = "http://172.18.0.129:8084/api/admin/modelApplication/aiCallBack"
|
|||
|
|
success = False
|
|||
|
|
message = ""
|
|||
|
|
remote_path = ""
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# Parse the path to get bucket and prefix
|
|||
|
|
bucket, prefix, filename = parse_minio_path(model_dir)
|
|||
|
|
print(f"[Deployment Worker] Downloading from bucket: {bucket}, prefix: {prefix}, filename: {filename}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 配置信息
|
|||
|
|
ENDPOINT = "minio-jndsj.t-aaron.com:2443"
|
|||
|
|
ACCESS_KEY = "PJM0c2qlauoXv5TMEHm2"
|
|||
|
|
SECRET_KEY = "Wr69Dm3ZH39M3GCSeyB3eFLynLPuGCKYfphixZuI"
|
|||
|
|
|
|||
|
|
|
|||
|
|
model_path = download_minio_deploy(
|
|||
|
|
endpoint=ENDPOINT,
|
|||
|
|
access_key=ACCESS_KEY,
|
|||
|
|
secret_key=SECRET_KEY,
|
|||
|
|
bucket_name=bucket,
|
|||
|
|
prefix=prefix,
|
|||
|
|
local_dir=LOCAL_DIR
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
print("###line890",model_path)
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
# # 2. Deploy to remote server
|
|||
|
|
# remote_path = deploy_to_remote(model_path, model_code)
|
|||
|
|
|
|||
|
|
remote_path = deploy_to_local(model_path, model_code)
|
|||
|
|
if not remote_path:
|
|||
|
|
message = "Failed to deploy model to remote server"
|
|||
|
|
print(f"[Deployment Worker] {message}")
|
|||
|
|
raise Exception(message)
|
|||
|
|
|
|||
|
|
# If we get here, deployment was successful
|
|||
|
|
success = True
|
|||
|
|
message = "操作成功"
|
|||
|
|
print(f"[Deployment Worker] Deployment successful: {remote_path}")
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
message = f"Deployment failed: {str(e)}"
|
|||
|
|
print(f"[Deployment Worker] {message}")
|
|||
|
|
success = False
|
|||
|
|
|
|||
|
|
finally:
|
|||
|
|
# Clean up local file if it exists
|
|||
|
|
if 'local_path' in locals() and os.path.exists(model_path):
|
|||
|
|
try:
|
|||
|
|
os.remove(model_path)
|
|||
|
|
except OSError as e:
|
|||
|
|
print(f"[Deployment Worker] Warning: Failed to clean up local file: {e}")
|
|||
|
|
|
|||
|
|
# Send callback
|
|||
|
|
try:
|
|||
|
|
if success:
|
|||
|
|
callback_data = {
|
|||
|
|
"id":request_id,
|
|||
|
|
"status": 3,
|
|||
|
|
"msg": "操作成功",
|
|||
|
|
"data": None # Using None instead of "null" string
|
|||
|
|
}
|
|||
|
|
else:
|
|||
|
|
callback_data = {
|
|||
|
|
"id":request_id,
|
|||
|
|
"status": 4,
|
|||
|
|
"msg": "未查询到数据",
|
|||
|
|
"data": None # Using None instead of "null" string
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
response = requests.post(
|
|||
|
|
callback_url,
|
|||
|
|
json=callback_data,
|
|||
|
|
timeout=10
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if response.status_code != 200:
|
|||
|
|
print(f"[Deployment Worker] Callback failed with status {response.status_code}: {response.text}")
|
|||
|
|
print("line956",callback_data)
|
|||
|
|
else:
|
|||
|
|
print("[Deployment Worker] Callback sent successfully")
|
|||
|
|
print("line959",callback_data)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"[Deployment Worker] Failed to send callback: {str(e)}")
|
|||
|
|
|
|||
|
|
return success
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.route('/api/train/deploy', methods=['POST'])
|
|||
|
|
def deploy_model():
|
|||
|
|
# Get input data
|
|||
|
|
data = request.get_json()
|
|||
|
|
|
|||
|
|
print("###line879",data)
|
|||
|
|
|
|||
|
|
if not data or 'ModelDir' not in data or 'ModelCode' not in data:
|
|||
|
|
return jsonify({
|
|||
|
|
"status": "error",
|
|||
|
|
"message": "Missing required fields (ModelDir or ModelCode)"
|
|||
|
|
}), 400
|
|||
|
|
|
|||
|
|
# model_dir = data['ModelDir']
|
|||
|
|
|
|||
|
|
model_dir = os.path.join(data['ModelDir'] ,'yolov5.pt')
|
|||
|
|
|
|||
|
|
# model_dir = os.path.join(data['ModelDir'] ,'best.pt')
|
|||
|
|
|
|||
|
|
|
|||
|
|
model_code = data['ModelCode']
|
|||
|
|
|
|||
|
|
# Generate a unique request ID if not provided
|
|||
|
|
request_id = data.get('request_id', str(uuid.uuid4()))
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# Start deployment in a separate process
|
|||
|
|
process = multiprocessing.Process(
|
|||
|
|
target=deployment_worker,
|
|||
|
|
args=(model_dir, model_code,request_id)
|
|||
|
|
)
|
|||
|
|
process.start()
|
|||
|
|
|
|||
|
|
|
|||
|
|
# Immediately return response while deployment continues in background
|
|||
|
|
#
|
|||
|
|
return jsonify({
|
|||
|
|
"request_id": request_id,
|
|||
|
|
"status": 4
|
|||
|
|
}), 202 # 202 Accepted status code
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
return jsonify({
|
|||
|
|
"status": "error",
|
|||
|
|
"message": f"Failed to start deployment: {str(e)}",
|
|||
|
|
"request_id": request_id
|
|||
|
|
}), 500
|
|||
|
|
|
|||
|
|
|
|||
|
|
# Configure logging
|
|||
|
|
logging.basicConfig(level=logging.INFO)
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
# Kafka configuration
|
|||
|
|
KAFKA_CONFIG = {
|
|||
|
|
'bootstrap_servers': '172.18.0.126:9094',
|
|||
|
|
'value_serializer': lambda v: json.dumps(v).encode('utf-8')
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# Kafka topics
|
|||
|
|
TOPIC_TRAIN = 'training-tasks'
|
|||
|
|
|
|||
|
|
model_dict = {
|
|||
|
|
"001": "river",
|
|||
|
|
"002": "forest2",
|
|||
|
|
"003": "highWay2",
|
|||
|
|
"006": "vehicle",
|
|||
|
|
"007": "pedestrian",
|
|||
|
|
"008": "smogfire",
|
|||
|
|
"009": "AnglerSwimmer",
|
|||
|
|
"010": "countryRoad",
|
|||
|
|
"011": "ship2",
|
|||
|
|
"013": "channelEmergency",
|
|||
|
|
"014": "forest2",
|
|||
|
|
"015": "river2",
|
|||
|
|
"016": "cityMangement2",
|
|||
|
|
"017": "drowning",
|
|||
|
|
"018": "noParking",
|
|||
|
|
"019": "illParking",
|
|||
|
|
"020": "cityRoad",
|
|||
|
|
"023": "pothole",
|
|||
|
|
"024": "channel2",
|
|||
|
|
"025": "riverT",
|
|||
|
|
"026": "forestCrowd",
|
|||
|
|
"027": "highWay2T",
|
|||
|
|
"028": "smartSite",
|
|||
|
|
"029": "rubbish",
|
|||
|
|
"030": "firework"
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
def validate_minio_url(url):
|
|||
|
|
"""验证MinIO URL格式是否正确"""
|
|||
|
|
try:
|
|||
|
|
parsed = urllib.parse.urlparse(url)
|
|||
|
|
if not all([parsed.scheme in ('http', 'https'), parsed.netloc, len(parsed.path.split('/')) >= 3]):
|
|||
|
|
raise ValueError("Invalid MinIO URL format")
|
|||
|
|
return parsed
|
|||
|
|
except Exception as e:
|
|||
|
|
raise ValueError(f"URL validation failed: {str(e)}")
|
|||
|
|
|
|||
|
|
def download_minio_directory(minio_url, local_base_dir):
|
|||
|
|
"""
|
|||
|
|
下载MinIO目录到本地
|
|||
|
|
|
|||
|
|
:param minio_url: 完整的MinIO路径 (e.g. "http://172.18.0.129:8084/th-dsp/027123")
|
|||
|
|
:param local_base_dir: 本地存储基础目录 (e.g. "/home/th/jcq/AI_AutoPlat/DATA/TrainDatasets")
|
|||
|
|
:return: (success, message) 元组
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
# 验证并解析URL
|
|||
|
|
parsed = validate_minio_url(minio_url)
|
|||
|
|
endpoint = parsed.netloc
|
|||
|
|
bucket_name = parsed.path.split('/')[1]
|
|||
|
|
prefix = '/'.join(parsed.path.split('/')[2:])
|
|||
|
|
|
|||
|
|
# 从环境变量获取凭证(推荐方式)
|
|||
|
|
access_key = os.getenv("MINIO_ACCESS_KEY", "PJM0c2qlauoXv5TMEHm2") # 默认值仅用于测试
|
|||
|
|
secret_key = os.getenv("MINIO_SECRET_KEY", "Wr69Dm3ZH39M3GCSeyB3eFLynLPuGCKYfphixZuI")
|
|||
|
|
|
|||
|
|
# 创建本地目录(使用prefix的最后部分作为目录名)
|
|||
|
|
dir_name = os.path.basename(prefix) or "minio_download"
|
|||
|
|
local_dir = os.path.join(local_base_dir, dir_name)
|
|||
|
|
os.makedirs(local_dir, exist_ok=True)
|
|||
|
|
|
|||
|
|
# 初始化MinIO客户端
|
|||
|
|
client = Minio(
|
|||
|
|
endpoint,
|
|||
|
|
access_key=access_key,
|
|||
|
|
secret_key=secret_key,
|
|||
|
|
secure=parsed.scheme == 'https'
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
print(f"Starting download from {minio_url} to {local_dir}")
|
|||
|
|
|
|||
|
|
# 递归下载所有文件
|
|||
|
|
objects = client.list_objects(bucket_name, prefix=prefix, recursive=True)
|
|||
|
|
downloaded_files = 0
|
|||
|
|
|
|||
|
|
for obj in objects:
|
|||
|
|
# 构建本地路径(保持相对目录结构)
|
|||
|
|
relative_path = os.path.relpath(obj.object_name, prefix)
|
|||
|
|
local_path = os.path.join(local_dir, relative_path)
|
|||
|
|
|
|||
|
|
# 确保目录存在
|
|||
|
|
os.makedirs(os.path.dirname(local_path), exist_ok=True)
|
|||
|
|
|
|||
|
|
# 下载文件
|
|||
|
|
client.fget_object(bucket_name, obj.object_name, local_path)
|
|||
|
|
downloaded_files += 1
|
|||
|
|
|
|||
|
|
print(f"Downloaded [{downloaded_files}]: {obj.object_name} -> {local_path}")
|
|||
|
|
|
|||
|
|
print(f"Download completed! Total files: {downloaded_files}")
|
|||
|
|
return True, local_dir
|
|||
|
|
|
|||
|
|
except S3Error as e:
|
|||
|
|
error_msg = f"MinIO Error: {str(e)}"
|
|||
|
|
if "NoSuchBucket" in str(e):
|
|||
|
|
error_msg = f"Bucket not found: {bucket_name}"
|
|||
|
|
return False, error_msg
|
|||
|
|
except ValueError as e:
|
|||
|
|
return False, f"Invalid URL: {str(e)}"
|
|||
|
|
except Exception as e:
|
|||
|
|
return False, f"Unexpected error: {str(e)}"
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
def split_dataset(data_dir, train_ratio=0.8):
|
|||
|
|
"""Split dataset into training and validation sets, supporting both JPG and PNG formats
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
data_dir (str): Directory containing images and labels
|
|||
|
|
train_ratio (float): Ratio of training data (default: 0.8)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
tuple: (train_files, val_files) lists of image filenames
|
|||
|
|
"""
|
|||
|
|
# Supported image extensions
|
|||
|
|
IMAGE_EXTENSIONS = ('.jpg', '.jpeg', '.png')
|
|||
|
|
|
|||
|
|
# Find all valid image-label pairs
|
|||
|
|
valid_files = []
|
|||
|
|
for f in os.listdir(data_dir):
|
|||
|
|
# Check for image files
|
|||
|
|
if f.lower().endswith(IMAGE_EXTENSIONS):
|
|||
|
|
# Get corresponding label file path
|
|||
|
|
base_name = os.path.splitext(f)[0] # Remove extension
|
|||
|
|
txt_file = f"{base_name}.txt"
|
|||
|
|
|
|||
|
|
# Only include if label file exists
|
|||
|
|
if os.path.exists(os.path.join(data_dir, txt_file)):
|
|||
|
|
valid_files.append(f)
|
|||
|
|
|
|||
|
|
# Shuffle the dataset
|
|||
|
|
random.shuffle(valid_files)
|
|||
|
|
|
|||
|
|
# Split into training and validation sets
|
|||
|
|
split_idx = int(len(valid_files) * train_ratio)
|
|||
|
|
return valid_files[:split_idx], valid_files[split_idx:]
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
def write_file_list(file_list, output_file, base_dir):
|
|||
|
|
with open(output_file, 'w') as f:
|
|||
|
|
for file in file_list:
|
|||
|
|
f.write(os.path.join(base_dir, file) + '\n')
|
|||
|
|
|
|||
|
|
def copy_files_to_folders(files, src_dir, dest_dir):
|
|||
|
|
"""复制图片和标签文件到目标文件夹"""
|
|||
|
|
os.makedirs(dest_dir, exist_ok=True)
|
|||
|
|
for file in files:
|
|||
|
|
# 复制图片
|
|||
|
|
src_img = os.path.join(src_dir, file)
|
|||
|
|
dest_img = os.path.join(dest_dir, file)
|
|||
|
|
copyfile(src_img, dest_img)
|
|||
|
|
|
|||
|
|
# # 复制对应的标签文件
|
|||
|
|
# src_txt = os.path.join(src_dir, file.replace('.jpg', '.txt'))
|
|||
|
|
# dest_txt = os.path.join(dest_dir, file.replace('.jpg', '.txt'))
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 动态替换扩展名
|
|||
|
|
src_txt = os.path.join(src_dir, file.replace('.jpg', '.txt').replace('.png', '.txt'))
|
|||
|
|
dest_txt = os.path.join(dest_dir, file.replace('.jpg', '.txt').replace('.png', '.txt'))
|
|||
|
|
|
|||
|
|
copyfile(src_txt, dest_txt)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def update_yaml(yaml_file, train_txt, val_txt):
|
|||
|
|
with open(yaml_file, 'r') as f:
|
|||
|
|
data = yaml.safe_load(f)
|
|||
|
|
|
|||
|
|
data['train'] = train_txt
|
|||
|
|
data['val'] = val_txt
|
|||
|
|
|
|||
|
|
with open(yaml_file, 'w') as f:
|
|||
|
|
yaml.dump(data, f, default_flow_style=False)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_model_name(code):
|
|||
|
|
# Ensure code is a 3-digit string
|
|||
|
|
code_str = f"{int(code):03d}" if str(code).isdigit() else code
|
|||
|
|
return model_dict.get(code_str, f"Unknown code: {code}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def generate_request_id(length=30):
|
|||
|
|
"""Generate a unique request ID"""
|
|||
|
|
return 'bb' + ''.join(random.sample(string.ascii_letters, length))
|
|||
|
|
|
|||
|
|
def transform_message_format(input_data):
|
|||
|
|
|
|||
|
|
# transformed = {
|
|||
|
|
# "Scene": input_data.get("ModelScenetype", "train").capitalize(),
|
|||
|
|
# "Command": "start", # Default command
|
|||
|
|
# "Request_ID": input_data.get("request_id", generate_request_id()),
|
|||
|
|
# "Version": input_data.get("Version", "v1.0"),
|
|||
|
|
# "Version": input_data.get("Version", "v1.0"),
|
|||
|
|
# "Name": "channel2", # Default name if not provided
|
|||
|
|
# "Model": input_data.get("code", input_data.get("Code", "026")),
|
|||
|
|
# }
|
|||
|
|
|
|||
|
|
transformed = {
|
|||
|
|
"Scene": input_data.get("ModelScenetype", "train").capitalize(),
|
|||
|
|
"Command": "start", # Default command
|
|||
|
|
"Request_ID": input_data.get("request_id", generate_request_id()),
|
|||
|
|
"Version": input_data.get("version"),
|
|||
|
|
"Name": "channel2", # Default name if not provided
|
|||
|
|
"Model": input_data.get("code", input_data.get("Code", "026")),
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
# Handle TrainParameter construction
|
|||
|
|
params = input_data.get("parameters", {})
|
|||
|
|
batch_size = params.get("batch_size")
|
|||
|
|
img_size = params.get("img_size")
|
|||
|
|
epochs = params.get("epochs")
|
|||
|
|
|
|||
|
|
|
|||
|
|
model_name = get_model_name(input_data.get("code", input_data.get("Code", "026")))
|
|||
|
|
|
|||
|
|
RAW_YAML_PATH = "/home/th/jcq/AI_AutoPlat/AI_web_dsj/config/yaml"
|
|||
|
|
|
|||
|
|
yaml_path = os.path.join(RAW_YAML_PATH,f"{model_name}.yaml")
|
|||
|
|
|
|||
|
|
|
|||
|
|
print("####line137",yaml_path)
|
|||
|
|
|
|||
|
|
# minio_url = "http://172.18.0.129:8084/th-dsp/027123"
|
|||
|
|
# 获取URL
|
|||
|
|
minio_url = input_data.get("TrainImageDir")
|
|||
|
|
if not minio_url:
|
|||
|
|
print("Error: TrainImageDir not specified in input_data")
|
|||
|
|
else:
|
|||
|
|
# 指定本地存储目录
|
|||
|
|
local_base_dir = "/home/th/jcq/AI_AutoPlat/DATA/TrainDatasets"
|
|||
|
|
|
|||
|
|
# 执行下载
|
|||
|
|
success, result = download_minio_directory(minio_url, local_base_dir)
|
|||
|
|
|
|||
|
|
if success:
|
|||
|
|
print(f"Successfully downloaded to: {result}")
|
|||
|
|
else:
|
|||
|
|
print(f"Download failed: {result}")
|
|||
|
|
|
|||
|
|
if success:
|
|||
|
|
print("line277",f"Files downloaded successfully to: {result}")
|
|||
|
|
else:
|
|||
|
|
print(f"Download failed: {result}")
|
|||
|
|
|
|||
|
|
TrainImageDir = result
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 创建 train 和 val 文件夹
|
|||
|
|
train_dir = os.path.join(TrainImageDir, "train")
|
|||
|
|
val_dir = os.path.join(TrainImageDir, "val")
|
|||
|
|
os.makedirs(train_dir, exist_ok=True)
|
|||
|
|
os.makedirs(val_dir, exist_ok=True)
|
|||
|
|
|
|||
|
|
# 划分数据集
|
|||
|
|
train_files, val_files = split_dataset(TrainImageDir)
|
|||
|
|
|
|||
|
|
# 复制文件到 train/ 和 val/ 文件夹
|
|||
|
|
copy_files_to_folders(train_files, TrainImageDir, train_dir)
|
|||
|
|
copy_files_to_folders(val_files, TrainImageDir, val_dir)
|
|||
|
|
|
|||
|
|
# 写入训练集和验证集txt文件
|
|||
|
|
train_txt = os.path.join(TrainImageDir, "train.txt")
|
|||
|
|
val_txt = os.path.join(TrainImageDir, "val.txt")
|
|||
|
|
|
|||
|
|
write_file_list(train_files, train_txt, train_dir)
|
|||
|
|
write_file_list(val_files, val_txt, val_dir)
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
#JCQ: 更新yaml训练文件
|
|||
|
|
|
|||
|
|
# 更新yaml文件,预先获取,直接更新yaml
|
|||
|
|
|
|||
|
|
|
|||
|
|
nc = 5
|
|||
|
|
names = ['flag', 'buoy', 'shipname', 'ship', 'uncover']
|
|||
|
|
|
|||
|
|
|
|||
|
|
update_yaml(yaml_path, train_txt, val_txt)
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
print(f"数据集划分完成,训练集: {len(train_files)} 张,验证集: {len(val_files)} 张")
|
|||
|
|
print(f"训练集列表已写入: {train_txt}")
|
|||
|
|
print(f"验证集列表已写入: {val_txt}")
|
|||
|
|
print(f"YAML文件已更新: {yaml_path}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
RAW_WEIGHTS_PATH = "/home/th/jcq/AIlib/weights_init"
|
|||
|
|
|
|||
|
|
model_path = os.path.join(RAW_WEIGHTS_PATH,model_name,"yolov5.pt")
|
|||
|
|
# model_path = os.path.join(RAW_WEIGHTS_PATH,model_name,"best.pt")
|
|||
|
|
print("####line98",model_path)
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
transformed["TrainParameter"] = str([
|
|||
|
|
batch_size,
|
|||
|
|
img_size,
|
|||
|
|
epochs,
|
|||
|
|
yaml_path,
|
|||
|
|
model_path
|
|||
|
|
])
|
|||
|
|
|
|||
|
|
return transformed
|
|||
|
|
|
|||
|
|
def validate_training_request(data):
|
|||
|
|
"""Validate the training request structure"""
|
|||
|
|
required_fields = {
|
|||
|
|
"Scene": str,
|
|||
|
|
"Command": str,
|
|||
|
|
"Version": str,
|
|||
|
|
"Name": str,
|
|||
|
|
"Model": str,
|
|||
|
|
"TrainParameter": str
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
for field, field_type in required_fields.items():
|
|||
|
|
if field not in data:
|
|||
|
|
return False, f"Missing required field: {field}"
|
|||
|
|
if not isinstance(data[field], field_type):
|
|||
|
|
return False, f"Field {field} must be {field_type.__name__}"
|
|||
|
|
|
|||
|
|
if data["Scene"] not in ["Train"]:
|
|||
|
|
return False, "Invalid Scene value"
|
|||
|
|
|
|||
|
|
if data["Command"] not in ["start", "stop"]:
|
|||
|
|
return False, "Invalid Command value"
|
|||
|
|
|
|||
|
|
return True, ""
|
|||
|
|
|
|||
|
|
def send_to_kafka(topic, message):
|
|||
|
|
"""Send message to Kafka topic"""
|
|||
|
|
producer = None
|
|||
|
|
try:
|
|||
|
|
producer = KafkaProducer(**KAFKA_CONFIG)
|
|||
|
|
future = producer.send(topic, value=message)
|
|||
|
|
|
|||
|
|
# Wait for message to be delivered
|
|||
|
|
result = future.get(timeout=10)
|
|||
|
|
logger.info(f"Message sent to partition {result.partition}, offset {result.offset}")
|
|||
|
|
return True, ""
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"Failed to send message: {e}")
|
|||
|
|
return False, str(e)
|
|||
|
|
finally:
|
|||
|
|
if producer:
|
|||
|
|
producer.flush()
|
|||
|
|
producer.close()
|
|||
|
|
|
|||
|
|
@app.route('/api/train', methods=['POST'])
|
|||
|
|
def handle_training_request():
|
|||
|
|
"""HTTP endpoint for training requests (non-blocking)"""
|
|||
|
|
try:
|
|||
|
|
# Parse and validate basic request structure
|
|||
|
|
data = request.get_json()
|
|||
|
|
if not data or not isinstance(data, dict):
|
|||
|
|
return jsonify({"status": "error", "message": "Invalid JSON data"}), 400
|
|||
|
|
|
|||
|
|
# Generate request ID
|
|||
|
|
|
|||
|
|
request_id = data.get("request_id")
|
|||
|
|
|
|||
|
|
# Start background process (daemon=True ensures process exits with main)
|
|||
|
|
p = Process(
|
|||
|
|
target=process_training_task,
|
|||
|
|
args=(data, request_id),
|
|||
|
|
daemon=True
|
|||
|
|
)
|
|||
|
|
p.start()
|
|||
|
|
|
|||
|
|
# Immediate response
|
|||
|
|
return jsonify({
|
|||
|
|
"status": "success",
|
|||
|
|
"message": "Training task started",
|
|||
|
|
"request_id": request_id,
|
|||
|
|
"process_id": p.pid
|
|||
|
|
}), 202 # 202 Accepted
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"Request handling error: {str(e)}")
|
|||
|
|
return jsonify({
|
|||
|
|
"status": "error",
|
|||
|
|
"message": "Internal server error"
|
|||
|
|
}), 500
|
|||
|
|
|
|||
|
|
def process_training_task(data, request_id):
|
|||
|
|
"""独立进程执行耗时任务"""
|
|||
|
|
try:
|
|||
|
|
logger.info(f"[{request_id}] Starting background task (PID: {os.getpid()})")
|
|||
|
|
start_time = time.time()
|
|||
|
|
|
|||
|
|
# 1. 数据转换
|
|||
|
|
transformed = transform_message_format(data)
|
|||
|
|
transformed["Request_ID"] = request_id
|
|||
|
|
|
|||
|
|
# 2. 数据验证
|
|||
|
|
is_valid, msg = validate_training_request(transformed)
|
|||
|
|
if not is_valid:
|
|||
|
|
raise ValueError(msg)
|
|||
|
|
|
|||
|
|
# 3. 发送到Kafka
|
|||
|
|
success, error = send_to_kafka(TOPIC_TRAIN, transformed)
|
|||
|
|
if not success:
|
|||
|
|
raise RuntimeError(error)
|
|||
|
|
|
|||
|
|
elapsed = time.time() - start_time
|
|||
|
|
logger.info(f"[{request_id}] Task completed in {elapsed:.2f}s")
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"[{request_id}] Task failed: {str(e)}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
api.add_resource(TrainingTask_Stop, '/api/train/control')
|
|||
|
|
api.add_resource(TrainingTask_Check, '/api/train/check')
|
|||
|
|
api.add_resource(TrainingTask_Manager, '/api/train/manager')
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == '__main__':
|
|||
|
|
app.run(host='172.16.225.150', port=5001, debug=True)
|