urban_management/STDCUtils/STDC_process.py

46 lines
1.5 KiB
Python

import time
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from torchvision.transforms import transforms
def STDC_process(img0, model, device, new_hw=None):
if new_hw is None:
new_hw = [360, 640]
t_start = time.time()
img0 = cv2.cvtColor(img0, cv2.COLOR_BGR2RGB)
t2 = time.time()
print(f't_bgr2rgb. ({t2 - t_start:.3f}s)')
# img0 = img0[..., ::-1]
img = transforms.ToTensor()(img0)
img = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))(img)
t3 = time.time()
print(f't_trans. ({t3 - t2:.3f}s)')
t_togpu = time.time()
img = img.to(device)
t_togpu2 = time.time()
print(f't_togpu. ({t_togpu2 - t_togpu:.3f}s)')
C, H, W = img.shape
size = img.size()[-2:]
# new_hw = [int(H * scale), int(W * scale)]
# new_hw = [360, 640]
img = img.unsqueeze(0)
img = F.interpolate(img, new_hw, mode='bilinear', align_corners=True)
t_pro = time.time()
print(f't_interpolate. ({t_pro - t_togpu2:.3f}s)')
print(f't_pro. ({t_pro - t_start:.3f}s)')
logits = model(img)[0]
t_inf = time.time()
print(f't_inf. ({t_inf - t_pro:.3f}s)')
logits = F.interpolate(logits, size=size, mode='bilinear', align_corners=True)
probs = torch.softmax(logits, dim=1)
preds = torch.argmax(probs, dim=1)
preds_squeeze = preds.squeeze(0)
preds_squeeze_predict = np.array(preds_squeeze.cpu())
t_end = time.time()
print(f't_post. ({t_end - t_inf:.3f}s)')
return preds_squeeze_predict