AIlib2/obbUtils/func_utils.py

113 lines
4.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os,sys
import torch
import numpy as np
sys.path.extend(['../AIlib2/obbUtils'])
#import datasets.DOTA_devkit.ResultMerge_multi_process
#from datasets.DOTA_devkit.ResultMerge_multi_process import py_cpu_nms_poly_fast, py_cpu_nms_poly
from dotadevkit.ops.ResultMerge import py_cpu_nms_poly_fast, py_cpu_nms_poly
import time
# def decode_prediction(predictions, dsets, args, img_id, down_ratio):
def decode_prediction(predictions, category, model_size, down_ratio,ori_image):
t1=time.time()
predictions = predictions[0, :, :]
# ttt1=time.time()
# # ori_image = dsets.load_image(dsets.img_ids.index(img_id)) #加载了原图第2次这里耗时 改1
# ttt2 = time.time()
# print(f'jiazaitupian. ({(1E3 * (ttt2 - ttt1)):.1f}ms) ')
h, w, c = ori_image.shape
pts0 = {cat: [] for cat in category}
scores0 = {cat: [] for cat in category}
for pred in predictions:
cen_pt = np.asarray([pred[0], pred[1]], np.float32)
tt = np.asarray([pred[2], pred[3]], np.float32)
rr = np.asarray([pred[4], pred[5]], np.float32)
bb = np.asarray([pred[6], pred[7]], np.float32)
ll = np.asarray([pred[8], pred[9]], np.float32)
tl = tt + ll - cen_pt
bl = bb + ll - cen_pt
tr = tt + rr - cen_pt
br = bb + rr - cen_pt
score = pred[10]
clse = pred[11]
pts = np.asarray([tr, br, bl, tl], np.float32)
pts[:, 0] = pts[:, 0] * down_ratio / model_size[0] * w
pts[:, 1] = pts[:, 1] * down_ratio / model_size[1] * h
pts0[category[int(clse)]].append(pts)
scores0[category[int(clse)]].append(score)
t2=time.time()
#print('###line40:decode_prediction time: %.1f ',(t2-t1)*1000.0)
return pts0, scores0
def non_maximum_suppression(pts, scores):
nms_item = np.concatenate([pts[:, 0:1, 0],
pts[:, 0:1, 1],
pts[:, 1:2, 0],
pts[:, 1:2, 1],
pts[:, 2:3, 0],
pts[:, 2:3, 1],
pts[:, 3:4, 0],
pts[:, 3:4, 1],
scores[:, np.newaxis]], axis=1)
nms_item = np.asarray(nms_item, np.float64)
keep_index = py_cpu_nms_poly_fast(dets=nms_item, thresh=0.1)
return nms_item[keep_index]
def write_results(args,
model,
dsets,
down_ratio,
device,
decoder,
result_path,
print_ps=False):
results = {cat: {img_id: [] for img_id in dsets.img_ids} for cat in dsets.category}
for index in range(len(dsets)):
data_dict = dsets.__getitem__(index)
image = data_dict['image'].to(device)
img_id = data_dict['img_id']
image_w = data_dict['image_w']
image_h = data_dict['image_h']
with torch.no_grad():
pr_decs = model(image)
decoded_pts = []
decoded_scores = []
torch.cuda.synchronize(device)
predictions = decoder.ctdet_decode(pr_decs)
pts0, scores0 = decode_prediction(predictions, dsets, args, img_id, down_ratio)
decoded_pts.append(pts0)
decoded_scores.append(scores0)
# nms
for cat in dsets.category:
if cat == 'background':
continue
pts_cat = []
scores_cat = []
for pts0, scores0 in zip(decoded_pts, decoded_scores):
pts_cat.extend(pts0[cat])
scores_cat.extend(scores0[cat])
pts_cat = np.asarray(pts_cat, np.float32)
scores_cat = np.asarray(scores_cat, np.float32)
if pts_cat.shape[0]:
nms_results = non_maximum_suppression(pts_cat, scores_cat)
results[cat][img_id].extend(nms_results)
if print_ps:
print('testing {}/{} data {}'.format(index+1, len(dsets), img_id))
for cat in dsets.category:
if cat == 'background':
continue
with open(os.path.join(result_path, 'Task1_{}.txt'.format(cat)), 'w') as f:
for img_id in results[cat]:
for pt in results[cat][img_id]:
f.write('{} {:.12f} {:.1f} {:.1f} {:.1f} {:.1f} {:.1f} {:.1f} {:.1f} {:.1f}\n'.format(
img_id, pt[8], pt[0], pt[1], pt[2], pt[3], pt[4], pt[5], pt[6], pt[7]))