城管三模型代码
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

46 lines
1.5KB

  1. import time
  2. import cv2
  3. import numpy as np
  4. import torch
  5. import torch.nn.functional as F
  6. from torchvision.transforms import transforms
  7. def STDC_process(img0, model, device, new_hw=None):
  8. if new_hw is None:
  9. new_hw = [360, 640]
  10. t_start = time.time()
  11. img0 = cv2.cvtColor(img0, cv2.COLOR_BGR2RGB)
  12. t2 = time.time()
  13. print(f't_bgr2rgb. ({t2 - t_start:.3f}s)')
  14. # img0 = img0[..., ::-1]
  15. img = transforms.ToTensor()(img0)
  16. img = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))(img)
  17. t3 = time.time()
  18. print(f't_trans. ({t3 - t2:.3f}s)')
  19. t_togpu = time.time()
  20. img = img.to(device)
  21. t_togpu2 = time.time()
  22. print(f't_togpu. ({t_togpu2 - t_togpu:.3f}s)')
  23. C, H, W = img.shape
  24. size = img.size()[-2:]
  25. # new_hw = [int(H * scale), int(W * scale)]
  26. # new_hw = [360, 640]
  27. img = img.unsqueeze(0)
  28. img = F.interpolate(img, new_hw, mode='bilinear', align_corners=True)
  29. t_pro = time.time()
  30. print(f't_interpolate. ({t_pro - t_togpu2:.3f}s)')
  31. print(f't_pro. ({t_pro - t_start:.3f}s)')
  32. logits = model(img)[0]
  33. t_inf = time.time()
  34. print(f't_inf. ({t_inf - t_pro:.3f}s)')
  35. logits = F.interpolate(logits, size=size, mode='bilinear', align_corners=True)
  36. probs = torch.softmax(logits, dim=1)
  37. preds = torch.argmax(probs, dim=1)
  38. preds_squeeze = preds.squeeze(0)
  39. preds_squeeze_predict = np.array(preds_squeeze.cpu())
  40. t_end = time.time()
  41. print(f't_post. ({t_end - t_inf:.3f}s)')
  42. return preds_squeeze_predict