Browse Source

STDC-v1

master
thsw 2 years ago
parent
commit
7cce7cb72a
4 changed files with 114 additions and 0 deletions
  1. BIN
      model_maxmIOU75_1720_0.946_360640.pth
  2. +33
    -0
      predict.py
  3. +61
    -0
      predict_guolv.py
  4. +20
    -0
      shiyan.py

BIN
model_maxmIOU75_1720_0.946_360640.pth View File


+ 33
- 0
predict.py View File

@@ -0,0 +1,33 @@
from models_725.segWaterBuilding import SegModel
from PIL import Image
import numpy as np
import cv2
import os
from cv2 import getTickCount, getTickFrequency
import time

def predict_lunkuo(impth=None):
pred, probs = segmodel.eval(image=impth)#####
preds_squeeze = pred.squeeze(0)
preds_squeeze[preds_squeeze != 0] = 255
preds_squeeze = np.array(preds_squeeze.cpu())
preds_squeeze = np.uint8(preds_squeeze)
_, binary = cv2.threshold(preds_squeeze,220,255,cv2.THRESH_BINARY)
contours, hierarchy = cv2.findContours(binary,cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
img_n = cv2.cvtColor(impth,cv2.COLOR_RGB2BGR)
img2 = cv2.drawContours(img_n,contours,-1,(0,0,255),8)
return img2

if __name__ == '__main__':
impth = 'images/examples'
folders = os.listdir(impth)
segmodel = SegModel()
for i in range(len(folders)):
imgpath = os.path.join(impth, folders[i])
time00 = time.time()
img = Image.open(imgpath).convert('RGB')
img = np.array(img)
time11 = time.time()
predict_lunkuo(impth=img)
print('----all_process', (time.time() - time11) * 1000)

+ 61
- 0
predict_guolv.py View File

@@ -0,0 +1,61 @@
import torch
from models_725.segWaterBuilding import SegModel
from PIL import Image
import numpy as np
import cv2
import os
from cv2 import getTickCount, getTickFrequency
import time

def filer_boundary(probs=None, contours=None, img_n=None):
contours2 = contours[:]
n = len(contours)
band = 0.8
cimg = np.zeros_like(img_n)
area = map(cv2.contourArea, contours)
area_list = list(area)
if n != 0:
for i in range(n):
cimg[:, :, :] = 255
cv2.drawContours(cimg, contours2, i, color=(0, 0, 0), thickness=-1)
torch_cimg = torch.from_numpy(cimg).cuda()
torch_cimg = torch_cimg[:, :, 0] * 1 / 255
torch_cimg = abs(torch_cimg - 1)
cimg_logits = torch_cimg.mul(probs)
sum = torch.sum(cimg_logits)
sum_pix = sum / area_list[i]
if sum_pix <= band:
contours.remove(contours2[i])
return contours

def predict_lunkuo(impth=None, segmodel=None, filter=False):
pred, probs = segmodel.eval(image=impth)#####
preds_squeeze = pred.squeeze(0)
preds_squeeze[preds_squeeze != 0] = 255
preds_squeeze = np.array(preds_squeeze.cpu())
preds_squeeze = np.uint8(preds_squeeze)
_, binary = cv2.threshold(preds_squeeze,220,255,cv2.THRESH_BINARY)
contours, hierarchy = cv2.findContours(binary,cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
img_n = cv2.cvtColor(impth,cv2.COLOR_RGB2BGR)
if filter:
time_guolv = time.time()
contours = filer_boundary(probs=probs, contours=contours, img_n=img_n)
print('++++guolv', (time.time() - time_guolv) * 1000)
img2 = cv2.drawContours(img_n,contours,-1,(0,0,255),8)
# save_path = 'demo_lunkuo_1700_360640/' + name + '.png'
# cv2.imwrite(save_path, img2)
return img2

if __name__ == '__main__':
impth = '/home/data/lijiwen/wurenjiqifei/images/'
folders = os.listdir(impth)
segmodel = SegModel()
for i in range(len(folders)):
imgpath = os.path.join(impth, folders[i])
# name = imgpath.split('/')[-1].split('.')[0]
time00 = time.time()
img = Image.open(imgpath).convert('RGB')
img = np.array(img)
time11 = time.time()
img_out = predict_lunkuo(impth=img, segmodel=segmodel, filter=True)
print('----all_process', (time.time() - time11) * 1000)

+ 20
- 0
shiyan.py View File

@@ -0,0 +1,20 @@
import os
import time
from PIL import Image
import numpy as np
from models_725.segWaterBuilding import SegModel
from predict_guolv import predict_lunkuo


impth = '/home/data/lijiwen/wurenjiqifei/images/'
folders = os.listdir(impth)
segmodel = SegModel()
for i in range(len(folders)):
imgpath = os.path.join(impth, folders[i])
# name = imgpath.split('/')[-1].split('.')[0]
time00 = time.time()
img = Image.open(imgpath).convert('RGB')
img = np.array(img)
time11 = time.time()
img_out = predict_lunkuo(impth=img, segmodel=segmodel, filter=False)
print('----all_process', (time.time() - time11) * 1000)

Loading…
Cancel
Save