22 lines
490 B
Python
22 lines
490 B
Python
|
|
import logging
|
||
|
|
import os
|
||
|
|
|
||
|
|
import torch
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
logging.basicConfig(
|
||
|
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||
|
|
datefmt="%m/%d/%Y %H:%M:%S",
|
||
|
|
level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def get_pinjie(img, shift):
|
||
|
|
nbox = img.shape[1]
|
||
|
|
shift = torch.from_numpy(shift).to(img.device)
|
||
|
|
shift = shift.unsqueeze(1).repeat(1, nbox, 1)
|
||
|
|
img[..., :2] += shift
|
||
|
|
|
||
|
|
img_out = img.view(1, -1, 7)
|
||
|
|
|
||
|
|
return img_out
|