114 lines
4.2 KiB
Python
114 lines
4.2 KiB
Python
import torch
|
|
import numpy as np
|
|
from models.experimental import attempt_load
|
|
from utils.general import non_max_suppression, scale_coords
|
|
from utils.BaseDetector import baseDet
|
|
from utils.torch_utils import select_device
|
|
from utils.datasets import letterbox
|
|
import random
|
|
import cv2
|
|
|
|
class Colors:
|
|
# Ultralytics color palette https://ultralytics.com/
|
|
def __init__(self):
|
|
# hex = matplotlib.colors.TABLEAU_COLORS.values()
|
|
hex = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
|
|
'2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
|
|
self.palette = [self.hex2rgb('#' + c) for c in hex]
|
|
self.n = len(self.palette)
|
|
|
|
def __call__(self, i, bgr=False):
|
|
c = self.palette[int(i) % self.n]
|
|
return (c[2], c[1], c[0]) if bgr else c
|
|
|
|
@staticmethod
|
|
def hex2rgb(h): # rgb order (PIL)
|
|
return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
|
|
|
|
|
|
colors = Colors() # create instance for 'from utils.plots import colors'
|
|
def plot_one_box(x, img, color=None, label=None, line_thickness=3):
|
|
# Plots one bounding box on image img
|
|
tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
|
|
color = color or [random.randint(0, 255) for _ in range(3)]
|
|
c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
|
|
cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
|
|
if label:
|
|
tf = max(tl - 1, 1) # font thickness
|
|
t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
|
|
c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
|
|
cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA) # filled
|
|
# cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
|
|
cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3, [0, 0, 0], thickness=tf, lineType=cv2.LINE_AA)
|
|
|
|
class Detector(baseDet):
|
|
|
|
def __init__(self,dete_weights):
|
|
super(Detector, self).__init__()
|
|
self.init_model(dete_weights)
|
|
self.build_config()
|
|
|
|
def init_model(self,dete_weights):
|
|
|
|
# self.weights = 'weights/best_luoshui20230608.pt'
|
|
self.weights = dete_weights
|
|
|
|
self.device = '0' if torch.cuda.is_available() else 'cpu'
|
|
self.device = select_device(self.device)
|
|
model = attempt_load(self.weights, map_location=self.device)
|
|
model.to(self.device).eval()
|
|
model.half()
|
|
# torch.save(model, 'test.pt')
|
|
self.m = model
|
|
self.names = model.module.names if hasattr(
|
|
model, 'module') else model.names
|
|
|
|
def preprocess(self, img):
|
|
|
|
img0 = img.copy()
|
|
img = letterbox(img, new_shape=self.img_size)[0]
|
|
img = img[:, :, ::-1].transpose(2, 0, 1)
|
|
img = np.ascontiguousarray(img)
|
|
img = torch.from_numpy(img).to(self.device)
|
|
img = img.half() # 半精度
|
|
img /= 255.0 # 图像归一化
|
|
if img.ndimension() == 3:
|
|
img = img.unsqueeze(0)
|
|
|
|
return img0, img
|
|
|
|
|
|
|
|
def detect(self, im):
|
|
|
|
im0, img = self.preprocess(im)
|
|
|
|
pred = self.m(img, augment=False)[0]
|
|
pred = pred.float()
|
|
pred = non_max_suppression(pred, self.threshold, 0.4)
|
|
|
|
pred_boxes = []
|
|
for det in pred:
|
|
|
|
if det is not None and len(det):
|
|
det[:, :4] = scale_coords(
|
|
img.shape[2:], det[:, :4], im0.shape).round()
|
|
|
|
for *x, conf, cls_id in det:
|
|
lbl = self.names[int(cls_id)]
|
|
# if not lbl in ['person', 'car', 'truck']:#不在这个类别中,则继续
|
|
# continue
|
|
# if not lbl in ['head', 'boat']: # 不在这个类别中,则继续
|
|
# continue
|
|
x1, y1 = int(x[0]), int(x[1])
|
|
x2, y2 = int(x[2]), int(x[3])
|
|
pred_boxes.append((x1, y1, x2, y2, lbl, conf))
|
|
|
|
c = int(cls_id) # integer class
|
|
plot_one_box(x, im0, label=lbl, color=colors(c, True), line_thickness=3)
|
|
|
|
cv2.imwrite('test_result_1.png', im0)
|
|
|
|
return im, pred_boxes
|
|
|