Parcourir la source

更新更新

tags/V2.4.0
chenyukun il y a 2 ans
Parent
révision
eee294dd83
13 fichiers modifiés avec 316 ajouts et 133 suppressions
  1. +4
    -5
      concurrency/FileUpdateThread.py
  2. +16
    -17
      concurrency/IntelligentRecognitionProcess.py
  3. +37
    -1
      dsp_application.yml
  4. +3
    -3
      entity/FeedBack.py
  5. +0
    -57
      enums/AnalysisLabelEnum.py
  6. +2
    -0
      enums/ExceptionEnum.py
  7. +0
    -16
      enums/ModelTypeEnum.py
  8. +0
    -24
      service/Dispatcher.py
  9. +9
    -0
      test/mysqltest.py
  10. +8
    -1
      util/Cv2Utils.py
  11. +8
    -8
      util/KafkaUtils.py
  12. +227
    -0
      util/MyConnectionPool.py
  13. +2
    -1
      util/YmlUtils.py

+ 4
- 5
concurrency/FileUpdateThread.py Voir le fichier

from util import TimeUtils from util import TimeUtils
import uuid import uuid
from entity import FeedBack from entity import FeedBack
from enums.AnalysisTypeEnum import AnalysisType
from enums.AnalysisStatusEnum import AnalysisStatus from enums.AnalysisStatusEnum import AnalysisStatus
import numpy as np import numpy as np
from PIL import Image from PIL import Image
or_image_name = self.build_image_name(self.msg.get('results_base_dir'), time_now, or_image_name = self.build_image_name(self.msg.get('results_base_dir'), time_now,
str(image_dict.get("current_frame")), str(image_dict.get("current_frame")),
str(image_dict.get("last_frame")), str(image_dict.get("last_frame")),
image_dict.get("question_descrition"),
image_dict.get("model_detection_code"),
random_num, random_num,
image_dict.get("mode_service"), image_dict.get("mode_service"),
self.msg.get('request_id'), "OR") self.msg.get('request_id'), "OR")
ai_image_name = self.build_image_name(self.msg.get('results_base_dir'), time_now, ai_image_name = self.build_image_name(self.msg.get('results_base_dir'), time_now,
str(image_dict.get("current_frame")), str(image_dict.get("current_frame")),
str(image_dict.get("last_frame")), str(image_dict.get("last_frame")),
image_dict.get("question_descrition"),
image_dict.get("model_detection_code"),
random_num, random_num,
image_dict.get("mode_service"), image_dict.get("mode_service"),
self.msg.get('request_id'), "AI") self.msg.get('request_id'), "AI")
self.sendResult(FeedBack.message_feedback(self.msg.get('request_id'), AnalysisStatus.RUNNING.value, self.sendResult(FeedBack.message_feedback(self.msg.get('request_id'), AnalysisStatus.RUNNING.value,
self.mode_service, "", "", image_dict.get("progress"), self.mode_service, "", "", image_dict.get("progress"),
or_image_name, or_image_name,
ai_image_name, image_dict.get("question_code"),
image_dict.get("question_descrition"),
ai_image_name, image_dict.get("model_type_code"),
image_dict.get("model_detection_code"),
TimeUtils.now_date_to_str())) TimeUtils.now_date_to_str()))
except Exception as e: except Exception as e:
logger.error("requestId:{}, 图片上传异常:", self.msg.get("request_id")) logger.error("requestId:{}, 图片上传异常:", self.msg.get("request_id"))

+ 16
- 17
concurrency/IntelligentRecognitionProcess.py Voir le fichier

