Browse Source

FP16 inference update

5.0
Glenn Jocher 4 years ago
parent
commit
260b1729f0
4 changed files with 28 additions and 23 deletions
  1. +4
    -10
      detect.py
  2. +3
    -3
      requirements.txt
  3. +17
    -9
      test.py
  4. +4
    -1
      utils/utils.py

+ 4
- 10
detect.py View File

@@ -14,6 +14,7 @@ def detect(save_img=False):
if os.path.exists(out):
shutil.rmtree(out) # delete output folder
os.makedirs(out) # make new output folder
half &= device.type != 'cpu' # half precision only supported on CUDA

# Load model
google_utils.attempt_download(weights)
@@ -21,6 +22,8 @@ def detect(save_img=False):
# torch.save(torch.load(weights, map_location=device), weights) # update model if SourceChangeWarning
# model.fuse()
model.to(device).eval()
if half:
model.half() # to FP16

# Second-stage classifier
classify = False
@@ -29,11 +32,6 @@ def detect(save_img=False):
modelc.load_state_dict(torch.load('weights/resnet101.pt', map_location=device)['model']) # load weights
modelc.to(device).eval()

# Half precision
half = half and device.type != 'cpu' # half precision only supported on CUDA
if half:
model.half()

# Set Dataloader
vid_path, vid_writer = None, None
if webcam:
@@ -51,7 +49,7 @@ def detect(save_img=False):
# Run inference
t0 = time.time()
img = torch.zeros((1, 3, imgsz, imgsz), device=device) # init img
_ = model(img.half() if half else img.float()) if device.type != 'cpu' else None # run once
_ = model(img.half() if half else img) if device.type != 'cpu' else None # run once
for path, img, im0s, vid_cap in dataset:
img = torch.from_numpy(img).to(device)
img = img.half() if half else img.float() # uint8 to fp16/32
@@ -63,10 +61,6 @@ def detect(save_img=False):
t1 = torch_utils.time_synchronized()
pred = model(img, augment=opt.augment)[0]

# to float
if half:
pred = pred.float()

# Apply NMS
pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres,
fast=True, classes=opt.classes, agnostic=opt.agnostic_nms)

+ 3
- 3
requirements.txt View File

@@ -1,12 +1,12 @@
# pip install -U -r requirements.txt
Cython
numpy
numpy==1.17
opencv-python
torch >= 1.5
torch>=1.5
matplotlib
pillow
tensorboard
pyyaml >= 5.3
PyYAML>=5.3
torchvision
scipy
tqdm

+ 17
- 9
test.py View File

@@ -20,10 +20,12 @@ def test(data,
model=None,
dataloader=None,
fast=False,
verbose=False): # 0 fast, 1 accurate
verbose=False,
half=False): # FP16
# Initialize/load model and set device
if model is None:
device = torch_utils.select_device(opt.device, batch_size=batch_size)
half &= device.type != 'cpu' # half precision only supported on CUDA

# Remove previous
for f in glob.glob('test_batch*.jpg'):
@@ -35,6 +37,8 @@ def test(data,
torch_utils.model_info(model)
# model.fuse()
model.to(device)
if half:
model.half() # to FP16

if device.type != 'cpu' and torch.cuda.device_count() > 1:
model = nn.DataParallel(model)
@@ -72,24 +76,27 @@ def test(data,

seen = 0
model.eval()
_ = model(torch.zeros((1, 3, imgsz, imgsz), device=device)) if device.type != 'cpu' else None # run once
img = torch.zeros((1, 3, imgsz, imgsz), device=device) # init img
_ = model(img.half() if half else img) if device.type != 'cpu' else None # run once
names = model.names if hasattr(model, 'names') else model.module.names
coco91class = coco80_to_coco91_class()
s = ('%20s' + '%12s' * 6) % ('Class', 'Images', 'Targets', 'P', 'R', 'mAP@.5', 'mAP@.5:.95')
p, r, f1, mp, mr, map50, map, t0, t1 = 0., 0., 0., 0., 0., 0., 0., 0., 0.
loss = torch.zeros(3, device=device)
jdict, stats, ap, ap_class = [], [], [], []
for batch_i, (imgs, targets, paths, shapes) in enumerate(tqdm(dataloader, desc=s)):
imgs = imgs.to(device).float() / 255.0 # uint8 to float32, 0 - 255 to 0.0 - 1.0
for batch_i, (img, targets, paths, shapes) in enumerate(tqdm(dataloader, desc=s)):
img = img.to(device)
img = img.half() if half else img.float() # uint8 to fp16/32
img /= 255.0 # 0 - 255 to 0.0 - 1.0
targets = targets.to(device)
nb, _, height, width = imgs.shape # batch size, channels, height, width
nb, _, height, width = img.shape # batch size, channels, height, width
whwh = torch.Tensor([width, height, width, height]).to(device)

# Disable gradients
with torch.no_grad():
# Run model
t = torch_utils.time_synchronized()
inf_out, train_out = model(imgs, augment=augment) # inference and training outputs
inf_out, train_out = model(img, augment=augment) # inference and training outputs
t0 += torch_utils.time_synchronized() - t

# Compute loss
@@ -125,7 +132,7 @@ def test(data,
# [{"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}, ...
image_id = int(Path(paths[si]).stem.split('_')[-1])
box = pred[:, :4].clone() # xyxy
scale_coords(imgs[si].shape[1:], box, shapes[si][0], shapes[si][1]) # to original shape
scale_coords(img[si].shape[1:], box, shapes[si][0], shapes[si][1]) # to original shape
box = xyxy2xywh(box) # xywh
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
for p, b in zip(pred.tolist(), box.tolist()):
@@ -168,9 +175,9 @@ def test(data,
# Plot images
if batch_i < 1:
f = 'test_batch%g_gt.jpg' % batch_i # filename
plot_images(imgs, targets, paths, f, names) # ground truth
plot_images(img, targets, paths, f, names) # ground truth
f = 'test_batch%g_pred.jpg' % batch_i
plot_images(imgs, output_to_target(output, width, height), paths, f, names) # predictions
plot_images(img, output_to_target(output, width, height), paths, f, names) # predictions

# Compute statistics
stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy
@@ -241,6 +248,7 @@ if __name__ == '__main__':
parser.add_argument('--save-json', action='store_true', help='save a cocoapi-compatible JSON results file')
parser.add_argument('--task', default='val', help="'val', 'test', 'study'")
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--half', action='store_true', help='half precision FP16 inference')
parser.add_argument('--single-cls', action='store_true', help='treat as single-class dataset')
parser.add_argument('--augment', action='store_true', help='augmented inference')
parser.add_argument('--verbose', action='store_true', help='report mAP by class')

+ 4
- 1
utils/utils.py View File

@@ -504,6 +504,9 @@ def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, fast=False, c
Returns detections with shape:
nx6 (x1, y1, x2, y2, conf, cls)
"""
if prediction.dtype is torch.float16:
prediction = prediction.float() # to FP32

nc = prediction[0].shape[1] - 5 # number of classes
xc = prediction[..., 4] > conf_thres # candidates

@@ -902,7 +905,7 @@ def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max
return None

if isinstance(images, torch.Tensor):
images = images.cpu().numpy()
images = images.cpu().float().numpy()

if isinstance(targets, torch.Tensor):
targets = targets.cpu().numpy()

Loading…
Cancel
Save