You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

298 lines
11KB

  1. import copy
  2. import subprocess as sp
  3. from enum import Enum, unique
  4. from PIL import Image
  5. import time
  6. import cv2
  7. import sys
  8. sys.path.extend(['..','../AIlib' ])
  9. from AI import AI_process, AI_process_forest, get_postProcess_para
  10. import cv2,os,time
  11. from segutils.segmodel import SegModel
  12. from models.experimental import attempt_load
  13. from utils.torch_utils import select_device
  14. from utilsK.queRiver import get_labelnames,get_label_arrays
  15. import numpy as np
  16. import torch
  17. from utilsK.masterUtils import get_needed_objectsIndex
  18. # 异常枚举
  19. @unique
  20. class ModelType(Enum):
  21. WATER_SURFACE_MODEL = ("1", "001", "水面模型")
  22. FOREST_FARM_MODEL = ("2", "002", "森林模型")
  23. TRAFFIC_FARM_MODEL = ("3", "003", "交通模型")
  24. def checkCode(code):
  25. for model in ModelType:
  26. if model.value[1] == code:
  27. return True
  28. return False
  29. class ModelConfig():
  30. def __init__(self):
  31. postFile = '../AIlib/conf/para.json'
  32. self.conf_thres, self.iou_thres, self.classes, self.rainbows = get_postProcess_para(postFile)
  33. class SZModelConfig(ModelConfig):
  34. def __init__(self):
  35. super(SZModelConfig, self).__init__()
  36. labelnames = "../AIlib/weights/yolov5/class8/labelnames.json" ##对应类别表
  37. self.names = get_labelnames(labelnames)
  38. self.label_arraylist = get_label_arrays(self.names, self.rainbows, outfontsize=40,
  39. fontpath="../AIlib/conf/platech.ttf")
  40. class LCModelConfig(ModelConfig):
  41. def __init__(self):
  42. super(LCModelConfig, self).__init__()
  43. labelnames = "../AIlib/weights/forest/labelnames.json"
  44. self.names = get_labelnames(labelnames)
  45. self.label_arraylist = get_label_arrays(self.names, self.rainbows, outfontsize=40, fontpath="../AIlib/conf/platech.ttf")
  46. class RFModelConfig(ModelConfig):
  47. def __init__(self):
  48. super(RFModelConfig, self).__init__()
  49. labelnames = "../AIlib/weights/road/labelnames.json"
  50. self.names = get_labelnames(labelnames)
  51. imageW = 1536
  52. outfontsize=int(imageW/1920*40)
  53. self.label_arraylist = get_label_arrays(self.names, self.rainbows, outfontsize=outfontsize, fontpath="../AIlib/conf/platech.ttf")
  54. class Model():
  55. def __init__(self, device, allowedList=None):
  56. ##预先设置的参数
  57. self.device_ = device ##选定模型,可选 cpu,'0','1'
  58. self.allowedList = allowedList
  59. # 水面模型
  60. class SZModel(Model):
  61. def __init__(self, device, allowedList=None):
  62. super().__init__(device, allowedList)
  63. self.device = select_device(self.device_)
  64. self.half = self.device.type != 'cpu'
  65. self.model = attempt_load("../AIlib/weights/yolov5/class8/bestcao.pt", map_location=self.device)
  66. if self.half:
  67. self.model.half()
  68. self.segmodel = SegModel(nclass=2, weights='../AIlib/weights/STDC/model_maxmIOU75_1720_0.946_360640.pth',
  69. device=self.device)
  70. # names, label_arraylist, rainbows, conf_thres, iou_thres
  71. def process(self, frame, config):
  72. return AI_process([frame], self.model, self.segmodel, config[0], config[1],
  73. config[2], self.half, self.device, config[3], config[4],
  74. self.allowedList)
  75. # 森林模型
  76. class LCModel(Model):
  77. def __init__(self, device, allowedList=None):
  78. super().__init__(device, allowedList)
  79. self.device = select_device(self.device_)
  80. self.half = self.device.type != 'cpu' # half precision only supported on CUDA
  81. self.model = attempt_load("../AIlib/weights/forest/best.pt", map_location=self.device) # load FP32 model
  82. if self.half:
  83. self.model.half()
  84. self.segmodel = None
  85. # names, label_arraylist, rainbows, conf_thres, iou_thres
  86. def process(self, frame, config):
  87. return AI_process_forest([frame], self.model, self.segmodel, config[0], config[1], config[2],
  88. self.half, self.device, config[3], config[4], self.allowedList)
  89. # 交通模型
  90. class RFModel(Model):
  91. def __init__(self, device, allowedList=None):
  92. super().__init__(device, allowedList)
  93. self.device = select_device(self.device_)
  94. self.half = self.device.type != 'cpu' # half precision only supported on CUDA
  95. self.model = attempt_load("../AIlib/weights/road/best.pt", map_location=self.device) # load FP32 model
  96. if self.half:
  97. self.model.half()
  98. self.segmodel = None
  99. # names, label_arraylist, rainbows, conf_thres, iou_thres
  100. def process(self, frame, config):
  101. return AI_process_forest([frame], self.model, self.segmodel, config[0], config[1], config[2],
  102. self.half, self.device, config[3], config[4], self.allowedList)
  103. def get_model(args):
  104. for model in args[2]:
  105. try:
  106. code = '001'
  107. needed_objectsIndex = [int(category.get("id")) for category in model.get("categories")]
  108. if code == ModelType.WATER_SURFACE_MODEL.value[1]:
  109. return SZModel(args[1], needed_objectsIndex), code, args[0].get("sz")
  110. elif code == ModelType.FOREST_FARM_MODEL.value[1]:
  111. return LCModel(args[1], needed_objectsIndex), code, args[0].get("lc")
  112. elif code == ModelType.TRAFFIC_FARM_MODEL.value[1]:
  113. return RFModel(args[1], needed_objectsIndex), code, args[0].get("rf")
  114. else:
  115. raise Exception("11111")
  116. except Exception as e:
  117. raise Exception("22222")
  118. class PictureWaterMark():
  119. def common_water(self, image, logo):
  120. width, height = image.shape[1], image.shape[0]
  121. mark_width, mark_height = logo.shape[1], logo.shape[0]
  122. rate = int(width * 0.2) / mark_width
  123. logo_new = cv2.resize(logo, None, fx=rate, fy=rate, interpolation=cv2.INTER_NEAREST)
  124. position = (int(width * 0.95 - logo_new.shape[1]), int(height * 0.95 - logo_new.shape[0]))
  125. b = Image.new('RGBA', (width, height), (0, 0, 0, 0)) # 创建新图像:透明'
  126. a = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
  127. watermark = Image.fromarray(cv2.cvtColor(logo_new, cv2.COLOR_BGRA2RGBA))
  128. # 图片旋转
  129. # watermark = watermark.rotate(45)
  130. b.paste(a, (0, 0))
  131. b.paste(watermark, position, mask=watermark)
  132. return cv2.cvtColor(np.asarray(b), cv2.COLOR_BGR2RGB)
  133. def common_water_1(self, image, logo, alpha=1):
  134. h, w = image.shape[0], image.shape[1]
  135. if w >= h:
  136. rate = int(w * 0.1) / logo.shape[1]
  137. else:
  138. rate = int(h * 0.1) / logo.shape[0]
  139. mask = cv2.resize(logo, None, fx=rate, fy=rate, interpolation=cv2.INTER_NEAREST)
  140. mask_h, mask_w = mask.shape[0], mask.shape[1]
  141. mask_channels = cv2.split(mask)
  142. dst_channels = cv2.split(image)
  143. # b, g, r, a = cv2.split(mask)
  144. # 计算mask在图片的坐标
  145. ul_points = (int(h * 0.95) - mask_h, int(w - h * 0.05 - mask_w))
  146. dr_points = (int(h * 0.95), int(w - h * 0.05))
  147. for i in range(3):
  148. dst_channels[i][ul_points[0]: dr_points[0], ul_points[1]: dr_points[1]] = dst_channels[i][
  149. ul_points[0]: dr_points[0],
  150. ul_points[1]: dr_points[1]] * (
  151. 255.0 - mask_channels[3] * alpha) / 255
  152. dst_channels[i][ul_points[0]: dr_points[0], ul_points[1]: dr_points[1]] += np.array(
  153. mask_channels[i] * (mask_channels[3] * alpha / 255), dtype=np.uint8)
  154. dst_img = cv2.merge(dst_channels)
  155. return dst_img
  156. def video_merge(frame1, frame2, width, height):
  157. frameLeft = cv2.resize(frame1, (width, height), interpolation=cv2.INTER_LINEAR)
  158. frameRight = cv2.resize(frame2, (width, height), interpolation=cv2.INTER_LINEAR)
  159. frame_merge = np.hstack((frameLeft, frameRight))
  160. # frame_merge = np.hstack((frame1, frame2))
  161. return frame_merge
  162. cap = cv2.VideoCapture("/home/DATA/chenyukun/4.mp4")
  163. # Get video information
  164. fps = int(cap.get(cv2.CAP_PROP_FPS))
  165. print(fps)
  166. width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  167. print(width)
  168. height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  169. print(height)
  170. # command = ['ffmpeg',
  171. # '-y', # 不经过确认,输出时直接覆盖同名文件。
  172. # '-f', 'rawvideo',
  173. # '-vcodec', 'rawvideo',
  174. # '-pix_fmt', 'bgr24',
  175. # # '-s', "{}x{}".format(self.width * 2, self.height),
  176. # '-s', "{}x{}".format(width, height),
  177. # '-r', str(15),
  178. # '-i', '-', # 指定输入文件
  179. # '-g', '15',
  180. # '-sc_threshold', '0', # 使得GOP的插入更加均匀
  181. # '-b:v', '3000k', # 指定码率
  182. # '-tune', 'zerolatency', # 加速编码速度
  183. # '-c:v', 'libx264', # 指定视频编码器
  184. # '-pix_fmt', 'yuv420p',
  185. # "-an",
  186. # '-preset', 'ultrafast', # 指定输出的视频质量,会影响文件的生成速度,有以下几个可用的值 ultrafast,
  187. # # superfast, veryfast, faster, fast, medium, slow, slower, veryslow。
  188. # '-f', 'flv',
  189. # "rtmp://live.push.t-aaron.com/live/THSAk"]
  190. #
  191. # # 管道配置
  192. # p = sp.Popen(command, stdin=sp.PIPE, shell=False)
  193. sz = SZModelConfig()
  194. lc = LCModelConfig()
  195. rf = RFModelConfig()
  196. config = {
  197. "sz": (sz.names, sz.label_arraylist, sz.rainbows, sz.conf_thres, sz.iou_thres),
  198. "lc": (lc.names, lc.label_arraylist, lc.rainbows, lc.conf_thres, lc.iou_thres),
  199. "rf": (rf.names, rf.label_arraylist, rf.rainbows, rf.conf_thres, rf.iou_thres),
  200. }
  201. model = {
  202. "models": [
  203. {
  204. "code": "001",
  205. "categories": [
  206. {
  207. "id": "0",
  208. "config": {}
  209. },
  210. {
  211. "id": "1",
  212. "config": {}
  213. },
  214. {
  215. "id": "2",
  216. "config": {}
  217. },
  218. {
  219. "id": "3",
  220. "config": {}
  221. },
  222. {
  223. "id": "4",
  224. "config": {}
  225. },
  226. {
  227. "id": "5",
  228. "config": {}
  229. },
  230. {
  231. "id": "6",
  232. "config": {}
  233. },
  234. {
  235. "id": "7",
  236. "config": {}
  237. }
  238. ]
  239. }]
  240. }
  241. mod, model_type_code, modelConfig = get_model((config, str(0), model.get("models")))
  242. pic = PictureWaterMark()
  243. logo = cv2.imread("./image/logo.png", -1)
  244. ai_video_file = cv2.VideoWriter("/home/DATA/chenyukun/aa/2.mp4", cv2.VideoWriter_fourcc(*'mp4v'), fps, (width*2, height))
  245. while(cap.isOpened()):
  246. start =time.time()
  247. ret, frame = cap.read()
  248. # cap.grab()
  249. if not ret:
  250. print("Opening camera is failed")
  251. break
  252. p_result, timeOut = mod.process(copy.deepcopy(frame), modelConfig)
  253. frame = pic.common_water_1(frame, logo)
  254. p_result[1] = pic.common_water_1(p_result[1], logo)
  255. frame_merge = video_merge(frame, p_result[1], width, height)
  256. ai_video_file.write(frame_merge)
  257. print(time.time()-start)
  258. ai_video_file.release()
  259. cap.release()