from enums.AnalysisStatusEnum import AnalysisStatus from enums.AnalysisStatusEnum import AnalysisStatus
from enums.AnalysisTypeEnum import AnalysisType from enums.AnalysisTypeEnum import AnalysisType
from enums.ExceptionEnum import ExceptionType from enums.ExceptionEnum import ExceptionType
from enums.AnalysisLabelEnum import AnalysisLabel, LCAnalysisLabel
from enums.ModelTypeEnum import ModelType from enums.ModelTypeEnum import ModelType
from util import LogUtils, TimeUtils from util import LogUtils, TimeUtils
from util.Cv2Utils import Cv2Util from util.Cv2Utils import Cv2Util
code = model.get("code") code = model.get("code")
needed_objectsIndex = [int(category.get("id")) for category in model.get("categories")] needed_objectsIndex = [int(category.get("id")) for category in model.get("categories")]
if code == ModelType.WATER_SURFACE_MODEL.value[1]: if code == ModelType.WATER_SURFACE_MODEL.value[1]:
return ModelUtils.SZModel(gpuId, needed_objectsIndex), AnalysisLabel
return ModelUtils.SZModel(gpuId, needed_objectsIndex), code
elif code == ModelType.FOREST_FARM_MODEL.value[1]: elif code == ModelType.FOREST_FARM_MODEL.value[1]:
return ModelUtils.LCModel(gpuId, needed_objectsIndex), LCAnalysisLabel
return ModelUtils.LCModel(gpuId, needed_objectsIndex), code
else: else:
logger.error("未匹配到对应的模型") logger.error("未匹配到对应的模型")
raise ServiceException(ExceptionType.AI_MODEL_MATCH_EXCEPTION.value[0], raise ServiceException(ExceptionType.AI_MODEL_MATCH_EXCEPTION.value[0],
try: try:
# 加载模型 # 加载模型
logger.info("开始加载算法模型, requestId: {}", self.msg.get("request_id")) logger.info("开始加载算法模型, requestId: {}", self.msg.get("request_id"))
mod, analyseLable = self.get_model(str(self.gpu_ids[0]), self.msg["models"])
mod, model_type_code = self.get_model(str(self.gpu_ids[0]), self.msg["models"])
logger.info("加载算法模型完成, requestId: {}", self.msg.get("request_id")) logger.info("加载算法模型完成, requestId: {}", self.msg.get("request_id"))
# 定义原视频、AI视频保存名称 # 定义原视频、AI视频保存名称
randomStr = str(uuid.uuid1().hex) randomStr = str(uuid.uuid1().hex)
try: try:
cv2tool.getP().stdin.write(p_result[1].tostring()) cv2tool.getP().stdin.write(p_result[1].tostring())
cv2tool.getOrVideoFile().write(frame) cv2tool.getOrVideoFile().write(frame)
cv2tool.getAiVideoFile().write(p_result[1])
frame_merge = cv2tool.video_merge(copy.deepcopy(frame), copy.deepcopy(p_result[1]))
cv2tool.getAiVideoFile().write(frame_merge)
except Exception as e: except Exception as e:
logger.error("requestId:{}, 写流异常:", self.msg.get("request_id")) logger.error("requestId:{}, 写流异常:", self.msg.get("request_id"))
logger.exception(e) logger.exception(e)
ai_analyse_results = p_result[2] ai_analyse_results = p_result[2]
for ai_analyse_result in ai_analyse_results: for ai_analyse_result in ai_analyse_results:
order = str(int(ai_analyse_result[0])) order = str(int(ai_analyse_result[0]))
label = analyseLable.getLabel(order)
high_result = high_score_image.get(order) high_result = high_score_image.get(order)
conf_c = ai_analyse_result[5] conf_c = ai_analyse_result[5]
if high_result is None: if high_result is None:
"last_frame": current_frame + step, "last_frame": current_frame + step,
"progress": "", "progress": "",
"mode_service": "online", "mode_service": "online",
"question_code": label.value[2],
"question_descrition": label.value[1],
"model_type_code": model_type_code,
"model_detection_code": order,
"socre": conf_c "socre": conf_c
} }
else: else:
"last_frame": current_frame + step, "last_frame": current_frame + step,
"progress": "", "progress": "",
"mode_service": "online", "mode_service": "online",
"question_code": label.value[2],
"question_descrition": label.value[1],
"model_type_code": model_type_code,
"model_detection_code": order,
"socre": conf_c "socre": conf_c
} }
if current_frame % int(self.content["service"]["frame_step"]) == 0 and len(high_score_image) > 0: if current_frame % int(self.content["service"]["frame_step"]) == 0 and len(high_score_image) > 0:
try: try:
# 加载模型 # 加载模型
logger.info("开始加载算法模型, requestId:{}", self.msg.get("request_id")) logger.info("开始加载算法模型, requestId:{}", self.msg.get("request_id"))
mod, analyseLable = self.get_model(str(self.gpu_ids[0]), self.msg["models"])
mod, model_type_code = self.get_model(str(self.gpu_ids[0]), self.msg["models"])
# mod = ModelUtils.SZModel([0,1,2,3]) # mod = ModelUtils.SZModel([0,1,2,3])
logger.info("加载算法模型完成, requestId:{}", self.msg.get("request_id")) logger.info("加载算法模型完成, requestId:{}", self.msg.get("request_id"))
# 定义原视频、AI视频保存名称 # 定义原视频、AI视频保存名称
# logger.info("算法模型调度时间:{}s", int(time11-time00)) # logger.info("算法模型调度时间:{}s", int(time11-time00))
# 原视频保存本地、AI视频保存本地 # 原视频保存本地、AI视频保存本地
try: try:
cv2tool.getAiVideoFile().write(p_result[1])
frame_merge = cv2tool.video_merge(copy.deepcopy(frame), copy.deepcopy(p_result[1]))
cv2tool.getAiVideoFile().write(frame_merge)
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)
# # 问题图片加入队列, 暂时写死,后期修改为真实问题 # # 问题图片加入队列, 暂时写死,后期修改为真实问题
ai_analyse_results = p_result[2] ai_analyse_results = p_result[2]
for ai_analyse_result in ai_analyse_results: for ai_analyse_result in ai_analyse_results:
order = str(int(ai_analyse_result[0])) order = str(int(ai_analyse_result[0]))
label = analyseLable.getLabel(order)
high_result = high_score_image.get(order) high_result = high_score_image.get(order)
conf_c = ai_analyse_result[5] conf_c = ai_analyse_result[5]
if high_result is None: if high_result is None:
"last_frame": current_frame + step, "last_frame": current_frame + step,
"progress": "", "progress": "",
"mode_service": "offline", "mode_service": "offline",
"question_code": label.value[2],
"question_descrition": label.value[1],
"model_type_code": model_type_code,
"model_detection_code": order,
"socre": conf_c "socre": conf_c
} }
else: else:
"last_frame": current_frame + step, "last_frame": current_frame + step,
"progress": "", "progress": "",
"mode_service": "offline", "mode_service": "offline",
"question_code": label.value[2],
"question_descrition": label.value[1],
"model_type_code": model_type_code,
"model_detection_code": order,
"socre": conf_c "socre": conf_c
} }
if current_frame % int(self.content["service"]["frame_step"]) == 0 and len(high_score_image) > 0: if current_frame % int(self.content["service"]["frame_step"]) == 0 and len(high_score_image) > 0:

+ 37
- 1
dsp_application.yml Voir le fichier

dsp:
active: dev
kafka: kafka:
topic: topic:
dsp-alg-online-tasks-topic: dsp-alg-online-tasks dsp-alg-online-tasks-topic: dsp-alg-online-tasks
dsp-alg-results-topic: dsp-alg-task-results dsp-alg-results-topic: dsp-alg-task-results
dsp-alg-task-results: dsp-alg-task-results:
partition: [0] partition: [0]
active: dev
local: local:
bootstrap_servers: ['192.168.10.11:9092'] bootstrap_servers: ['192.168.10.11:9092']
producer: producer:
enqueue: True enqueue: True
# 编码格式 # 编码格式
encoding: utf8 encoding: utf8
#mysql:
# # 数据库信息
# dev:
# host: 192.168.11.13
# port: 3306
# dbname: tuheng_dsp
# username: root
# password: idontcare
# test:
# host: 192.168.11.242
# port: 3306
# dbname: tuheng_dsp
# username: root
# password: idontcare
# prod:
# host: 172.16.1.22
# port: 3306
# dbname: tuheng_dsp
# username: root
# password: TH22#2022
# db_charset: utf8
# # mincached : 启动时开启的闲置连接数量(缺省值 0 开始时不创建连接)
# db_min_cached: 0
# # maxcached : 连接池中允许的闲置的最多连接数量(缺省值 0 代表不闲置连接池大小)
# db_max_cached: 10
# # maxshared : 共享连接数允许的最大数量(缺省值 0 代表所有连接都是专用的)如果达到了最大数量,被请求为共享的连接将会被共享使用
# db_max_shared: 10
# # maxconnecyions : 创建连接池的最大数量(缺省值 0 代表不限制)
# db_max_connecyions: 20
# # maxusage : 单个连接的最大允许复用次数(缺省值 0 或 False 代表不限制的复用).当达到最大数时,连接会自动重新连接(关闭和重新打开)
# db_blocking: True
# # maxusage : 单个连接的最大允许复用次数(缺省值 0 或 False 代表不限制的复用).当达到最大数时,连接会自动重新连接(关闭和重新打开)
# db_max_usage: 0
# # setsession : 一个可选的SQL命令列表用于准备每个会话,如["set datestyle to german", ...]
# db_set_session: None

