Du kannst nicht mehr als 25 Themen auswählen Themen müssen entweder mit einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

62 Zeilen
2.3KB

  1. import torch
  2. from models_725.segWaterBuilding import SegModel
  3. from PIL import Image
  4. import numpy as np
  5. import cv2
  6. import os
  7. from cv2 import getTickCount, getTickFrequency
  8. import time
  9. def filer_boundary(probs=None, contours=None, img_n=None):
  10. contours2 = contours[:]
  11. n = len(contours)
  12. band = 0.8
  13. cimg = np.zeros_like(img_n)
  14. area = map(cv2.contourArea, contours)
  15. area_list = list(area)
  16. if n != 0:
  17. for i in range(n):
  18. cimg[:, :, :] = 255
  19. cv2.drawContours(cimg, contours2, i, color=(0, 0, 0), thickness=-1)
  20. torch_cimg = torch.from_numpy(cimg).cuda()
  21. torch_cimg = torch_cimg[:, :, 0] * 1 / 255
  22. torch_cimg = abs(torch_cimg - 1)
  23. cimg_logits = torch_cimg.mul(probs)
  24. sum = torch.sum(cimg_logits)
  25. sum_pix = sum / area_list[i]
  26. if sum_pix <= band:
  27. contours.remove(contours2[i])
  28. return contours
  29. def predict_lunkuo(impth=None, segmodel=None, filter=False):
  30. pred, probs = segmodel.eval(image=impth)#####
  31. preds_squeeze = pred.squeeze(0)
  32. preds_squeeze[preds_squeeze != 0] = 255
  33. preds_squeeze = np.array(preds_squeeze.cpu())
  34. preds_squeeze = np.uint8(preds_squeeze)
  35. _, binary = cv2.threshold(preds_squeeze,220,255,cv2.THRESH_BINARY)
  36. contours, hierarchy = cv2.findContours(binary,cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
  37. img_n = cv2.cvtColor(impth,cv2.COLOR_RGB2BGR)
  38. if filter:
  39. time_guolv = time.time()
  40. contours = filer_boundary(probs=probs, contours=contours, img_n=img_n)
  41. print('++++guolv', (time.time() - time_guolv) * 1000)
  42. img2 = cv2.drawContours(img_n,contours,-1,(0,0,255),8)
  43. # save_path = 'demo_lunkuo_1700_360640/' + name + '.png'
  44. # cv2.imwrite(save_path, img2)
  45. return img2
  46. if __name__ == '__main__':
  47. impth = '/home/data/lijiwen/wurenjiqifei/images/'
  48. folders = os.listdir(impth)
  49. segmodel = SegModel()
  50. for i in range(len(folders)):
  51. imgpath = os.path.join(impth, folders[i])
  52. # name = imgpath.split('/')[-1].split('.')[0]
  53. time00 = time.time()
  54. img = Image.open(imgpath).convert('RGB')
  55. img = np.array(img)
  56. time11 = time.time()
  57. img_out = predict_lunkuo(impth=img, segmodel=segmodel, filter=True)
  58. print('----all_process', (time.time() - time11) * 1000)