Browse Source

Code refactor (#7923)

* Code refactor for general.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update restapi.py

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
modifyDataloader
Glenn Jocher GitHub 2 years ago
parent
commit
cee5959c74
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 37 additions and 44 deletions
  1. +1
    -1
      utils/flask_rest_api/restapi.py
  2. +36
    -43
      utils/general.py

+ 1
- 1
utils/flask_rest_api/restapi.py View File

@@ -17,7 +17,7 @@ DETECTION_URL = "/v1/object-detection/yolov5s"

@app.route(DETECTION_URL, methods=["POST"])
def predict():
if not request.method == "POST":
if request.method != "POST":
return

if request.files.get("image"):

+ 36
- 43
utils/general.py View File

@@ -67,17 +67,16 @@ def is_kaggle():

def is_writeable(dir, test=False):
# Return True if directory has write permissions, test opening a file with write permissions if test=True
if test: # method 1
file = Path(dir) / 'tmp.txt'
try:
with open(file, 'w'): # open file with write permissions
pass
file.unlink() # remove file
return True
except OSError:
return False
else: # method 2
if not test:
return os.access(dir, os.R_OK) # possible issues on Windows
file = Path(dir) / 'tmp.txt'
try:
with open(file, 'w'): # open file with write permissions
pass
file.unlink() # remove file
return True
except OSError:
return False


def set_logging(name=None, verbose=VERBOSE):
@@ -244,7 +243,7 @@ def is_ascii(s=''):

def is_chinese(s='人工智能'):
# Is string composed of any Chinese characters?
return True if re.search('[\u4e00-\u9fff]', str(s)) else False
return bool(re.search('[\u4e00-\u9fff]', str(s)))


def emojis(str=''):
@@ -417,7 +416,7 @@ def check_file(file, suffix=''):
# Search/download file (if necessary) and return path
check_suffix(file, suffix) # optional
file = str(file) # convert to str()
if Path(file).is_file() or file == '': # exists
if Path(file).is_file() or not file: # exists
return file
elif file.startswith(('http:/', 'https:/')): # download
url = str(Path(file)).replace(':/', '://') # Pathlib turns :// -> :/
@@ -481,28 +480,26 @@ def check_dataset(data, autodownload=True):
val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
if not all(x.exists() for x in val):
LOGGER.info(emojis('\nDataset not found ⚠, missing paths %s' % [str(x) for x in val if not x.exists()]))
if s and autodownload: # download script
t = time.time()
root = path.parent if 'path' in data else '..' # unzip directory i.e. '../'
if s.startswith('http') and s.endswith('.zip'): # URL
f = Path(s).name # filename
LOGGER.info(f'Downloading {s} to {f}...')
torch.hub.download_url_to_file(s, f)
Path(root).mkdir(parents=True, exist_ok=True) # create root
ZipFile(f).extractall(path=root) # unzip
Path(f).unlink() # remove zip
r = None # success
elif s.startswith('bash '): # bash script
LOGGER.info(f'Running {s} ...')
r = os.system(s)
else: # python script
r = exec(s, {'yaml': data}) # return None
dt = f'({round(time.time() - t, 1)}s)'
s = f"success ✅ {dt}, saved to {colorstr('bold', root)}" if r in (0, None) else f"failure {dt} ❌"
LOGGER.info(emojis(f"Dataset download {s}"))
else:
if not s or not autodownload:
raise Exception(emojis('Dataset not found ❌'))

t = time.time()
root = path.parent if 'path' in data else '..' # unzip directory i.e. '../'
if s.startswith('http') and s.endswith('.zip'): # URL
f = Path(s).name # filename
LOGGER.info(f'Downloading {s} to {f}...')
torch.hub.download_url_to_file(s, f)
Path(root).mkdir(parents=True, exist_ok=True) # create root
ZipFile(f).extractall(path=root) # unzip
Path(f).unlink() # remove zip
r = None # success
elif s.startswith('bash '): # bash script
LOGGER.info(f'Running {s} ...')
r = os.system(s)
else: # python script
r = exec(s, {'yaml': data}) # return None
dt = f'({round(time.time() - t, 1)}s)'
s = f"success ✅ {dt}, saved to {colorstr('bold', root)}" if r in (0, None) else f"failure {dt} ❌"
LOGGER.info(emojis(f"Dataset download {s}"))
check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf', progress=True) # download fonts
return data # dictionary

@@ -531,8 +528,7 @@ def check_amp(model):
def url2file(url):
# Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt
url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/
file = Path(urllib.parse.unquote(url)).name.split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth
return file
return Path(urllib.parse.unquote(url)).name.split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth


def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1, retry=3):
@@ -645,10 +641,9 @@ def labels_to_class_weights(labels, nc=80):

def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
# Produces image weights based on class_weights and image contents
# Usage: index = random.choices(range(n), weights=image_weights, k=1) # weighted image sample
class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])
image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
# index = random.choices(range(n), weights=image_weights, k=1) # weight image sample
return image_weights
return (class_weights.reshape(1, nc) * class_counts).sum(1)


def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
@@ -657,11 +652,10 @@ def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
# b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
# x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
# x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
x = [
return [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
return x


def xyxy2xywh(x):
@@ -883,7 +877,7 @@ def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_op
p.requires_grad = False
torch.save(x, s or f)
mb = os.path.getsize(s or f) / 1E6 # filesize
LOGGER.info(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB")
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")


def print_mutation(results, hyp, save_dir, bucket, prefix=colorstr('evolve: ')):
@@ -946,10 +940,9 @@ def apply_classifier(x, model, img, im0):
# Classes
pred_cls1 = d[:, 5].long()
ims = []
for j, a in enumerate(d): # per item
for a in d:
cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
im = cv2.resize(cutout, (224, 224)) # BGR
# cv2.imwrite('example%i.jpg' % j, cutout)

im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32

Loading…
Cancel
Save