|
|
@@ -223,18 +223,18 @@ class NMS(nn.Module): |
|
|
|
return non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) |
|
|
|
|
|
|
|
|
|
|
|
class autoShape(nn.Module): |
|
|
|
class AutoShape(nn.Module): |
|
|
|
# input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS |
|
|
|
conf = 0.25 # NMS confidence threshold |
|
|
|
iou = 0.45 # NMS IoU threshold |
|
|
|
classes = None # (optional list) filter by class |
|
|
|
|
|
|
|
def __init__(self, model): |
|
|
|
super(autoShape, self).__init__() |
|
|
|
super(AutoShape, self).__init__() |
|
|
|
self.model = model.eval() |
|
|
|
|
|
|
|
def autoshape(self): |
|
|
|
print('autoShape already enabled, skipping... ') # model already converted to model.autoshape() |
|
|
|
print('AutoShape already enabled, skipping... ') # model already converted to model.autoshape() |
|
|
|
return self |
|
|
|
|
|
|
|
@torch.no_grad() |