|
|
@@ -0,0 +1,61 @@ |
|
|
|
import torch |
|
|
|
from models_725.segWaterBuilding import SegModel |
|
|
|
from PIL import Image |
|
|
|
import numpy as np |
|
|
|
import cv2 |
|
|
|
import os |
|
|
|
from cv2 import getTickCount, getTickFrequency |
|
|
|
import time |
|
|
|
|
|
|
|
def filer_boundary(probs=None, contours=None, img_n=None): |
|
|
|
contours2 = contours[:] |
|
|
|
n = len(contours) |
|
|
|
band = 0.8 |
|
|
|
cimg = np.zeros_like(img_n) |
|
|
|
area = map(cv2.contourArea, contours) |
|
|
|
area_list = list(area) |
|
|
|
if n != 0: |
|
|
|
for i in range(n): |
|
|
|
cimg[:, :, :] = 255 |
|
|
|
cv2.drawContours(cimg, contours2, i, color=(0, 0, 0), thickness=-1) |
|
|
|
torch_cimg = torch.from_numpy(cimg).cuda() |
|
|
|
torch_cimg = torch_cimg[:, :, 0] * 1 / 255 |
|
|
|
torch_cimg = abs(torch_cimg - 1) |
|
|
|
cimg_logits = torch_cimg.mul(probs) |
|
|
|
sum = torch.sum(cimg_logits) |
|
|
|
sum_pix = sum / area_list[i] |
|
|
|
if sum_pix <= band: |
|
|
|
contours.remove(contours2[i]) |
|
|
|
return contours |
|
|
|
|
|
|
|
def predict_lunkuo(impth=None, segmodel=None, filter=False): |
|
|
|
pred, probs = segmodel.eval(image=impth)##### |
|
|
|
preds_squeeze = pred.squeeze(0) |
|
|
|
preds_squeeze[preds_squeeze != 0] = 255 |
|
|
|
preds_squeeze = np.array(preds_squeeze.cpu()) |
|
|
|
preds_squeeze = np.uint8(preds_squeeze) |
|
|
|
_, binary = cv2.threshold(preds_squeeze,220,255,cv2.THRESH_BINARY) |
|
|
|
contours, hierarchy = cv2.findContours(binary,cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) |
|
|
|
img_n = cv2.cvtColor(impth,cv2.COLOR_RGB2BGR) |
|
|
|
if filter: |
|
|
|
time_guolv = time.time() |
|
|
|
contours = filer_boundary(probs=probs, contours=contours, img_n=img_n) |
|
|
|
print('++++guolv', (time.time() - time_guolv) * 1000) |
|
|
|
img2 = cv2.drawContours(img_n,contours,-1,(0,0,255),8) |
|
|
|
# save_path = 'demo_lunkuo_1700_360640/' + name + '.png' |
|
|
|
# cv2.imwrite(save_path, img2) |
|
|
|
return img2 |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
impth = '/home/data/lijiwen/wurenjiqifei/images/' |
|
|
|
folders = os.listdir(impth) |
|
|
|
segmodel = SegModel() |
|
|
|
for i in range(len(folders)): |
|
|
|
imgpath = os.path.join(impth, folders[i]) |
|
|
|
# name = imgpath.split('/')[-1].split('.')[0] |
|
|
|
time00 = time.time() |
|
|
|
img = Image.open(imgpath).convert('RGB') |
|
|
|
img = np.array(img) |
|
|
|
time11 = time.time() |
|
|
|
img_out = predict_lunkuo(impth=img, segmodel=segmodel, filter=True) |
|
|
|
print('----all_process', (time.time() - time11) * 1000) |