import logging import os import torch logger = logging.getLogger(__name__) logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=os.environ.get("LOGLEVEL", "INFO").upper(), ) def get_pinjie(img, shift): nbox = img.shape[1] shift = torch.from_numpy(shift).to(img.device) shift = shift.unsqueeze(1).repeat(1, nbox, 1) img[..., :2] += shift img_out = img.view(1, -1, 7) return img_out