|
|
@@ -8,6 +8,9 @@ from DMPRUtils.DMPR_process import DMPR_process, plot_points |
|
|
|
from DMPRUtils.model.detector import DirectionalPointDetector |
|
|
|
from DMPRUtils.yolo_net import Model |
|
|
|
from DMPR_YOLO.jointUtil import dmpr_yolo |
|
|
|
from STDCUtils.STDC_process import STDC_process |
|
|
|
from STDCUtils.models.model_stages import BiSeNet |
|
|
|
from STDC_YOLO.yolo_stdc_joint import stdc_yolo |
|
|
|
from conf import config |
|
|
|
from models.experimental import attempt_load |
|
|
|
from models.yolo_process import yolo_process |
|
|
@@ -20,9 +23,9 @@ def main(): |
|
|
|
device_ = '1' ##选定模型,可选 cpu,'0','1' |
|
|
|
|
|
|
|
##以下参数目前不可改 |
|
|
|
Detweights = 'weights/urbanManagement/yolo/best.pt' |
|
|
|
Detweights = 'weights/urbanManagement/yolo/best1023.pt' |
|
|
|
seg_nclass = 2 |
|
|
|
DMPRweights = "weights/urbanManagement/DMPR/dp_detector_299.pth" |
|
|
|
DMPRweights = "weights/urbanManagement/DMPR/dp_detector_299_1023.pth" |
|
|
|
conf_thres, iou_thres, classes = 0.25, 0.45, 3 |
|
|
|
labelnames = "weights/yolov5/class5/labelnames.json" |
|
|
|
rainbows = [[0, 0, 255], [0, 255, 0], [255, 0, 0], [255, 0, 255], [255, 255, 0], [255, 129, 0], [255, 0, 127], |
|
|
@@ -40,6 +43,17 @@ def main(): |
|
|
|
if half: |
|
|
|
model.half() |
|
|
|
|
|
|
|
# STDC model |
|
|
|
args = config.get_parser_for_inference().parse_args() |
|
|
|
|
|
|
|
# STDC model |
|
|
|
STDC_model = BiSeNet(backbone=args.backbone, n_classes=args.n_classes, |
|
|
|
use_boundary_2=args.use_boundary_2, use_boundary_4=args.use_boundary_4, |
|
|
|
use_boundary_8=args.use_boundary_8, use_boundary_16=args.use_boundary_16, |
|
|
|
use_conv_last=args.use_conv_last).to(device) |
|
|
|
STDC_model.load_state_dict(torch.load(args.respth)) |
|
|
|
STDC_model.eval() |
|
|
|
|
|
|
|
# DMPR model |
|
|
|
args = config.get_parser_for_inference().parse_args() |
|
|
|
# DMPRmodel = DirectionalPointDetector(3, args.depth_factor, config.NUM_FEATURE_MAP_CHANNEL).to(device) |
|
|
@@ -66,6 +80,12 @@ def main(): |
|
|
|
t_yolo = time.time() |
|
|
|
print(f't_yolo. ({t_yolo - t_start:.3f}s)') |
|
|
|
|
|
|
|
# STDC process |
|
|
|
det2 = STDC_process(img0, STDC_model, device, args.n_classes, args.stdc_scale) |
|
|
|
|
|
|
|
# STDC joint yolo |
|
|
|
det0 = stdc_yolo(det2, det0) |
|
|
|
|
|
|
|
# plot所有box |
|
|
|
# for *xyxy, conf, cls in reversed(det0): |
|
|
|
# label = f'{int(cls)} {conf:.2f}' |