|
|
@@ -128,7 +128,10 @@ def custom(path_or_model='path/to/model.pt', autoshape=True): |
|
|
|
hub_model = Model(model.yaml).to(next(model.parameters()).device) # create |
|
|
|
hub_model.load_state_dict(model.float().state_dict()) # load state_dict |
|
|
|
hub_model.names = model.names # class names |
|
|
|
return hub_model.autoshape() if autoshape else hub_model |
|
|
|
if autoshape: |
|
|
|
hub_model = hub_model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS |
|
|
|
device = select_device('0' if torch.cuda.is_available() else 'cpu') # default to GPU if available |
|
|
|
return hub_model.to(device) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |