You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

293 lines
11KB

  1. import copy
  2. from enum import Enum, unique
  3. from PIL import Image
  4. import sys
  5. sys.path.extend(['..','../AIlib' ])
  6. from AI import AI_process, AI_process_forest, get_postProcess_para
  7. import cv2,os,time
  8. from segutils.segmodel import SegModel
  9. from models.experimental import attempt_load
  10. from utils.torch_utils import select_device
  11. from utilsK.queRiver import get_labelnames,get_label_arrays
  12. import numpy as np
  13. # 异常枚举
  14. @unique
  15. class ModelType(Enum):
  16. WATER_SURFACE_MODEL = ("1", "001", "水面模型")
  17. FOREST_FARM_MODEL = ("2", "002", "森林模型")
  18. TRAFFIC_FARM_MODEL = ("3", "003", "交通模型")
  19. def checkCode(code):
  20. for model in ModelType:
  21. if model.value[1] == code:
  22. return True
  23. return False
  24. class ModelConfig():
  25. def __init__(self):
  26. postFile = '../AIlib/conf/para.json'
  27. self.conf_thres, self.iou_thres, self.classes, self.rainbows = get_postProcess_para(postFile)
  28. class SZModelConfig(ModelConfig):
  29. def __init__(self):
  30. super(SZModelConfig, self).__init__()
  31. labelnames = "../AIlib/weights/yolov5/class8/labelnames.json" ##对应类别表
  32. self.names = get_labelnames(labelnames)
  33. self.label_arraylist = get_label_arrays(self.names, self.rainbows, outfontsize=40,
  34. fontpath="../AIlib/conf/platech.ttf")
  35. class LCModelConfig(ModelConfig):
  36. def __init__(self):
  37. super(LCModelConfig, self).__init__()
  38. labelnames = "../AIlib/weights/forest/labelnames.json"
  39. self.names = get_labelnames(labelnames)
  40. self.label_arraylist = get_label_arrays(self.names, self.rainbows, outfontsize=40, fontpath="../AIlib/conf/platech.ttf")
  41. class RFModelConfig(ModelConfig):
  42. def __init__(self):
  43. super(RFModelConfig, self).__init__()
  44. labelnames = "../AIlib/weights/road/labelnames.json"
  45. self.names = get_labelnames(labelnames)
  46. imageW = 1536
  47. outfontsize=int(imageW/1920*40)
  48. self.label_arraylist = get_label_arrays(self.names, self.rainbows, outfontsize=outfontsize, fontpath="../AIlib/conf/platech.ttf")
  49. class Model():
  50. def __init__(self, device, allowedList=None):
  51. ##预先设置的参数
  52. self.device_ = device ##选定模型,可选 cpu,'0','1'
  53. self.allowedList = allowedList
  54. # 水面模型
  55. class SZModel(Model):
  56. def __init__(self, device, allowedList=None):
  57. super().__init__(device, allowedList)
  58. self.device = select_device(self.device_)
  59. self.half = self.device.type != 'cpu'
  60. self.model = attempt_load("../AIlib/weights/yolov5/class8/bestcao.pt", map_location=self.device)
  61. if self.half:
  62. self.model.half()
  63. self.segmodel = SegModel(nclass=2, weights='../AIlib/weights/STDC/model_maxmIOU75_1720_0.946_360640.pth',
  64. device=self.device)
  65. # names, label_arraylist, rainbows, conf_thres, iou_thres
  66. def process(self, frame, config):
  67. return AI_process([frame], self.model, self.segmodel, config[0], config[1],
  68. config[2], self.half, self.device, config[3], config[4],
  69. self.allowedList)
  70. # 森林模型
  71. class LCModel(Model):
  72. def __init__(self, device, allowedList=None):
  73. super().__init__(device, allowedList)
  74. self.device = select_device(self.device_)
  75. self.half = self.device.type != 'cpu' # half precision only supported on CUDA
  76. self.model = attempt_load("../AIlib/weights/forest/best.pt", map_location=self.device) # load FP32 model
  77. if self.half:
  78. self.model.half()
  79. self.segmodel = None
  80. # names, label_arraylist, rainbows, conf_thres, iou_thres
  81. def process(self, frame, config):
  82. return AI_process_forest([frame], self.model, self.segmodel, config[0], config[1], config[2],
  83. self.half, self.device, config[3], config[4], self.allowedList)
  84. # 交通模型
  85. class RFModel(Model):
  86. def __init__(self, device, allowedList=None):
  87. super().__init__(device, allowedList)
  88. self.device = select_device(self.device_)
  89. self.half = self.device.type != 'cpu' # half precision only supported on CUDA
  90. self.model = attempt_load("../AIlib/weights/road/best.pt", map_location=self.device) # load FP32 model
  91. if self.half:
  92. self.model.half()
  93. self.segmodel = None
  94. # names, label_arraylist, rainbows, conf_thres, iou_thres
  95. def process(self, frame, config):
  96. return AI_process_forest([frame], self.model, self.segmodel, config[0], config[1], config[2],
  97. self.half, self.device, config[3], config[4], self.allowedList)
  98. def get_model(args):
  99. for model in args[2]:
  100. try:
  101. code = '001'
  102. needed_objectsIndex = [int(category.get("id")) for category in model.get("categories")]
  103. if code == ModelType.WATER_SURFACE_MODEL.value[1]:
  104. return SZModel(args[1], needed_objectsIndex), code, args[0].get("sz")
  105. elif code == ModelType.FOREST_FARM_MODEL.value[1]:
  106. return LCModel(args[1], needed_objectsIndex), code, args[0].get("lc")
  107. elif code == ModelType.TRAFFIC_FARM_MODEL.value[1]:
  108. return RFModel(args[1], needed_objectsIndex), code, args[0].get("rf")
  109. else:
  110. raise Exception("11111")
  111. except Exception as e:
  112. raise Exception("22222")
  113. class PictureWaterMark():
  114. def common_water(self, image, logo):
  115. width, height = image.shape[1], image.shape[0]
  116. mark_width, mark_height = logo.shape[1], logo.shape[0]
  117. rate = int(width * 0.2) / mark_width
  118. logo_new = cv2.resize(logo, None, fx=rate, fy=rate, interpolation=cv2.INTER_NEAREST)
  119. position = (int(width * 0.95 - logo_new.shape[1]), int(height * 0.95 - logo_new.shape[0]))
  120. b = Image.new('RGBA', (width, height), (0, 0, 0, 0)) # 创建新图像:透明'
  121. a = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
  122. watermark = Image.fromarray(cv2.cvtColor(logo_new, cv2.COLOR_BGRA2RGBA))
  123. # 图片旋转
  124. # watermark = watermark.rotate(45)
  125. b.paste(a, (0, 0))
  126. b.paste(watermark, position, mask=watermark)
  127. return cv2.cvtColor(np.asarray(b), cv2.COLOR_BGR2RGB)
  128. def common_water_1(self, image, logo, alpha=1):
  129. h, w = image.shape[0], image.shape[1]
  130. if w >= h:
  131. rate = int(w * 0.1) / logo.shape[1]
  132. else:
  133. rate = int(h * 0.1) / logo.shape[0]
  134. mask = cv2.resize(logo, None, fx=rate, fy=rate, interpolation=cv2.INTER_NEAREST)
  135. mask_h, mask_w = mask.shape[0], mask.shape[1]
  136. mask_channels = cv2.split(mask)
  137. dst_channels = cv2.split(image)
  138. # b, g, r, a = cv2.split(mask)
  139. # 计算mask在图片的坐标
  140. ul_points = (int(h * 0.95) - mask_h, int(w - h * 0.05 - mask_w))
  141. dr_points = (int(h * 0.95), int(w - h * 0.05))
  142. for i in range(3):
  143. dst_channels[i][ul_points[0]: dr_points[0], ul_points[1]: dr_points[1]] = dst_channels[i][
  144. ul_points[0]: dr_points[0],
  145. ul_points[1]: dr_points[1]] * (
  146. 255.0 - mask_channels[3] * alpha) / 255
  147. dst_channels[i][ul_points[0]: dr_points[0], ul_points[1]: dr_points[1]] += np.array(
  148. mask_channels[i] * (mask_channels[3] * alpha / 255), dtype=np.uint8)
  149. dst_img = cv2.merge(dst_channels)
  150. return dst_img
  151. def video_merge(frame1, frame2, width, height):
  152. frameLeft = cv2.resize(frame1, (width, height), interpolation=cv2.INTER_LINEAR)
  153. frameRight = cv2.resize(frame2, (width, height), interpolation=cv2.INTER_LINEAR)
  154. frame_merge = np.hstack((frameLeft, frameRight))
  155. # frame_merge = np.hstack((frame1, frame2))
  156. return frame_merge
  157. cap = cv2.VideoCapture("/home/DATA/chenyukun/3.mp4")
  158. # Get video information
  159. fps = int(cap.get(cv2.CAP_PROP_FPS))
  160. print(fps)
  161. width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  162. print(width)
  163. height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  164. print(height)
  165. # command = ['ffmpeg',
  166. # '-y', # 不经过确认,输出时直接覆盖同名文件。
  167. # '-f', 'rawvideo',
  168. # '-vcodec', 'rawvideo',
  169. # '-pix_fmt', 'bgr24',
  170. # # '-s', "{}x{}".format(self.width * 2, self.height),
  171. # '-s', "{}x{}".format(width, height),
  172. # '-r', str(15),
  173. # '-i', '-', # 指定输入文件
  174. # '-g', '15',
  175. # '-sc_threshold', '0', # 使得GOP的插入更加均匀
  176. # '-b:v', '3000k', # 指定码率
  177. # '-tune', 'zerolatency', # 加速编码速度
  178. # '-c:v', 'libx264', # 指定视频编码器
  179. # '-pix_fmt', 'yuv420p',
  180. # "-an",
  181. # '-preset', 'ultrafast', # 指定输出的视频质量,会影响文件的生成速度,有以下几个可用的值 ultrafast,
  182. # # superfast, veryfast, faster, fast, medium, slow, slower, veryslow。
  183. # '-f', 'flv',
  184. # "rtmp://live.push.t-aaron.com/live/THSAk"]
  185. #
  186. # # 管道配置
  187. # p = sp.Popen(command, stdin=sp.PIPE, shell=False)
  188. sz = SZModelConfig()
  189. lc = LCModelConfig()
  190. rf = RFModelConfig()
  191. config = {
  192. "sz": (sz.names, sz.label_arraylist, sz.rainbows, sz.conf_thres, sz.iou_thres),
  193. "lc": (lc.names, lc.label_arraylist, lc.rainbows, lc.conf_thres, lc.iou_thres),
  194. "rf": (rf.names, rf.label_arraylist, rf.rainbows, rf.conf_thres, rf.iou_thres),
  195. }
  196. model = {
  197. "models": [
  198. {
  199. "code": "001",
  200. "categories": [
  201. {
  202. "id": "0",
  203. "config": {}
  204. },
  205. {
  206. "id": "1",
  207. "config": {}
  208. },
  209. {
  210. "id": "2",
  211. "config": {}
  212. },
  213. {
  214. "id": "3",
  215. "config": {}
  216. },
  217. {
  218. "id": "4",
  219. "config": {}
  220. },
  221. {
  222. "id": "5",
  223. "config": {}
  224. },
  225. {
  226. "id": "6",
  227. "config": {}
  228. },
  229. {
  230. "id": "7",
  231. "config": {}
  232. }
  233. ]
  234. }]
  235. }
  236. mod, model_type_code, modelConfig = get_model((config, str(1), model.get("models")))
  237. pic = PictureWaterMark()
  238. logo = cv2.imread("./image/logo.png", -1)
  239. ai_video_file = cv2.VideoWriter("/home/DATA/chenyukun/aa/1.mp4", cv2.VideoWriter_fourcc(*'mp4v'), fps, (width*2, height))
  240. while(cap.isOpened()):
  241. start =time.time()
  242. ret, frame = cap.read()
  243. # cap.grab()
  244. if not ret:
  245. print("Opening camera is failed")
  246. break
  247. p_result, timeOut = mod.process(copy.deepcopy(frame), modelConfig)
  248. frame = pic.common_water_1(frame, logo)
  249. p_result[1] = pic.common_water_1(p_result[1], logo)
  250. frame_merge = video_merge(frame, p_result[1], width, height)
  251. ai_video_file.write(frame_merge)
  252. print(time.time()-start)
  253. ai_video_file.release()
  254. cap.release()