You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

151 line
6.0KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Download utils
  4. """
  5. import os
  6. import platform
  7. import subprocess
  8. import time
  9. import urllib
  10. from pathlib import Path
  11. from zipfile import ZipFile
  12. import requests
  13. import torch
  14. def gsutil_getsize(url=''):
  15. # gs://bucket/file size https://cloud.google.com/storage/docs/gsutil/commands/du
  16. s = subprocess.check_output(f'gsutil du {url}', shell=True).decode('utf-8')
  17. return eval(s.split(' ')[0]) if len(s) else 0 # bytes
  18. def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''):
  19. # Attempts to download file from url or url2, checks and removes incomplete downloads < min_bytes
  20. file = Path(file)
  21. assert_msg = f"Downloaded file '{file}' does not exist or size is < min_bytes={min_bytes}"
  22. try: # url1
  23. print(f'Downloading {url} to {file}...')
  24. torch.hub.download_url_to_file(url, str(file))
  25. assert file.exists() and file.stat().st_size > min_bytes, assert_msg # check
  26. except Exception as e: # url2
  27. file.unlink(missing_ok=True) # remove partial downloads
  28. print(f'ERROR: {e}\nRe-attempting {url2 or url} to {file}...')
  29. os.system(f"curl -L '{url2 or url}' -o '{file}' --retry 3 -C -") # curl download, retry and resume on fail
  30. finally:
  31. if not file.exists() or file.stat().st_size < min_bytes: # check
  32. file.unlink(missing_ok=True) # remove partial downloads
  33. print(f"ERROR: {assert_msg}\n{error_msg}")
  34. print('')
  35. def attempt_download(file, repo='ultralytics/yolov5'): # from utils.downloads import *; attempt_download()
  36. # Attempt file download if does not exist
  37. file = Path(str(file).strip().replace("'", ''))
  38. if not file.exists():
  39. # URL specified
  40. name = Path(urllib.parse.unquote(str(file))).name # decode '%2F' to '/' etc.
  41. if str(file).startswith(('http:/', 'https:/')): # download
  42. url = str(file).replace(':/', '://') # Pathlib turns :// -> :/
  43. name = name.split('?')[0] # parse authentication https://url.com/file.txt?auth...
  44. safe_download(file=name, url=url, min_bytes=1E5)
  45. return name
  46. # GitHub assets
  47. file.parent.mkdir(parents=True, exist_ok=True) # make parent dir (if required)
  48. try:
  49. response = requests.get(f'https://api.github.com/repos/{repo}/releases/latest').json() # github api
  50. assets = [x['name'] for x in response['assets']] # release assets, i.e. ['yolov5s.pt', 'yolov5m.pt', ...]
  51. tag = response['tag_name'] # i.e. 'v1.0'
  52. except: # fallback plan
  53. assets = ['yolov5n.pt', 'yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt',
  54. 'yolov5n6.pt', 'yolov5s6.pt', 'yolov5m6.pt', 'yolov5l6.pt', 'yolov5x6.pt']
  55. try:
  56. tag = subprocess.check_output('git tag', shell=True, stderr=subprocess.STDOUT).decode().split()[-1]
  57. except:
  58. tag = 'v6.0' # current release
  59. if name in assets:
  60. safe_download(file,
  61. url=f'https://github.com/{repo}/releases/download/{tag}/{name}',
  62. # url2=f'https://storage.googleapis.com/{repo}/ckpt/{name}', # backup url (optional)
  63. min_bytes=1E5,
  64. error_msg=f'{file} missing, try downloading from https://github.com/{repo}/releases/')
  65. return str(file)
  66. def gdrive_download(id='16TiPfZj7htmTyhntwcZyEEAejOUxuT6m', file='tmp.zip'):
  67. # Downloads a file from Google Drive. from yolov5.utils.downloads import *; gdrive_download()
  68. t = time.time()
  69. file = Path(file)
  70. cookie = Path('cookie') # gdrive cookie
  71. print(f'Downloading https://drive.google.com/uc?export=download&id={id} as {file}... ', end='')
  72. file.unlink(missing_ok=True) # remove existing file
  73. cookie.unlink(missing_ok=True) # remove existing cookie
  74. # Attempt file download
  75. out = "NUL" if platform.system() == "Windows" else "/dev/null"
  76. os.system(f'curl -c ./cookie -s -L "drive.google.com/uc?export=download&id={id}" > {out}')
  77. if os.path.exists('cookie'): # large file
  78. s = f'curl -Lb ./cookie "drive.google.com/uc?export=download&confirm={get_token()}&id={id}" -o {file}'
  79. else: # small file
  80. s = f'curl -s -L -o {file} "drive.google.com/uc?export=download&id={id}"'
  81. r = os.system(s) # execute, capture return
  82. cookie.unlink(missing_ok=True) # remove existing cookie
  83. # Error check
  84. if r != 0:
  85. file.unlink(missing_ok=True) # remove partial
  86. print('Download error ') # raise Exception('Download error')
  87. return r
  88. # Unzip if archive
  89. if file.suffix == '.zip':
  90. print('unzipping... ', end='')
  91. ZipFile(file).extractall(path=file.parent) # unzip
  92. file.unlink() # remove zip
  93. print(f'Done ({time.time() - t:.1f}s)')
  94. return r
  95. def get_token(cookie="./cookie"):
  96. with open(cookie) as f:
  97. for line in f:
  98. if "download" in line:
  99. return line.split()[-1]
  100. return ""
  101. # Google utils: https://cloud.google.com/storage/docs/reference/libraries ----------------------------------------------
  102. #
  103. #
  104. # def upload_blob(bucket_name, source_file_name, destination_blob_name):
  105. # # Uploads a file to a bucket
  106. # # https://cloud.google.com/storage/docs/uploading-objects#storage-upload-object-python
  107. #
  108. # storage_client = storage.Client()
  109. # bucket = storage_client.get_bucket(bucket_name)
  110. # blob = bucket.blob(destination_blob_name)
  111. #
  112. # blob.upload_from_filename(source_file_name)
  113. #
  114. # print('File {} uploaded to {}.'.format(
  115. # source_file_name,
  116. # destination_blob_name))
  117. #
  118. #
  119. # def download_blob(bucket_name, source_blob_name, destination_file_name):
  120. # # Uploads a blob from a bucket
  121. # storage_client = storage.Client()
  122. # bucket = storage_client.get_bucket(bucket_name)
  123. # blob = bucket.blob(source_blob_name)
  124. #
  125. # blob.download_to_filename(destination_file_name)
  126. #
  127. # print('Blob {} downloaded to {}.'.format(
  128. # source_blob_name,
  129. # destination_file_name))