|
- import logging
- import os
-
- import torch
-
- logger = logging.getLogger(__name__)
- logging.basicConfig(
- format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
- datefmt="%m/%d/%Y %H:%M:%S",
- level=os.environ.get("LOGLEVEL", "INFO").upper(),
- )
-
-
- def get_pinjie(img, shift):
- nbox = img.shape[1]
- shift = torch.from_numpy(shift).to(img.device)
- shift = shift.unsqueeze(1).repeat(1, nbox, 1)
- img[..., :2] += shift
-
- img_out = img.view(1, -1, 7)
-
- return img_out
|