You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

predict.py 1.3KB

2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
123456789101112131415161718192021222324252627282930313233343536
  1. from models_725.segWaterBuilding import SegModel
  2. from PIL import Image
  3. import numpy as np
  4. import cv2
  5. import os
  6. from cv2 import getTickCount, getTickFrequency
  7. import time
  8. def predict_lunkuo(impth=None):
  9. pred, probs = segmodel.eval(image=impth)#####
  10. preds_squeeze = pred.squeeze(0)
  11. preds_squeeze[preds_squeeze != 0] = 255
  12. preds_squeeze = np.array(preds_squeeze.cpu())
  13. preds_squeeze = np.uint8(preds_squeeze)
  14. _, binary = cv2.threshold(preds_squeeze,220,255,cv2.THRESH_BINARY)
  15. contours, hierarchy = cv2.findContours(binary,cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
  16. img_n = cv2.cvtColor(impth,cv2.COLOR_RGB2BGR)
  17. img2 = cv2.drawContours(img_n,contours,-1,(0,0,255),8)
  18. return img2
  19. if __name__ == '__main__':
  20. impth = 'images/examples'
  21. outpth= 'images/results'
  22. folders = os.listdir(impth)
  23. #segmodel = SegModel(device='cuda:0')
  24. segmodel = SegModel(device='cpu')
  25. for i in range(len(folders)):
  26. imgpath = os.path.join(impth, folders[i])
  27. time00 = time.time()
  28. img = Image.open(imgpath).convert('RGB')
  29. img = np.array(img)
  30. time11 = time.time()
  31. img=predict_lunkuo(impth=img)
  32. cv2.imwrite( os.path.join( outpth,folders[i] ) ,img )
  33. print('----all_process', (time.time() - time11) * 1000)