You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

150 lines
5.9KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Download utils
  4. """
  5. import os
  6. import platform
  7. import subprocess
  8. import time
  9. import urllib
  10. from pathlib import Path
  11. import requests
  12. import torch
  13. def gsutil_getsize(url=''):
  14. # gs://bucket/file size https://cloud.google.com/storage/docs/gsutil/commands/du
  15. s = subprocess.check_output(f'gsutil du {url}', shell=True).decode('utf-8')
  16. return eval(s.split(' ')[0]) if len(s) else 0 # bytes
  17. def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''):
  18. # Attempts to download file from url or url2, checks and removes incomplete downloads < min_bytes
  19. file = Path(file)
  20. assert_msg = f"Downloaded file '{file}' does not exist or size is < min_bytes={min_bytes}"
  21. try: # url1
  22. print(f'Downloading {url} to {file}...')
  23. torch.hub.download_url_to_file(url, str(file))
  24. assert file.exists() and file.stat().st_size > min_bytes, assert_msg # check
  25. except Exception as e: # url2
  26. file.unlink(missing_ok=True) # remove partial downloads
  27. print(f'ERROR: {e}\nRe-attempting {url2 or url} to {file}...')
  28. os.system(f"curl -L '{url2 or url}' -o '{file}' --retry 3 -C -") # curl download, retry and resume on fail
  29. finally:
  30. if not file.exists() or file.stat().st_size < min_bytes: # check
  31. file.unlink(missing_ok=True) # remove partial downloads
  32. print(f"ERROR: {assert_msg}\n{error_msg}")
  33. print('')
  34. def attempt_download(file, repo='ultralytics/yolov5'): # from utils.downloads import *; attempt_download()
  35. # Attempt file download if does not exist
  36. file = Path(str(file).strip().replace("'", ''))
  37. if not file.exists():
  38. # URL specified
  39. name = Path(urllib.parse.unquote(str(file))).name # decode '%2F' to '/' etc.
  40. if str(file).startswith(('http:/', 'https:/')): # download
  41. url = str(file).replace(':/', '://') # Pathlib turns :// -> :/
  42. name = name.split('?')[0] # parse authentication https://url.com/file.txt?auth...
  43. safe_download(file=name, url=url, min_bytes=1E5)
  44. return name
  45. # GitHub assets
  46. file.parent.mkdir(parents=True, exist_ok=True) # make parent dir (if required)
  47. try:
  48. response = requests.get(f'https://api.github.com/repos/{repo}/releases/latest').json() # github api
  49. assets = [x['name'] for x in response['assets']] # release assets, i.e. ['yolov5s.pt', 'yolov5m.pt', ...]
  50. tag = response['tag_name'] # i.e. 'v1.0'
  51. except: # fallback plan
  52. assets = ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt',
  53. 'yolov5s6.pt', 'yolov5m6.pt', 'yolov5l6.pt', 'yolov5x6.pt']
  54. try:
  55. tag = subprocess.check_output('git tag', shell=True, stderr=subprocess.STDOUT).decode().split()[-1]
  56. except:
  57. tag = 'v5.0' # current release
  58. if name in assets:
  59. safe_download(file,
  60. url=f'https://github.com/{repo}/releases/download/{tag}/{name}',
  61. # url2=f'https://storage.googleapis.com/{repo}/ckpt/{name}', # backup url (optional)
  62. min_bytes=1E5,
  63. error_msg=f'{file} missing, try downloading from https://github.com/{repo}/releases/')
  64. return str(file)
  65. def gdrive_download(id='16TiPfZj7htmTyhntwcZyEEAejOUxuT6m', file='tmp.zip'):
  66. # Downloads a file from Google Drive. from yolov5.utils.downloads import *; gdrive_download()
  67. t = time.time()
  68. file = Path(file)
  69. cookie = Path('cookie') # gdrive cookie
  70. print(f'Downloading https://drive.google.com/uc?export=download&id={id} as {file}... ', end='')
  71. file.unlink(missing_ok=True) # remove existing file
  72. cookie.unlink(missing_ok=True) # remove existing cookie
  73. # Attempt file download
  74. out = "NUL" if platform.system() == "Windows" else "/dev/null"
  75. os.system(f'curl -c ./cookie -s -L "drive.google.com/uc?export=download&id={id}" > {out}')
  76. if os.path.exists('cookie'): # large file
  77. s = f'curl -Lb ./cookie "drive.google.com/uc?export=download&confirm={get_token()}&id={id}" -o {file}'
  78. else: # small file
  79. s = f'curl -s -L -o {file} "drive.google.com/uc?export=download&id={id}"'
  80. r = os.system(s) # execute, capture return
  81. cookie.unlink(missing_ok=True) # remove existing cookie
  82. # Error check
  83. if r != 0:
  84. file.unlink(missing_ok=True) # remove partial
  85. print('Download error ') # raise Exception('Download error')
  86. return r
  87. # Unzip if archive
  88. if file.suffix == '.zip':
  89. print('unzipping... ', end='')
  90. os.system(f'unzip -q {file}') # unzip
  91. file.unlink() # remove zip to free space
  92. print(f'Done ({time.time() - t:.1f}s)')
  93. return r
  94. def get_token(cookie="./cookie"):
  95. with open(cookie) as f:
  96. for line in f:
  97. if "download" in line:
  98. return line.split()[-1]
  99. return ""
  100. # Google utils: https://cloud.google.com/storage/docs/reference/libraries ----------------------------------------------
  101. #
  102. #
  103. # def upload_blob(bucket_name, source_file_name, destination_blob_name):
  104. # # Uploads a file to a bucket
  105. # # https://cloud.google.com/storage/docs/uploading-objects#storage-upload-object-python
  106. #
  107. # storage_client = storage.Client()
  108. # bucket = storage_client.get_bucket(bucket_name)
  109. # blob = bucket.blob(destination_blob_name)
  110. #
  111. # blob.upload_from_filename(source_file_name)
  112. #
  113. # print('File {} uploaded to {}.'.format(
  114. # source_file_name,
  115. # destination_blob_name))
  116. #
  117. #
  118. # def download_blob(bucket_name, source_blob_name, destination_file_name):
  119. # # Uploads a blob from a bucket
  120. # storage_client = storage.Client()
  121. # bucket = storage_client.get_bucket(bucket_name)
  122. # blob = bucket.blob(source_blob_name)
  123. #
  124. # blob.download_to_filename(destination_file_name)
  125. #
  126. # print('Blob {} downloaded to {}.'.format(
  127. # source_blob_name,
  128. # destination_file_name))