+ 3
- 3
entity/FeedBack.py Voir le fichier



def message_feedback(requestId, status, type, error_code="", error_msg="", progress="", original_url="", sign_url="", def message_feedback(requestId, status, type, error_code="", error_msg="", progress="", original_url="", sign_url="",
category_id="", description="", analyse_time=""):
model_type_code="", model_detection_code="", analyse_time=""):
taskfb = {} taskfb = {}
results = [] results = []
result_msg = {} result_msg = {}
taskfb["progress"] = progress taskfb["progress"] = progress
result_msg["original_url"] = original_url result_msg["original_url"] = original_url
result_msg["sign_url"] = sign_url result_msg["sign_url"] = sign_url
result_msg["category_id"] = category_id
result_msg["description"] = description
result_msg["model_type_code"] = model_type_code
result_msg["model_detection_coden"] = model_detection_code
result_msg["analyse_time"] = analyse_time result_msg["analyse_time"] = analyse_time
results.append(result_msg) results.append(result_msg)
taskfb["results"] = results taskfb["results"] = results

+ 0
- 57
enums/AnalysisLabelEnum.py Voir le fichier

from enum import Enum, unique


# 分析状态枚举
@unique
class AnalysisLabel(Enum):
VENT = ("0", "排口", "SL014")

SEWAGE_OUTLET = ("1", "水生植被", "SL013")

OTHER = ("2", "其他", "SL001")

FLOATING_OBJECTS = ("3", "漂浮物", "SL001")

AQUATIC_VEGETATION = ("4", "污染排口", "SL011")

VEGETABLE_FIELD = ("5", "菜地", "SL007")

NON_CONFORMING_BUILDING = ("6", "违建", "SL010")

BANK_SLOPE_GARBAGE = ("7", "岸坡垃圾", "SL009")

def checkLabel(id):
for label in AnalysisLabel:
if label.value[0] == id:
return True
return False

def getLabel(order):
for label in AnalysisLabel:
if label.value[0] == order:
return label
return None


# 林场
@unique
class LCAnalysisLabel(Enum):
PATTERN_SPOT = ("0", "林斑", "LC001")

DEAD_TREE = ("1", "病死树", "LC002")

PERSONNER_ACTIVITIES = ("2", "人员活动", "LC003")

FIRE_IMPLICATION = ("3", "火灾隐含", "LC004")

def checkLabel(id):
for label in LCAnalysisLabel:
if label.value[0] == id:
return True
return False

def getLabel(order):
for label in LCAnalysisLabel:
if label.value[0] == order:
return label
return None

+ 2
- 0
enums/ExceptionEnum.py Voir le fichier



AI_MODEL_MATCH_EXCEPTION = ("SP017", "The AI Model Is Not Matched!") AI_MODEL_MATCH_EXCEPTION = ("SP017", "The AI Model Is Not Matched!")


VIDEO_MERGE_EXCEPTION = ("SP018", "The Video Merge Exception!")

SERVICE_INNER_EXCEPTION = ("SP999", "系统内部异常, 请联系工程师定位处理!") SERVICE_INNER_EXCEPTION = ("SP999", "系统内部异常, 请联系工程师定位处理!")

+ 0
- 16
enums/ModelTypeEnum.py Voir le fichier

from enum import Enum, unique


# 异常枚举
@unique
class ModelType(Enum):

WATER_SURFACE_MODEL = ("1", "DSPSL000", "水面模型")

FOREST_FARM_MODEL = ("2", "DSPLC000", "林场模型")

def checkCode(code):
for model in ModelType:
if model.value[1] == code:
return True
return False

+ 0
- 24
service/Dispatcher.py Voir le fichier

# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import torch
import time import time
import GPUtil import GPUtil
from util import YmlUtils, FileUtils, LogUtils from util import YmlUtils, FileUtils, LogUtils
from loguru import logger from loguru import logger
from multiprocessing import Queue from multiprocessing import Queue
from enums.ModelTypeEnum import ModelType
from enums.AnalysisLabelEnum import AnalysisLabel, LCAnalysisLabel
from concurrency.IntelligentRecognitionProcess import OnlineIntelligentRecognitionProcess, OfflineIntelligentRecognitionProcess from concurrency.IntelligentRecognitionProcess import OnlineIntelligentRecognitionProcess, OfflineIntelligentRecognitionProcess
from concurrency.MessagePollingThread import OfflineMessagePollingThread, OnlineMessagePollingThread from concurrency.MessagePollingThread import OfflineMessagePollingThread, OnlineMessagePollingThread
from util import GPUtils from util import GPUtils

''' '''
分发服务 分发服务
''' '''
for model in models: for model in models:
if model.get("code") is None: if model.get("code") is None:
return False return False
if not ModelType.checkCode(model.get("code")):
return False
if model.get("categories") is None: if model.get("categories") is None:
return False return False
if model.get("code") == ModelType.WATER_SURFACE_MODEL.value[1]:
for category in model.get("categories"):
if not AnalysisLabel.checkLabel(category.get("id")):
return False
if model.get("code") == ModelType.FOREST_FARM_MODEL.value[1]:
for category in model.get("categories"):
if not LCAnalysisLabel.checkLabel(category.get("id")):
return False
if command == "start" and pull_url is None: if command == "start" and pull_url is None:
return False return False
if command == "start" and push_url is None: if command == "start" and push_url is None:
for model in models: for model in models:
if model.get("code") is None: if model.get("code") is None:
return False return False
if not ModelType.checkCode(model.get("code")):
return False
if model.get("categories") is None: if model.get("categories") is None:
return False return False
if model.get("code") == ModelType.WATER_SURFACE_MODEL.value[1]:
for category in model.get("categories"):
if not AnalysisLabel.checkLabel(category.get("id")):
return False
if model.get("code") == ModelType.FOREST_FARM_MODEL.value[1]:
for category in model.get("categories"):
if not LCAnalysisLabel.checkLabel(category.get("id")):
return False
if command == 'start' and original_url is None: if command == 'start' and original_url is None:
return False return False
if command == 'start' and original_type is None: if command == 'start' and original_type is None:

+ 9
- 0
test/mysqltest.py Voir le fichier

from util.MyConnectionPool import MySqLHelper
from util import YmlUtils
import json

if __name__=="__main__":
content = YmlUtils.getConfigs()
sql = MySqLHelper(content)
res = sql.selectall("select id, name, code, description, create_user, create_time, update_user, update_time, mark from dsp_model_classification where mark = %s", 1)
print(res)

+ 8
- 1
util/Cv2Utils.py Voir le fichier

# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import cv2 import cv2
import subprocess as sp import subprocess as sp
import numpy as np
from loguru import logger from loguru import logger
from exception.CustomerException import ServiceException from exception.CustomerException import ServiceException
from enums.ExceptionEnum import ExceptionType from enums.ExceptionEnum import ExceptionType
ExceptionType.OR_VIDEO_ADDRESS_EXCEPTION.value[1]) ExceptionType.OR_VIDEO_ADDRESS_EXCEPTION.value[1])


self.or_video_file = cv2.VideoWriter(self.orFilePath, cv2.VideoWriter_fourcc(*'mp4v'), self.fps, (self.width, self.height)) self.or_video_file = cv2.VideoWriter(self.orFilePath, cv2.VideoWriter_fourcc(*'mp4v'), self.fps, (self.width, self.height))
self.ai_video_file = cv2.VideoWriter(self.aiFilePath, cv2.VideoWriter_fourcc(*'mp4v'), self.fps, (self.width, self.height))
self.ai_video_file = cv2.VideoWriter(self.aiFilePath, cv2.VideoWriter_fourcc(*'mp4v'), self.fps, (self.width, int(self.height/2)))
except ServiceException as s: except ServiceException as s:
raise s raise s
except Exception as e: except Exception as e:
logger.error("初始化管道失败:") logger.error("初始化管道失败:")
logger.exception(e) logger.exception(e)


def video_merge(self, frame1, frame2):
frameLeft = cv2.resize(frame1, (int(self.width / 2), int(self.height / 2)), interpolation=cv2.INTER_LINEAR)
frameRight = cv2.resize(frame2, (int(self.width / 2), int(self.height / 2)), interpolation=cv2.INTER_LINEAR)
frame_merge = np.hstack((frameLeft, frameRight))
return frame_merge

def getP(self): def getP(self):
if self.p is None: if self.p is None:
logger.error("获取管道为空!") logger.error("获取管道为空!")

+ 8
- 8
util/KafkaUtils.py Voir le fichier



