diff --git a/AI_example.py b/AI_example.py index 290bbf7..9ca4b4d 100644 --- a/AI_example.py +++ b/AI_example.py @@ -8,6 +8,9 @@ from DMPRUtils.DMPR_process import DMPR_process, plot_points from DMPRUtils.model.detector import DirectionalPointDetector from DMPRUtils.yolo_net import Model from DMPR_YOLO.jointUtil import dmpr_yolo +from STDCUtils.STDC_process import STDC_process +from STDCUtils.models.model_stages import BiSeNet +from STDC_YOLO.yolo_stdc_joint import stdc_yolo from conf import config from models.experimental import attempt_load from models.yolo_process import yolo_process @@ -20,9 +23,9 @@ def main(): device_ = '1' ##选定模型,可选 cpu,'0','1' ##以下参数目前不可改 - Detweights = 'weights/urbanManagement/yolo/best.pt' + Detweights = 'weights/urbanManagement/yolo/best1023.pt' seg_nclass = 2 - DMPRweights = "weights/urbanManagement/DMPR/dp_detector_299.pth" + DMPRweights = "weights/urbanManagement/DMPR/dp_detector_299_1023.pth" conf_thres, iou_thres, classes = 0.25, 0.45, 3 labelnames = "weights/yolov5/class5/labelnames.json" rainbows = [[0, 0, 255], [0, 255, 0], [255, 0, 0], [255, 0, 255], [255, 255, 0], [255, 129, 0], [255, 0, 127], @@ -40,6 +43,17 @@ def main(): if half: model.half() + # STDC model + args = config.get_parser_for_inference().parse_args() + + # STDC model + STDC_model = BiSeNet(backbone=args.backbone, n_classes=args.n_classes, + use_boundary_2=args.use_boundary_2, use_boundary_4=args.use_boundary_4, + use_boundary_8=args.use_boundary_8, use_boundary_16=args.use_boundary_16, + use_conv_last=args.use_conv_last).to(device) + STDC_model.load_state_dict(torch.load(args.respth)) + STDC_model.eval() + # DMPR model args = config.get_parser_for_inference().parse_args() # DMPRmodel = DirectionalPointDetector(3, args.depth_factor, config.NUM_FEATURE_MAP_CHANNEL).to(device) @@ -66,6 +80,12 @@ def main(): t_yolo = time.time() print(f't_yolo. ({t_yolo - t_start:.3f}s)') + # STDC process + det2 = STDC_process(img0, STDC_model, device, args.n_classes, args.stdc_scale) + + # STDC joint yolo + det0 = stdc_yolo(det2, det0) + # plot所有box # for *xyxy, conf, cls in reversed(det0): # label = f'{int(cls)} {conf:.2f}' diff --git a/DMPR_YOLO/jointUtil.py b/DMPR_YOLO/jointUtil.py index 21cb5c7..172bd64 100644 --- a/DMPR_YOLO/jointUtil.py +++ b/DMPR_YOLO/jointUtil.py @@ -101,6 +101,5 @@ def dmpr_yolo(dmpr_det, yolo_det, img_shape, cls:int, scale_ratio): yolo_det_clone[yolo_det_clone[:, -2] == cls, -1] = res - return yolo_det_clone, new_yolo_det diff --git a/conf/config.py b/conf/config.py index 7689197..f8583fd 100644 --- a/conf/config.py +++ b/conf/config.py @@ -53,5 +53,17 @@ def get_parser_for_inference(): help='filter by class: --class 0, or --class 0 2 3') parser.add_argument('--dmpr-thresh', type=float, default=0.3, help="Detection threshold.") + # STDC + parser.add_argument('--n-classes', type=int, default=2, help='number of classes for segment') + parser.add_argument('--backbone', type=str, default='STDCNet813', help='STDC backbone') + parser.add_argument('--respth', type=str, default='weights/urbanManagement/STDC/model_final.pth', + help='The weights of STDC') + parser.add_argument('--stdc-scale', type=float, default=0.75, help='The scale of STDC') + parser.add_argument('--use-boundary-2', type=bool, default=False, help='') + parser.add_argument('--use-boundary-4', type=bool, default=False, help='') + parser.add_argument('--use-boundary-8', type=bool, default=False, help='') + parser.add_argument('--use-boundary-16', type=bool, default=False, help='') + parser.add_argument('--use-conv-last', type=bool, default=False, help='') + add_common_arguments(parser) return parser \ No newline at end of file diff --git a/weights/urbanManagement/yolo/best.pt b/weights/urbanManagement/yolo/best.pt deleted file mode 100644 index 53efa68..0000000 Binary files a/weights/urbanManagement/yolo/best.pt and /dev/null differ diff --git a/weights/urbanManagement/yolo/last.pt b/weights/urbanManagement/yolo/last.pt deleted file mode 100644 index b7d3b04..0000000 Binary files a/weights/urbanManagement/yolo/last.pt and /dev/null differ