import torch import numpy as np import cv2 import time import os import matplotlib.pyplot as plt import func_utils import time def apply_mask(image, mask, alpha=0.5): """Apply the given mask to the image. """ color = np.random.rand(3) for c in range(3): image[:, :, c] = np.where(mask == 1, image[:, :, c] * (1 - alpha) + alpha * color[c] * 255, image[:, :, c]) return image if not os.path.exists('output'): os.mkdir('output') saveDir = 'output' class TestModule(object): def __init__(self, dataset, num_classes, model, decoder): torch.manual_seed(317) self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") self.dataset = dataset self.num_classes = num_classes self.model = model self.decoder = decoder def load_model(self, model, resume): checkpoint = torch.load(resume, map_location=lambda storage, loc: storage) print('loaded weights from {}, epoch {}'.format(resume, checkpoint['epoch'])) state_dict_ = checkpoint['model_state_dict'] model.load_state_dict(state_dict_, strict=True) return model def map_mask_to_image(self, mask, img, color=None): if color is None: color = np.random.rand(3) mask = np.repeat(mask[:, :, np.newaxis], 3, axis=2) mskd = img * mask clmsk = np.ones(mask.shape) * mask clmsk[:, :, 0] = clmsk[:, :, 0] * color[0] * 256 clmsk[:, :, 1] = clmsk[:, :, 1] * color[1] * 256 clmsk[:, :, 2] = clmsk[:, :, 2] * color[2] * 256 img = img + 1. * clmsk - 1. * mskd return np.uint8(img) def imshow_heatmap(self, pr_dec, images): wh = pr_dec['wh'] hm = pr_dec['hm'] cls_theta = pr_dec['cls_theta'] wh_w = wh[0, 0, :, :].data.cpu().numpy() wh_h = wh[0, 1, :, :].data.cpu().numpy() hm = hm[0, 0, :, :].data.cpu().numpy() cls_theta = cls_theta[0, 0, :, :].data.cpu().numpy() images = np.transpose((images.squeeze(0).data.cpu().numpy() + 0.5) * 255, (1, 2, 0)).astype(np.uint8) wh_w = cv2.resize(wh_w, (images.shape[1], images.shape[0])) wh_h = cv2.resize(wh_h, (images.shape[1], images.shape[0])) hm = cv2.resize(hm, (images.shape[1], images.shape[0])) fig = plt.figure(1) ax1 = fig.add_subplot(2, 3, 1) ax1.set_xlabel('width') ax1.imshow(wh_w) ax2 = fig.add_subplot(2, 3, 2) ax2.set_xlabel('height') ax2.imshow(wh_h) ax3 = fig.add_subplot(2, 3, 3) ax3.set_xlabel('center hm') ax3.imshow(hm) ax5 = fig.add_subplot(2, 3, 5) ax5.set_xlabel('input image') ax5.imshow(cls_theta) ax6 = fig.add_subplot(2, 3, 6) ax6.set_xlabel('input image') ax6.imshow(images) plt.savefig('heatmap.png') def test(self, args, down_ratio): save_path = 'weights_'+args.dataset self.model = self.load_model(self.model, os.path.join(save_path, args.resume)) self.model = self.model.to(self.device) self.model.eval() t1 = time.time() dataset_module = self.dataset[args.dataset] dsets = dataset_module(data_dir=args.data_dir, phase='test', input_h=args.input_h, input_w=args.input_w, down_ratio=down_ratio) data_loader = torch.utils.data.DataLoader(dsets, batch_size=1, shuffle=False, num_workers=1, pin_memory=True) t2 = time.time() total_time = [] for cnt, data_dict in enumerate(data_loader): t4=time.time() image = data_dict['image'][0].to(self.device) img_id = data_dict['img_id'][0] print('processing {}/{} image ...'.format(cnt, len(data_loader))) begin_time = time.time() with torch.no_grad(): pr_decs = self.model(image) #self.imshow_heatmap(pr_decs[2], image) torch.cuda.synchronize(self.device) decoded_pts = [] decoded_scores = [] predictions = self.decoder.ctdet_decode(pr_decs) pts0, scores0 = func_utils.decode_prediction(predictions, dsets, args, img_id, down_ratio) decoded_pts.append(pts0) decoded_scores.append(scores0) t5 = time.time() #nms results = {cat:[] for cat in dsets.category} for cat in dsets.category: if cat == 'background': continue pts_cat = [] scores_cat = [] for pts0, scores0 in zip(decoded_pts, decoded_scores): pts_cat.extend(pts0[cat]) scores_cat.extend(scores0[cat]) pts_cat = np.asarray(pts_cat, np.float32) scores_cat = np.asarray(scores_cat, np.float32) if pts_cat.shape[0]: nms_results = func_utils.non_maximum_suppression(pts_cat, scores_cat) results[cat].extend(nms_results) end_time = time.time() total_time.append(end_time-begin_time) #""" ori_image = dsets.load_image(cnt) height, width, _ = ori_image.shape # ori_image = cv2.resize(ori_image, (args.input_w, args.input_h)) # ori_image = cv2.resize(ori_image, (args.input_w//args.down_ratio, args.input_h//args.down_ratio)) #nms for cat in dsets.category: if cat == 'background': continue result = results[cat] for pred in result: score = pred[-1] tl = np.asarray([pred[0], pred[1]], np.float32) tr = np.asarray([pred[2], pred[3]], np.float32) br = np.asarray([pred[4], pred[5]], np.float32) bl = np.asarray([pred[6], pred[7]], np.float32) tt = (np.asarray(tl, np.float32) + np.asarray(tr, np.float32)) / 2 rr = (np.asarray(tr, np.float32) + np.asarray(br, np.float32)) / 2 bb = (np.asarray(bl, np.float32) + np.asarray(br, np.float32)) / 2 ll = (np.asarray(tl, np.float32) + np.asarray(bl, np.float32)) / 2 box = np.asarray([tl, tr, br, bl], np.float32) cen_pts = np.mean(box, axis=0) cv2.line(ori_image, (int(cen_pts[0]), int(cen_pts[1])), (int(tt[0]), int(tt[1])), (0,0,255),1,1) #原来 cv2.line(ori_image, (int(cen_pts[0]), int(cen_pts[1])), (int(rr[0]), int(rr[1])), (255,0,255),1,1) cv2.line(ori_image, (int(cen_pts[0]), int(cen_pts[1])), (int(bb[0]), int(bb[1])), (0,255,0),1,1) cv2.line(ori_image, (int(cen_pts[0]), int(cen_pts[1])), (int(ll[0]), int(ll[1])), (255,0,0),1,1) # cv2.circle(ori_image, (320, 240), 5, (0, 0, 255), -1) # cv2.line(ori_image, (int(cen_pts[0]), int(cen_pts[1])), (int(tl[0]), int(tl[1])), (0,0,255),1,1) # cv2.line(ori_image, (int(cen_pts[0]), int(cen_pts[1])), (int(tr[0]), int(tr[1])), (255,0,255),1,1) # cv2.line(ori_image, (int(cen_pts[0]), int(cen_pts[1])), (int(br[0]), int(br[1])), (0,255,0),1,1) # cv2.line(ori_image, (int(cen_pts[0]), int(cen_pts[1])), (int(bl[0]), int(bl[1])), (255,0,0),1,1) # ori_image = cv2.drawContours(ori_image, [np.int0(box)], -1, (255,0,255),1,1) ori_image = cv2.drawContours(ori_image, [np.int0(box)], -1, (255,0,255),3,1) # box = cv2.boxPoints(cv2.minAreaRect(box)) # ori_image = cv2.drawContours(ori_image, [np.int0(box)], -1, (0,255,0),1,1) # cv2.putText(ori_image, '{:.2f} {}'.format(score, cat), (int(box[1][0]), int(box[1][1])), # cv2.FONT_HERSHEY_COMPLEX, 0.5, (0,255,255), 1,1) cv2.putText(ori_image, '{:.2f} {}'.format(score, cat), (int(box[1][0]), int(box[1][1])), cv2.FONT_HERSHEY_COMPLEX, 2, (78,110,240), 3,1) t6 = time.time() # print(f'one picture. ({(1E3 * (t5 - t4)):.1f}ms) Inference, ({(1E3 * (t6 - t5)):.1f}ms) NMS,') if args.dataset == 'hrsc': gt_anno = dsets.load_annotation(cnt) for pts_4 in gt_anno['pts']: bl = pts_4[0, :] tl = pts_4[1, :] tr = pts_4[2, :] br = pts_4[3, :] cen_pts = np.mean(pts_4, axis=0) box = np.asarray([bl, tl, tr, br], np.float32) box = np.int0(box) cv2.drawContours(ori_image, [box], 0, (255, 255, 255), 1) # imgName = os.path.basename(img_id) + '.png' imgName = os.path.basename(img_id) + '.jpg' saveFile = os.path.join(saveDir, imgName) cv2.imwrite(saveFile, ori_image,[cv2.IMWRITE_JPEG_QUALITY, 10]) # cv2.imwrite(saveFile, ori_image,[cv2.IMWRITE_PNG_COMPRESSION, 10]) t7 = time.time() print(f'one picture. ({(1E3 * (t5 - t4)):.1f}ms) Inference, ({(1E3 * (t6 - t5)):.1f}ms) NMS,({(1E3 * (t7 - t6)):.1f}ms) save image') print(saveFile) # cv2.imshow('pr_image', ori_image) # k = cv2.waitKey(0) & 0xFF # if k == ord('q'): # cv2.destroyAllWindows() # exit() #""" t3 = time.time() # total_time = total_time[1:] # print('avg time is {}'.format(np.mean(total_time))) # print('FPS is {}'.format(1./np.mean(total_time))) print(f'Done. ({(1E3 * (t3 - t2)):.1f}ms) Inference, ({(1E3 * (t2 - t1)):.1f}ms) load data,')