def __init__(self, content): def __init__(self, content):
self.content = content self.content = content
configs = self.content["kafka"][self.content["kafka"]["active"]]["producer"]
self.customerProducer = KafkaProducer(bootstrap_servers=self.content["kafka"][self.content["kafka"]["active"]]["bootstrap_servers"],
configs = self.content["kafka"][self.content["dsp"]["active"]]["producer"]
self.customerProducer = KafkaProducer(bootstrap_servers=self.content["kafka"][self.content["dsp"]["active"]]["bootstrap_servers"],
acks=configs["acks"], acks=configs["acks"],
retries=configs["retries"], retries=configs["retries"],
linger_ms=configs["linger_ms"], linger_ms=configs["linger_ms"],
if self.customerProducer: if self.customerProducer:
return self.customerProducer return self.customerProducer
logger.info("配置kafka生产者") logger.info("配置kafka生产者")
configs = self.content["kafka"][self.content["kafka"]["active"]]["producer"]
self.customerProducer = KafkaProducer(bootstrap_servers=self.content["kafka"][self.content["kafka"]["active"]]["bootstrap_servers"],
configs = self.content["kafka"][self.content["dsp"]["active"]]["producer"]
self.customerProducer = KafkaProducer(bootstrap_servers=self.content["kafka"][self.content["dsp"]["active"]]["bootstrap_servers"],
acks=configs["acks"], acks=configs["acks"],
retries=configs["retries"], retries=configs["retries"],
linger_ms=configs["linger_ms"], linger_ms=configs["linger_ms"],
def __init__(self, content): def __init__(self, content):
logger.info("初始化消费者") logger.info("初始化消费者")
self.content = content self.content = content
configs = self.content["kafka"][self.content["kafka"]["active"]]["consumer"]
self.customerConsumer = KafkaConsumer(bootstrap_servers=self.content["kafka"][self.content["kafka"]["active"]]["bootstrap_servers"],
configs = self.content["kafka"][self.content["dsp"]["active"]]["consumer"]
self.customerConsumer = KafkaConsumer(bootstrap_servers=self.content["kafka"][self.content["dsp"]["active"]]["bootstrap_servers"],
client_id=configs["client_id"], client_id=configs["client_id"],
group_id=configs["group_id"], group_id=configs["group_id"],
auto_offset_reset=configs["auto_offset_reset"], auto_offset_reset=configs["auto_offset_reset"],
if self.customerConsumer: if self.customerConsumer:
logger.info("获取消费者成功!") logger.info("获取消费者成功!")
return self.customerConsumer return self.customerConsumer
configs = self.content["kafka"][self.content["kafka"]["active"]]["consumer"]
self.customerConsumer = KafkaConsumer(bootstrap_servers=self.content["kafka"][self.content["kafka"]["active"]]["bootstrap_servers"],
configs = self.content["kafka"][self.content["dsp"]["active"]]["consumer"]
self.customerConsumer = KafkaConsumer(bootstrap_servers=self.content["kafka"][self.content["dsp"]["active"]]["bootstrap_servers"],
client_id=configs["client_id"], client_id=configs["client_id"],
group_id=configs["group_id"], group_id=configs["group_id"],
auto_offset_reset=configs["auto_offset_reset"], auto_offset_reset=configs["auto_offset_reset"],

+ 227
- 0
util/MyConnectionPool.py Voir le fichier

# -*- coding: UTF-8 -*-
import pymysql
from loguru import logger
from dbutils.pooled_db import PooledDB


"""
@功能:创建数据库连接池
"""


class MyConnectionPool(object):
__pool = None

def __init__(self, content):
self.conn = self.__getConn(content)
self.cursor = self.conn.cursor()

# 创建数据库连接conn和游标cursor
# def __enter__(self):
# self.conn = self.__getconn()
# self.cursor = self.conn.cursor()

# 创建数据库连接池
def __getconn(self, content):
if self.__pool is None:
self.__pool = PooledDB(
creator=pymysql,
mincached=int(content["mysql"]["db_min_cached"]),
maxcached=int(content["mysql"]["db_max_cached"]),
maxshared=int(content["mysql"]["db_max_shared"]),
maxconnections=int(content["mysql"]["db_max_connecyions"]),
blocking=content["mysql"]["db_blocking"],
maxusage=content["mysql"]["db_max_usage"],
setsession=content["mysql"]["db_set_session"],
host=content["mysql"][content["dsp"]["active"]]["host"],
port=content["mysql"][content["dsp"]["active"]]["port"],
user=content["mysql"][content["dsp"]["active"]]["username"],
passwd=content["mysql"][content["dsp"]["active"]]["password"],
db=content["mysql"][content["dsp"]["active"]]["dbname"],
use_unicode=False,
charset=content["mysql"]["db_charset"]
)
return self.__pool.connection()

# 释放连接池资源
# def __exit__(self, exc_type, exc_val, exc_tb):
# self.cursor.close()
# self.conn.close()

# 关闭连接归还给链接池
def close(self):
self.cursor.close()
self.conn.close()

# 从连接池中取出一个连接
def getconn(self, content):
conn = self.__getconn(content)
cursor = conn.cursor()
return cursor, conn


# 获取连接池,实例化
def get_my_connection(content):
return MyConnectionPool(content)


'''
执行语句查询有结果返回结果没有返回0;增/删/改返回变更数据条数,没有返回0
'''


class MySqLHelper(object):
def __init__(self, content):
logger.info("开始加载数据库连接池!")
self.db = get_my_connection(content)
logger.info("加载数据库连接池完成!")

def __new__(cls, *args, **kwargs):
if not hasattr(cls, 'inst'): # 单例
cls.inst = super(MySqLHelper, cls).__new__(cls, *args, **kwargs)
return cls.inst

# 封装执行命令
def execute(self, sql, param=None, autoclose=False):
"""
【主要判断是否有参数和是否执行完就释放连接】
:param sql: 字符串类型,sql语句
:param param: sql语句中要替换的参数"select %s from tab where id=%s" 其中的%s就是参数
:param autoclose: 是否关闭连接
:return: 返回连接conn和游标cursor
"""
cursor, conn = self.db.getconn() # 从连接池获取连接
count = 0
try:
# count : 为改变的数据条数
if param:
count = cursor.execute(sql, param)
else:
count = cursor.execute(sql)
conn.commit()
if autoclose:
self.close(cursor, conn)
except Exception as e:
pass
return cursor, conn, count

# 执行多条命令
# def executemany(self, lis):
# """
# :param lis: 是一个列表,里面放的是每个sql的字典'[{"sql":"xxx","param":"xx"}....]'
# :return:
# """
# cursor, conn = self.db.getconn()
# try:
# for order in lis:
# sql = order['sql']
# param = order['param']
# if param:
# cursor.execute(sql, param)
# else:
# cursor.execute(sql)
# conn.commit()
# self.close(cursor, conn)
# return True
# except Exception as e:
# print(e)
# conn.rollback()
# self.close(cursor, conn)
# return False

# 释放连接
def close(self, cursor, conn):
logger.info("开始释放数据库连接!")
cursor.close()
conn.close()
logger.info("释放数据库连接完成!")

# 查询所有
def selectall(self, sql, param=None):
try:
cursor, conn, count = self.execute(sql, param)
res = cursor.fetchall()
return res
except Exception as e:
logger.error("查询所有数据异常:")
logger.exception(e)
self.close(cursor, conn)
return count

# 查询单条
def selectone(self, sql, param=None):
try:
cursor, conn, count = self.execute(sql, param)
res = cursor.fetchone()
self.close(cursor, conn)
return res
except Exception as e:
logger.error("查询单条数据异常:")
logger.exception(e)
self.close(cursor, conn)
return count

# 增加
def insertone(self, sql, param):
try:
cursor, conn, count = self.execute(sql, param)
# _id = cursor.lastrowid() # 获取当前插入数据的主键id,该id应该为自动生成为好
conn.commit()
self.close(cursor, conn)
return count
# 防止表中没有id返回0
# if _id == 0:
# return True
# return _id
except Exception as e:
logger.error("新增数据异常:")
logger.exception(e)
conn.rollback()
self.close(cursor, conn)
return count

# 增加多行
def insertmany(self, sql, param):
"""
:param sql:
:param param: 必须是元组或列表[(),()]或((),())
:return:
"""
cursor, conn, count = self.db.getconn()
try:
cursor.executemany(sql, param)
conn.commit()
return count
except Exception as e:
logger.error("增加多条数据异常:")
logger.exception(e)
conn.rollback()
self.close(cursor, conn)
return count

# 删除
def delete(self, sql, param=None):
try:
cursor, conn, count = self.execute(sql, param)
self.close(cursor, conn)
return count
except Exception as e:
logger.error("删除数据异常:")
logger.exception(e)
conn.rollback()
self.close(cursor, conn)
return count

# 更新
def update(self, sql, param=None):
try:
cursor, conn, count = self.execute(sql, param)
conn.commit()
self.close(cursor, conn)
return count
except Exception as e:
logger.error("更新数据异常:")
logger.exception(e)
conn.rollback()
self.close(cursor, conn)
return count

+ 2
- 1
util/YmlUtils.py Voir le fichier

# 从配置文件读取所有配置信息 # 从配置文件读取所有配置信息
def getConfigs(): def getConfigs():
print("开始读取配置文件,获取配置消息:", Constant.APPLICATION_CONFIG) print("开始读取配置文件,获取配置消息:", Constant.APPLICATION_CONFIG)
applicationConfigPath = os.path.abspath(Constant.APPLICATION_CONFIG)
applicationConfigPath = "../dsp_application.yml"
print(applicationConfigPath)
if not os.path.exists(applicationConfigPath): if not os.path.exists(applicationConfigPath):
raise Exception("未找到配置文件:{}".format(Constant.APPLICATION_CONFIG)) raise Exception("未找到配置文件:{}".format(Constant.APPLICATION_CONFIG))
with open(applicationConfigPath, Constant.R, encoding=Constant.UTF_8) as f: with open(applicationConfigPath, Constant.R, encoding=Constant.UTF_8) as f:

Chargement…
Annuler
Enregistrer