You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1213 lines
50KB

  1. # -*- coding: utf-8 -*-
  2. """
  3. oss2.resumable
  4. ~~~~~~~~~~~~~~
  5. 该模块包含了断点续传相关的函数和类。
  6. """
  7. import os
  8. from loguru import logger
  9. from . import utils
  10. from .utils import b64encode_as_string, b64decode_from_string
  11. from . import iterators
  12. from . import exceptions
  13. from . import defaults
  14. from . import http
  15. from . import models
  16. from .crypto_bucket import CryptoBucket
  17. from . import Bucket
  18. from .iterators import PartIterator
  19. from .models import PartInfo
  20. from .compat import json, stringify, to_unicode, to_string
  21. from .task_queue import TaskQueue
  22. from .headers import *
  23. import functools
  24. import threading
  25. import random
  26. import string
  27. def resumable_upload(bucket, key, filename,
  28. store=None,
  29. headers=None,
  30. multipart_threshold=None,
  31. part_size=None,
  32. progress_callback=None,
  33. num_threads=None,
  34. params=None):
  35. """断点上传本地文件。
  36. 实现中采用分片上传方式上传本地文件,缺省的并发数是 `oss2.defaults.multipart_num_threads` ,并且在
  37. 本地磁盘保存已经上传的分片信息。如果因为某种原因上传被中断,下次上传同样的文件,即源文件和目标文件路径都
  38. 一样,就只会上传缺失的分片。
  39. 缺省条件下,该函数会在用户 `HOME` 目录下保存断点续传的信息。当待上传的本地文件没有发生变化,
  40. 且目标文件名没有变化时,会根据本地保存的信息,从断点开始上传。
  41. 使用该函数应注意如下细节:
  42. #. 如果使用CryptoBucket,函数会退化为普通上传
  43. :param bucket: :class:`Bucket <oss2.Bucket>` 或者 ::class:`CryptoBucket <oss2.CryptoBucket>` 对象
  44. :param key: 上传到用户空间的文件名
  45. :param filename: 待上传本地文件名
  46. :param store: 用来保存断点信息的持久存储,参见 :class:`ResumableStore` 的接口。如不指定,则使用 `ResumableStore` 。
  47. :param headers: HTTP头部
  48. # 调用外部函数put_object 或 init_multipart_upload传递完整headers
  49. # 调用外部函数uplpad_part目前只传递OSS_REQUEST_PAYER, OSS_TRAFFIC_LIMIT
  50. # 调用外部函数complete_multipart_upload目前只传递OSS_REQUEST_PAYER, OSS_OBJECT_ACL
  51. :type headers: 可以是dict,建议是oss2.CaseInsensitiveDict
  52. :param multipart_threshold: 文件长度大于该值时,则用分片上传。
  53. :param part_size: 指定分片上传的每个分片的大小。如不指定,则自动计算。
  54. :param progress_callback: 上传进度回调函数。参见 :ref:`progress_callback` 。
  55. :param num_threads: 并发上传的线程数,如不指定则使用 `oss2.defaults.multipart_num_threads` 。
  56. :param params: HTTP请求参数
  57. # 只有'sequential'这个参数才会被传递到外部函数init_multipart_upload中。
  58. # 其他参数视为无效参数不会往外部函数传递。
  59. :type params: dict
  60. """
  61. logger.debug("Start to resumable upload, bucket: {0}, key: {1}, filename: {2}, headers: {3}, "
  62. "multipart_threshold: {4}, part_size: {5}, num_threads: {6}".format(bucket.bucket_name, to_string(key),
  63. filename, headers,
  64. multipart_threshold,
  65. part_size, num_threads))
  66. size = os.path.getsize(filename)
  67. multipart_threshold = defaults.get(multipart_threshold, defaults.multipart_threshold)
  68. logger.debug("The size of file to upload is: {0}, multipart_threshold: {1}".format(size, multipart_threshold))
  69. if size >= multipart_threshold:
  70. uploader = _ResumableUploader(bucket, key, filename, size, store,
  71. part_size=part_size,
  72. headers=headers,
  73. progress_callback=progress_callback,
  74. num_threads=num_threads,
  75. params=params)
  76. result = uploader.upload()
  77. else:
  78. with open(to_unicode(filename), 'rb') as f:
  79. result = bucket.put_object(key, f, headers=headers, progress_callback=progress_callback)
  80. return result
  81. def resumable_upload1(bucket, key, filename, store=None, headers=None, multipart_threshold=None, part_size=None,
  82. progress_callback=None, num_threads=None, params=None, useData=None):
  83. size = os.path.getsize(filename)
  84. multipart_threshold = defaults.get(multipart_threshold, defaults.multipart_threshold)
  85. if size >= multipart_threshold:
  86. logger.info("{}文件上传, 文件大小: {}, 走断点上传逻辑", filename, size)
  87. uploader = _ResumableUploader1(bucket, key, filename, size, store,
  88. part_size=part_size,
  89. headers=headers,
  90. progress_callback=progress_callback,
  91. num_threads=num_threads,
  92. params=params,
  93. useData=useData)
  94. result = uploader.upload()
  95. else:
  96. logger.info("{}文件上传, 文件大小: {}, 走单文件上传逻辑", filename, size)
  97. with open(to_unicode(filename), 'rb') as f:
  98. result = bucket.put_object_1(key, f, headers=headers, progress_callback=progress_callback, useData=useData)
  99. return result
  100. def resumable_download(bucket, key, filename,
  101. multiget_threshold=None,
  102. part_size=None,
  103. progress_callback=None,
  104. num_threads=None,
  105. store=None,
  106. params=None,
  107. headers=None):
  108. """断点下载。
  109. 实现的方法是:
  110. #. 在本地创建一个临时文件,文件名由原始文件名加上一个随机的后缀组成;
  111. #. 通过指定请求的 `Range` 头按照范围并发读取OSS文件,并写入到临时文件里对应的位置;
  112. #. 全部完成之后,把临时文件重命名为目标文件 (即 `filename` )
  113. 在上述过程中,断点信息,即已经完成的范围,会保存在磁盘上。因为某种原因下载中断,后续如果下载
  114. 同样的文件,也就是源文件和目标文件一样,就会先读取断点信息,然后只下载缺失的部分。
  115. 缺省设置下,断点信息保存在 `HOME` 目录的一个子目录下。可以通过 `store` 参数更改保存位置。
  116. 使用该函数应注意如下细节:
  117. #. 对同样的源文件、目标文件,避免多个程序(线程)同时调用该函数。因为断点信息会在磁盘上互相覆盖,或临时文件名会冲突。
  118. #. 避免使用太小的范围(分片),即 `part_size` 不宜过小,建议大于或等于 `oss2.defaults.multiget_part_size` 。
  119. #. 如果目标文件已经存在,那么该函数会覆盖此文件。
  120. #. 如果使用CryptoBucket,函数会退化为普通下载
  121. :param bucket: :class:`Bucket <oss2.Bucket>` 或者 ::class:`CryptoBucket <oss2.CryptoBucket>` 对象
  122. :param str key: 待下载的远程文件名。
  123. :param str filename: 本地的目标文件名。
  124. :param int multiget_threshold: 文件长度大于该值时,则使用断点下载。
  125. :param int part_size: 指定期望的分片大小,即每个请求获得的字节数,实际的分片大小可能有所不同。
  126. :param progress_callback: 下载进度回调函数。参见 :ref:`progress_callback` 。
  127. :param num_threads: 并发下载的线程数,如不指定则使用 `oss2.defaults.multiget_num_threads` 。
  128. :param store: 用来保存断点信息的持久存储,可以指定断点信息所在的目录。
  129. :type store: `ResumableDownloadStore`
  130. :param dict params: 指定下载参数,可以传入versionId下载指定版本文件
  131. :param headers: HTTP头部,
  132. # 调用外部函数head_object目前只传递OSS_REQUEST_PAYER
  133. # 调用外部函数get_object_to_file, get_object目前需要向下传递的值有OSS_REQUEST_PAYER, OSS_TRAFFIC_LIMIT
  134. :type headers: 可以是dict,建议是oss2.CaseInsensitiveDict
  135. :raises: 如果OSS文件不存在,则抛出 :class:`NotFound <oss2.exceptions.NotFound>` ;也有可能抛出其他因下载文件而产生的异常。
  136. """
  137. logger.debug("Start to resumable download, bucket: {0}, key: {1}, filename: {2}, multiget_threshold: {3}, "
  138. "part_size: {4}, num_threads: {5}".format(bucket.bucket_name, to_string(key), filename,
  139. multiget_threshold, part_size, num_threads))
  140. multiget_threshold = defaults.get(multiget_threshold, defaults.multiget_threshold)
  141. valid_headers = _populate_valid_headers(headers, [OSS_REQUEST_PAYER, OSS_TRAFFIC_LIMIT])
  142. result = bucket.head_object(key, params=params, headers=valid_headers)
  143. logger.debug("The size of object to download is: {0}, multiget_threshold: {1}".format(result.content_length,
  144. multiget_threshold))
  145. if result.content_length >= multiget_threshold:
  146. downloader = _ResumableDownloader(bucket, key, filename, _ObjectInfo.make(result), part_size=part_size,
  147. progress_callback=progress_callback, num_threads=num_threads, store=store,
  148. params=params, headers=valid_headers)
  149. downloader.download(result.server_crc)
  150. else:
  151. bucket.get_object_to_file(key, filename, progress_callback=progress_callback, params=params,
  152. headers=valid_headers)
  153. _MAX_MULTIGET_PART_COUNT = 100000
  154. def determine_part_size(total_size,
  155. preferred_size=None):
  156. """确定分片上传是分片的大小。
  157. :param int total_size: 总共需要上传的长度
  158. :param int preferred_size: 用户期望的分片大小。如果不指定则采用defaults.part_size
  159. :return: 分片大小
  160. """
  161. if not preferred_size:
  162. preferred_size = defaults.part_size
  163. return _determine_part_size_internal(total_size, preferred_size, defaults.max_part_count)
  164. def _determine_part_size_internal(total_size, preferred_size, max_count):
  165. if total_size < preferred_size:
  166. return total_size
  167. while preferred_size * max_count < total_size or preferred_size < defaults.min_part_size:
  168. preferred_size = preferred_size * 2
  169. return preferred_size
  170. def _split_to_parts(total_size, part_size):
  171. parts = []
  172. num_parts = utils.how_many(total_size, part_size)
  173. for i in range(num_parts):
  174. if i == num_parts - 1:
  175. start = i * part_size
  176. end = total_size
  177. else:
  178. start = i * part_size
  179. end = part_size + start
  180. parts.append(_PartToProcess(i + 1, start, end))
  181. return parts
  182. def _populate_valid_headers(headers=None, valid_keys=None):
  183. """构建只包含有效keys的http header
  184. :param headers: 需要过滤的header
  185. :type headers: 可以是dict,建议是oss2.CaseInsensitiveDict
  186. :param valid_keys: 有效的关键key列表
  187. :type valid_keys: list
  188. :return: 只包含有效keys的http header, type: oss2.CaseInsensitiveDict
  189. """
  190. if headers is None or valid_keys is None:
  191. return None
  192. headers = http.CaseInsensitiveDict(headers)
  193. valid_headers = http.CaseInsensitiveDict()
  194. for key in valid_keys:
  195. if headers.get(key) is not None:
  196. valid_headers[key] = headers[key]
  197. if len(valid_headers) == 0:
  198. valid_headers = None
  199. return valid_headers
  200. def _filter_invalid_headers(headers=None, invalid_keys=None):
  201. """过滤无效keys的http header
  202. :param headers: 需要过滤的header
  203. :type headers: 可以是dict,建议是oss2.CaseInsensitiveDict
  204. :param invalid_keys: 无效的关键key列表
  205. :type invalid_keys: list
  206. :return: 过滤无效header之后的http headers, type: oss2.CaseInsensitiveDict
  207. """
  208. if headers is None or invalid_keys is None:
  209. return None
  210. headers = http.CaseInsensitiveDict(headers)
  211. valid_headers = headers.copy()
  212. for key in invalid_keys:
  213. if valid_headers.get(key) is not None:
  214. valid_headers.pop(key)
  215. if len(valid_headers) == 0:
  216. valid_headers = None
  217. return valid_headers
  218. def _populate_valid_params(params=None, valid_keys=None):
  219. """构建只包含有效keys的params
  220. :param params: 需要过滤的params
  221. :type params: dict
  222. :param valid_keys: 有效的关键key列表
  223. :type valid_keys: list
  224. :return: 只包含有效keys的params
  225. """
  226. if params is None or valid_keys is None:
  227. return None
  228. valid_params = dict()
  229. for key in valid_keys:
  230. if params.get(key) is not None:
  231. valid_params[key] = params[key]
  232. if len(valid_params) == 0:
  233. valid_params = None
  234. return valid_params
  235. class _ResumableOperation(object):
  236. def __init__(self, bucket, key, filename, size, store,
  237. progress_callback=None, versionid=None):
  238. self.bucket = bucket
  239. self.key = to_string(key)
  240. self.filename = filename
  241. self.size = size
  242. self._abspath = os.path.abspath(filename)
  243. self.__store = store
  244. if versionid is None:
  245. self.__record_key = self.__store.make_store_key(bucket.bucket_name, self.key, self._abspath)
  246. else:
  247. self.__record_key = self.__store.make_store_key(bucket.bucket_name, self.key, self._abspath, versionid)
  248. logger.debug("Init _ResumableOperation, record_key: {0}".format(self.__record_key))
  249. # protect self.__progress_callback
  250. self.__plock = threading.Lock()
  251. self.__progress_callback = progress_callback
  252. def _del_record(self):
  253. self.__store.delete(self.__record_key)
  254. def _put_record(self, record):
  255. self.__store.put(self.__record_key, record)
  256. def _get_record(self):
  257. return self.__store.get(self.__record_key)
  258. def _report_progress(self, consumed_size):
  259. if self.__progress_callback:
  260. with self.__plock:
  261. self.__progress_callback(consumed_size, self.size)
  262. class _ResumableOperation1(object):
  263. def __init__(self, bucket, key, filename, size, store, progress_callback=None, versionid=None, useData=None):
  264. self.useData = useData
  265. self.bucket = bucket
  266. self.key = to_string(key)
  267. self.filename = filename
  268. self.size = size
  269. self._abspath = os.path.abspath(filename)
  270. self.__store = store
  271. if versionid is None:
  272. self.__record_key = self.__store.make_store_key(bucket.bucket_name, self.key, self._abspath)
  273. else:
  274. self.__record_key = self.__store.make_store_key(bucket.bucket_name, self.key, self._abspath, versionid)
  275. logger.debug("Init _ResumableOperation, record_key: {0}".format(self.__record_key))
  276. # protect self.__progress_callback
  277. self.__plock = threading.Lock()
  278. self.__progress_callback = progress_callback
  279. def _del_record(self):
  280. self.__store.delete(self.__record_key)
  281. def _put_record(self, record):
  282. self.__store.put(self.__record_key, record)
  283. def _get_record(self):
  284. return self.__store.get(self.__record_key)
  285. def _report_progress(self, consumed_size):
  286. if self.__progress_callback:
  287. with self.__plock:
  288. self.__progress_callback(self.useData, consumed_size, self.size)
  289. class _ObjectInfo(object):
  290. def __init__(self):
  291. self.size = None
  292. self.etag = None
  293. self.mtime = None
  294. @staticmethod
  295. def make(head_object_result):
  296. objectInfo = _ObjectInfo()
  297. objectInfo.size = head_object_result.content_length
  298. objectInfo.etag = head_object_result.etag
  299. objectInfo.mtime = head_object_result.last_modified
  300. return objectInfo
  301. class _ResumableDownloader(_ResumableOperation):
  302. def __init__(self, bucket, key, filename, objectInfo,
  303. part_size=None,
  304. store=None,
  305. progress_callback=None,
  306. num_threads=None,
  307. params=None,
  308. headers=None):
  309. versionid = None
  310. if params is not None and params.get('versionId') is not None:
  311. versionid = params.get('versionId')
  312. super(_ResumableDownloader, self).__init__(bucket, key, filename, objectInfo.size,
  313. store or ResumableDownloadStore(),
  314. progress_callback=progress_callback,
  315. versionid=versionid)
  316. self.objectInfo = objectInfo
  317. self.__op = 'ResumableDownload'
  318. self.__part_size = defaults.get(part_size, defaults.multiget_part_size)
  319. self.__part_size = _determine_part_size_internal(self.size, self.__part_size, _MAX_MULTIGET_PART_COUNT)
  320. self.__tmp_file = None
  321. self.__num_threads = defaults.get(num_threads, defaults.multiget_num_threads)
  322. self.__finished_parts = None
  323. self.__finished_size = None
  324. self.__params = params
  325. self.__headers = headers
  326. # protect record
  327. self.__lock = threading.Lock()
  328. self.__record = None
  329. logger.debug("Init _ResumableDownloader, bucket: {0}, key: {1}, part_size: {2}, num_thread: {3}".format(
  330. bucket.bucket_name, to_string(key), self.__part_size, self.__num_threads))
  331. def download(self, server_crc=None):
  332. self.__load_record()
  333. parts_to_download = self.__get_parts_to_download()
  334. logger.debug("Parts need to download: {0}".format(parts_to_download))
  335. # create tmp file if it is does not exist
  336. open(self.__tmp_file, 'a').close()
  337. q = TaskQueue(functools.partial(self.__producer, parts_to_download=parts_to_download),
  338. [self.__consumer] * self.__num_threads)
  339. q.run()
  340. if self.bucket.enable_crc:
  341. parts = sorted(self.__finished_parts, key=lambda p: p.part_number)
  342. object_crc = utils.calc_obj_crc_from_parts(parts)
  343. utils.check_crc('resume download', object_crc, server_crc, None)
  344. utils.force_rename(self.__tmp_file, self.filename)
  345. self._report_progress(self.size)
  346. self._del_record()
  347. def __producer(self, q, parts_to_download=None):
  348. for part in parts_to_download:
  349. q.put(part)
  350. def __consumer(self, q):
  351. while q.ok():
  352. part = q.get()
  353. if part is None:
  354. break
  355. self.__download_part(part)
  356. def __download_part(self, part):
  357. self._report_progress(self.__finished_size)
  358. with open(self.__tmp_file, 'rb+') as f:
  359. f.seek(part.start, os.SEEK_SET)
  360. headers = _populate_valid_headers(self.__headers, [OSS_REQUEST_PAYER, OSS_TRAFFIC_LIMIT])
  361. if headers is None:
  362. headers = http.CaseInsensitiveDict()
  363. headers[IF_MATCH] = self.objectInfo.etag
  364. headers[IF_UNMODIFIED_SINCE] = utils.http_date(self.objectInfo.mtime)
  365. result = self.bucket.get_object(self.key, byte_range=(part.start, part.end - 1), headers=headers,
  366. params=self.__params)
  367. utils.copyfileobj_and_verify(result, f, part.end - part.start, request_id=result.request_id)
  368. part.part_crc = result.client_crc
  369. logger.debug("down part success, add part info to record, part_number: {0}, start: {1}, end: {2}".format(
  370. part.part_number, part.start, part.end))
  371. self.__finish_part(part)
  372. def __load_record(self):
  373. record = self._get_record()
  374. logger.debug("Load record return {0}".format(record))
  375. if record and not self.__is_record_sane(record):
  376. logger.warning("The content of record is invalid, delete the record")
  377. self._del_record()
  378. record = None
  379. if record and not os.path.exists(self.filename + record['tmp_suffix']):
  380. logger.warning("Temp file: {0} does not exist, delete the record".format(
  381. self.filename + record['tmp_suffix']))
  382. self._del_record()
  383. record = None
  384. if record and self.__is_remote_changed(record):
  385. logger.warning("Object: {0} has been overwritten,delete the record and tmp file".format(self.key))
  386. utils.silently_remove(self.filename + record['tmp_suffix'])
  387. self._del_record()
  388. record = None
  389. if not record:
  390. record = {'op_type': self.__op, 'bucket': self.bucket.bucket_name, 'key': self.key,
  391. 'size': self.objectInfo.size, 'mtime': self.objectInfo.mtime, 'etag': self.objectInfo.etag,
  392. 'part_size': self.__part_size, 'file_path': self._abspath, 'tmp_suffix': self.__gen_tmp_suffix(),
  393. 'parts': []}
  394. logger.debug('Add new record, bucket: {0}, key: {1}, part_size: {2}'.format(
  395. self.bucket.bucket_name, self.key, self.__part_size))
  396. self._put_record(record)
  397. self.__tmp_file = self.filename + record['tmp_suffix']
  398. self.__part_size = record['part_size']
  399. self.__finished_parts = list(
  400. _PartToProcess(p['part_number'], p['start'], p['end'], p['part_crc']) for p in record['parts'])
  401. self.__finished_size = sum(p.size for p in self.__finished_parts)
  402. self.__record = record
  403. def __get_parts_to_download(self):
  404. assert self.__record
  405. all_set = set(_split_to_parts(self.size, self.__part_size))
  406. finished_set = set(self.__finished_parts)
  407. return sorted(list(all_set - finished_set), key=lambda p: p.part_number)
  408. def __is_record_sane(self, record):
  409. try:
  410. if record['op_type'] != self.__op:
  411. logger.error('op_type invalid, op_type in record:{0} is invalid'.format(record['op_type']))
  412. return False
  413. for key in ('etag', 'tmp_suffix', 'file_path', 'bucket', 'key'):
  414. if not isinstance(record[key], str):
  415. logger.error('{0} is not a string: {1}'.format(key, record[key]))
  416. return False
  417. for key in ('part_size', 'size', 'mtime'):
  418. if not isinstance(record[key], int):
  419. logger.error('{0} is not an integer: {1}'.format(key, record[key]))
  420. return False
  421. if not isinstance(record['parts'], list):
  422. logger.error('{0} is not a list: {1}'.format(key, record[key]))
  423. return False
  424. except KeyError as e:
  425. logger.error('Key not found: {0}'.format(e.args))
  426. return False
  427. return True
  428. def __is_remote_changed(self, record):
  429. return (record['mtime'] != self.objectInfo.mtime or
  430. record['size'] != self.objectInfo.size or
  431. record['etag'] != self.objectInfo.etag)
  432. def __finish_part(self, part):
  433. with self.__lock:
  434. self.__finished_parts.append(part)
  435. self.__finished_size += part.size
  436. self.__record['parts'].append({'part_number': part.part_number,
  437. 'start': part.start,
  438. 'end': part.end,
  439. 'part_crc': part.part_crc})
  440. self._put_record(self.__record)
  441. def __gen_tmp_suffix(self):
  442. return '.tmp-' + ''.join(random.choice(string.ascii_lowercase) for i in range(12))
  443. class _ResumableUploader(_ResumableOperation):
  444. """以断点续传方式上传文件。
  445. :param bucket: :class:`Bucket <oss2.Bucket>` 对象
  446. :param key: 文件名
  447. :param filename: 待上传的文件名
  448. :param size: 文件总长度
  449. :param store: 用来保存进度的持久化存储
  450. :param headers: 传给 `init_multipart_upload` 的HTTP头部
  451. :param part_size: 分片大小。优先使用用户提供的值。如果用户没有指定,那么对于新上传,计算出一个合理值;对于老的上传,采用第一个
  452. 分片的大小。
  453. :param progress_callback: 上传进度回调函数。参见 :ref:`progress_callback` 。
  454. """
  455. def __init__(self, bucket, key, filename, size,
  456. store=None,
  457. headers=None,
  458. part_size=None,
  459. progress_callback=None,
  460. num_threads=None,
  461. params=None):
  462. super(_ResumableUploader, self).__init__(bucket, key, filename, size,
  463. store or ResumableStore(),
  464. progress_callback=progress_callback)
  465. self.__op = 'ResumableUpload'
  466. self.__headers = headers
  467. self.__part_size = defaults.get(part_size, defaults.part_size)
  468. self.__mtime = os.path.getmtime(filename)
  469. self.__num_threads = defaults.get(num_threads, defaults.multipart_num_threads)
  470. self.__upload_id = None
  471. self.__params = params
  472. # protect below fields
  473. self.__lock = threading.Lock()
  474. self.__record = None
  475. self.__finished_size = 0
  476. self.__finished_parts = None
  477. self.__encryption = False
  478. self.__record_upload_context = False
  479. self.__upload_context = None
  480. if isinstance(self.bucket, CryptoBucket):
  481. self.__encryption = True
  482. self.__record_upload_context = True
  483. logger.debug("Init _ResumableUploader, bucket: {0}, key: {1}, part_size: {2}, num_thread: {3}".format(
  484. bucket.bucket_name, to_string(key), self.__part_size, self.__num_threads))
  485. def upload(self):
  486. self.__load_record()
  487. parts_to_upload = self.__get_parts_to_upload(self.__finished_parts)
  488. parts_to_upload = sorted(parts_to_upload, key=lambda p: p.part_number)
  489. logger.debug("Parts need to upload: {0}".format(parts_to_upload))
  490. q = TaskQueue(functools.partial(self.__producer, parts_to_upload=parts_to_upload),
  491. [self.__consumer] * self.__num_threads)
  492. q.run()
  493. self._report_progress(self.size)
  494. headers = _populate_valid_headers(self.__headers, [OSS_REQUEST_PAYER, OSS_OBJECT_ACL])
  495. result = self.bucket.complete_multipart_upload(self.key, self.__upload_id, self.__finished_parts,
  496. headers=headers)
  497. self._del_record()
  498. return result
  499. def __producer(self, q, parts_to_upload=None):
  500. for part in parts_to_upload:
  501. q.put(part)
  502. def __consumer(self, q):
  503. while True:
  504. part = q.get()
  505. if part is None:
  506. break
  507. self.__upload_part(part)
  508. def __upload_part(self, part):
  509. with open(to_unicode(self.filename), 'rb') as f:
  510. self._report_progress(self.__finished_size)
  511. f.seek(part.start, os.SEEK_SET)
  512. headers = _populate_valid_headers(self.__headers, [OSS_REQUEST_PAYER, OSS_TRAFFIC_LIMIT])
  513. if self.__encryption:
  514. result = self.bucket.upload_part(self.key, self.__upload_id, part.part_number,
  515. utils.SizedFileAdapter(f, part.size), headers=headers,
  516. upload_context=self.__upload_context)
  517. else:
  518. result = self.bucket.upload_part(self.key, self.__upload_id, part.part_number,
  519. utils.SizedFileAdapter(f, part.size), headers=headers)
  520. logger.debug("Upload part success, add part info to record, part_number: {0}, etag: {1}, size: {2}".format(
  521. part.part_number, result.etag, part.size))
  522. self.__finish_part(PartInfo(part.part_number, result.etag, size=part.size, part_crc=result.crc))
  523. def __finish_part(self, part_info):
  524. with self.__lock:
  525. self.__finished_parts.append(part_info)
  526. self.__finished_size += part_info.size
  527. def __load_record(self):
  528. record = self._get_record()
  529. logger.debug("Load record return {0}".format(record))
  530. if record and not self.__is_record_sane(record):
  531. logger.warn("The content of record is invalid, delete the record")
  532. self._del_record()
  533. record = None
  534. if record and self.__file_changed(record):
  535. logger.warn("File: {0} has been changed, delete the record".format(self.filename))
  536. self._del_record()
  537. record = None
  538. if record and not self.__upload_exists(record['upload_id']):
  539. logger.warn('Multipart upload: {0} does not exist, delete the record'.format(record['upload_id']))
  540. self._del_record()
  541. record = None
  542. if not record:
  543. params = _populate_valid_params(self.__params, [Bucket.SEQUENTIAL])
  544. part_size = determine_part_size(self.size, self.__part_size)
  545. logger.debug("Upload File size: {0}, User-specify part_size: {1}, Calculated part_size: {2}".format(
  546. self.size, self.__part_size, part_size))
  547. if self.__encryption:
  548. upload_context = models.MultipartUploadCryptoContext(self.size, part_size)
  549. upload_id = self.bucket.init_multipart_upload(self.key, self.__headers, params,
  550. upload_context).upload_id
  551. if self.__record_upload_context:
  552. material = upload_context.content_crypto_material
  553. material_record = {'wrap_alg': material.wrap_alg, 'cek_alg': material.cek_alg,
  554. 'encrypted_key': b64encode_as_string(material.encrypted_key),
  555. 'encrypted_iv': b64encode_as_string(material.encrypted_iv),
  556. 'mat_desc': material.mat_desc}
  557. else:
  558. upload_id = self.bucket.init_multipart_upload(self.key, self.__headers, params).upload_id
  559. record = {'op_type': self.__op, 'upload_id': upload_id, 'file_path': self._abspath, 'size': self.size,
  560. 'mtime': self.__mtime, 'bucket': self.bucket.bucket_name, 'key': self.key, 'part_size': part_size}
  561. if self.__record_upload_context:
  562. record['content_crypto_material'] = material_record
  563. logger.debug('Add new record, bucket: {0}, key: {1}, upload_id: {2}, part_size: {3}'.format(
  564. self.bucket.bucket_name, self.key, upload_id, part_size))
  565. self._put_record(record)
  566. self.__record = record
  567. self.__part_size = self.__record['part_size']
  568. self.__upload_id = self.__record['upload_id']
  569. if self.__record_upload_context:
  570. if 'content_crypto_material' in self.__record:
  571. material_record = self.__record['content_crypto_material']
  572. wrap_alg = material_record['wrap_alg']
  573. cek_alg = material_record['cek_alg']
  574. if cek_alg != self.bucket.crypto_provider.cipher.alg or wrap_alg != self.bucket.crypto_provider.wrap_alg:
  575. err_msg = 'Envelope or data encryption/decryption algorithm is inconsistent'
  576. raise exceptions.InconsistentError(err_msg, self)
  577. content_crypto_material = models.ContentCryptoMaterial(self.bucket.crypto_provider.cipher,
  578. material_record['wrap_alg'],
  579. b64decode_from_string(
  580. material_record['encrypted_key']),
  581. b64decode_from_string(
  582. material_record['encrypted_iv']),
  583. material_record['mat_desc'])
  584. self.__upload_context = models.MultipartUploadCryptoContext(self.size, self.__part_size,
  585. content_crypto_material)
  586. else:
  587. err_msg = 'If record_upload_context flag is true, content_crypto_material must in the the record'
  588. raise exceptions.InconsistentError(err_msg, self)
  589. else:
  590. if 'content_crypto_material' in self.__record:
  591. err_msg = 'content_crypto_material must in the the record, but record_upload_context flat is false'
  592. raise exceptions.InvalidEncryptionRequest(err_msg, self)
  593. self.__finished_parts = self.__get_finished_parts()
  594. self.__finished_size = sum(p.size for p in self.__finished_parts)
  595. def __get_finished_parts(self):
  596. parts = []
  597. valid_headers = _filter_invalid_headers(self.__headers,
  598. [OSS_SERVER_SIDE_ENCRYPTION, OSS_SERVER_SIDE_DATA_ENCRYPTION])
  599. for part in PartIterator(self.bucket, self.key, self.__upload_id, headers=valid_headers):
  600. parts.append(part)
  601. return parts
  602. def __upload_exists(self, upload_id):
  603. try:
  604. valid_headers = _filter_invalid_headers(self.__headers,
  605. [OSS_SERVER_SIDE_ENCRYPTION, OSS_SERVER_SIDE_DATA_ENCRYPTION])
  606. list(iterators.PartIterator(self.bucket, self.key, upload_id, '0', max_parts=1, headers=valid_headers))
  607. except exceptions.NoSuchUpload:
  608. return False
  609. else:
  610. return True
  611. def __file_changed(self, record):
  612. return record['mtime'] != self.__mtime or record['size'] != self.size
  613. def __get_parts_to_upload(self, parts_uploaded):
  614. all_parts = _split_to_parts(self.size, self.__part_size)
  615. if not parts_uploaded:
  616. return all_parts
  617. all_parts_map = dict((p.part_number, p) for p in all_parts)
  618. for uploaded in parts_uploaded:
  619. if uploaded.part_number in all_parts_map:
  620. del all_parts_map[uploaded.part_number]
  621. return all_parts_map.values()
  622. def __is_record_sane(self, record):
  623. try:
  624. if record['op_type'] != self.__op:
  625. logger.error('op_type invalid, op_type in record:{0} is invalid'.format(record['op_type']))
  626. return False
  627. for key in ('upload_id', 'file_path', 'bucket', 'key'):
  628. if not isinstance(record[key], str):
  629. logger.error('Type Error, {0} in record is not a string type: {1}'.format(key, record[key]))
  630. return False
  631. for key in ('size', 'part_size'):
  632. if not isinstance(record[key], int):
  633. logger.error('Type Error, {0} in record is not an integer type: {1}'.format(key, record[key]))
  634. return False
  635. if not isinstance(record['mtime'], int) and not isinstance(record['mtime'], float):
  636. logger.error(
  637. 'Type Error, mtime in record is not a float or an integer type: {0}'.format(record['mtime']))
  638. return False
  639. except KeyError as e:
  640. logger.error('Key not found: {0}'.format(e.args))
  641. return False
  642. return True
  643. class _ResumableUploader1(_ResumableOperation1):
  644. def __init__(self, bucket, key, filename, size,
  645. store=None,
  646. headers=None,
  647. part_size=None,
  648. progress_callback=None,
  649. num_threads=None,
  650. params=None,
  651. useData=None):
  652. super(_ResumableUploader1, self).__init__(bucket, key, filename, size,
  653. store or ResumableStore(),
  654. progress_callback=progress_callback,
  655. useData=useData)
  656. self.__op = 'ResumableUpload'
  657. self.__headers = headers
  658. self.__part_size = defaults.get(part_size, defaults.part_size)
  659. self.__mtime = os.path.getmtime(filename)
  660. self.__num_threads = defaults.get(num_threads, defaults.multipart_num_threads)
  661. self.__upload_id = None
  662. self.__params = params
  663. # protect below fields
  664. self.__lock = threading.Lock()
  665. self.__record = None
  666. self.__finished_size = 0
  667. self.__finished_parts = None
  668. self.__encryption = False
  669. self.__record_upload_context = False
  670. self.__upload_context = None
  671. if isinstance(self.bucket, CryptoBucket):
  672. self.__encryption = True
  673. self.__record_upload_context = True
  674. logger.debug("Init _ResumableUploader, bucket: {0}, key: {1}, part_size: {2}, num_thread: {3}".format(
  675. bucket.bucket_name, to_string(key), self.__part_size, self.__num_threads))
  676. def upload(self):
  677. self.__load_record()
  678. parts_to_upload = self.__get_parts_to_upload(self.__finished_parts)
  679. parts_to_upload = sorted(parts_to_upload, key=lambda p: p.part_number)
  680. logger.debug("Parts need to upload: {0}".format(parts_to_upload))
  681. q = TaskQueue(functools.partial(self.__producer, parts_to_upload=parts_to_upload),
  682. [self.__consumer] * self.__num_threads)
  683. q.run()
  684. self._report_progress(self.size)
  685. headers = _populate_valid_headers(self.__headers, [OSS_REQUEST_PAYER, OSS_OBJECT_ACL])
  686. result = self.bucket.complete_multipart_upload(self.key, self.__upload_id, self.__finished_parts,
  687. headers=headers)
  688. self._del_record()
  689. return result
  690. def __producer(self, q, parts_to_upload=None):
  691. for part in parts_to_upload:
  692. q.put(part)
  693. def __consumer(self, q):
  694. while True:
  695. part = q.get()
  696. if part is None:
  697. break
  698. self.__upload_part(part)
  699. def __upload_part(self, part):
  700. with open(to_unicode(self.filename), 'rb') as f:
  701. self._report_progress(self.__finished_size)
  702. f.seek(part.start, os.SEEK_SET)
  703. headers = _populate_valid_headers(self.__headers, [OSS_REQUEST_PAYER, OSS_TRAFFIC_LIMIT])
  704. if self.__encryption:
  705. result = self.bucket.upload_part(self.key, self.__upload_id, part.part_number,
  706. utils.SizedFileAdapter(f, part.size), headers=headers,
  707. upload_context=self.__upload_context)
  708. else:
  709. result = self.bucket.upload_part(self.key, self.__upload_id, part.part_number,
  710. utils.SizedFileAdapter(f, part.size), headers=headers)
  711. logger.debug("Upload part success, add part info to record, part_number: {0}, etag: {1}, size: {2}".format(
  712. part.part_number, result.etag, part.size))
  713. self.__finish_part(PartInfo(part.part_number, result.etag, size=part.size, part_crc=result.crc))
  714. def __finish_part(self, part_info):
  715. with self.__lock:
  716. self.__finished_parts.append(part_info)
  717. self.__finished_size += part_info.size
  718. def __load_record(self):
  719. record = self._get_record()
  720. logger.debug("Load record return {0}".format(record))
  721. if record and not self.__is_record_sane(record):
  722. logger.warn("The content of record is invalid, delete the record")
  723. self._del_record()
  724. record = None
  725. if record and self.__file_changed(record):
  726. logger.warn("File: {0} has been changed, delete the record".format(self.filename))
  727. self._del_record()
  728. record = None
  729. if record and not self.__upload_exists(record['upload_id']):
  730. logger.warn('Multipart upload: {0} does not exist, delete the record'.format(record['upload_id']))
  731. self._del_record()
  732. record = None
  733. if not record:
  734. params = _populate_valid_params(self.__params, [Bucket.SEQUENTIAL])
  735. part_size = determine_part_size(self.size, self.__part_size)
  736. logger.debug("Upload File size: {0}, User-specify part_size: {1}, Calculated part_size: {2}".format(
  737. self.size, self.__part_size, part_size))
  738. if self.__encryption:
  739. upload_context = models.MultipartUploadCryptoContext(self.size, part_size)
  740. upload_id = self.bucket.init_multipart_upload(self.key, self.__headers, params,
  741. upload_context).upload_id
  742. if self.__record_upload_context:
  743. material = upload_context.content_crypto_material
  744. material_record = {'wrap_alg': material.wrap_alg, 'cek_alg': material.cek_alg,
  745. 'encrypted_key': b64encode_as_string(material.encrypted_key),
  746. 'encrypted_iv': b64encode_as_string(material.encrypted_iv),
  747. 'mat_desc': material.mat_desc}
  748. else:
  749. upload_id = self.bucket.init_multipart_upload(self.key, self.__headers, params).upload_id
  750. record = {'op_type': self.__op, 'upload_id': upload_id, 'file_path': self._abspath, 'size': self.size,
  751. 'mtime': self.__mtime, 'bucket': self.bucket.bucket_name, 'key': self.key, 'part_size': part_size}
  752. if self.__record_upload_context:
  753. record['content_crypto_material'] = material_record
  754. logger.debug('Add new record, bucket: {0}, key: {1}, upload_id: {2}, part_size: {3}'.format(
  755. self.bucket.bucket_name, self.key, upload_id, part_size))
  756. self._put_record(record)
  757. self.__record = record
  758. self.__part_size = self.__record['part_size']
  759. self.__upload_id = self.__record['upload_id']
  760. if self.__record_upload_context:
  761. if 'content_crypto_material' in self.__record:
  762. material_record = self.__record['content_crypto_material']
  763. wrap_alg = material_record['wrap_alg']
  764. cek_alg = material_record['cek_alg']
  765. if cek_alg != self.bucket.crypto_provider.cipher.alg or wrap_alg != self.bucket.crypto_provider.wrap_alg:
  766. err_msg = 'Envelope or data encryption/decryption algorithm is inconsistent'
  767. raise exceptions.InconsistentError(err_msg, self)
  768. content_crypto_material = models.ContentCryptoMaterial(self.bucket.crypto_provider.cipher,
  769. material_record['wrap_alg'],
  770. b64decode_from_string(
  771. material_record['encrypted_key']),
  772. b64decode_from_string(
  773. material_record['encrypted_iv']),
  774. material_record['mat_desc'])
  775. self.__upload_context = models.MultipartUploadCryptoContext(self.size, self.__part_size,
  776. content_crypto_material)
  777. else:
  778. err_msg = 'If record_upload_context flag is true, content_crypto_material must in the the record'
  779. raise exceptions.InconsistentError(err_msg, self)
  780. else:
  781. if 'content_crypto_material' in self.__record:
  782. err_msg = 'content_crypto_material must in the the record, but record_upload_context flat is false'
  783. raise exceptions.InvalidEncryptionRequest(err_msg, self)
  784. self.__finished_parts = self.__get_finished_parts()
  785. self.__finished_size = sum(p.size for p in self.__finished_parts)
  786. def __get_finished_parts(self):
  787. parts = []
  788. valid_headers = _filter_invalid_headers(self.__headers,
  789. [OSS_SERVER_SIDE_ENCRYPTION, OSS_SERVER_SIDE_DATA_ENCRYPTION])
  790. for part in PartIterator(self.bucket, self.key, self.__upload_id, headers=valid_headers):
  791. parts.append(part)
  792. return parts
  793. def __upload_exists(self, upload_id):
  794. try:
  795. valid_headers = _filter_invalid_headers(self.__headers,
  796. [OSS_SERVER_SIDE_ENCRYPTION, OSS_SERVER_SIDE_DATA_ENCRYPTION])
  797. list(iterators.PartIterator(self.bucket, self.key, upload_id, '0', max_parts=1, headers=valid_headers))
  798. except exceptions.NoSuchUpload:
  799. return False
  800. else:
  801. return True
  802. def __file_changed(self, record):
  803. return record['mtime'] != self.__mtime or record['size'] != self.size
  804. def __get_parts_to_upload(self, parts_uploaded):
  805. all_parts = _split_to_parts(self.size, self.__part_size)
  806. if not parts_uploaded:
  807. return all_parts
  808. all_parts_map = dict((p.part_number, p) for p in all_parts)
  809. for uploaded in parts_uploaded:
  810. if uploaded.part_number in all_parts_map:
  811. del all_parts_map[uploaded.part_number]
  812. return all_parts_map.values()
  813. def __is_record_sane(self, record):
  814. try:
  815. if record['op_type'] != self.__op:
  816. logger.error('op_type invalid, op_type in record:{0} is invalid'.format(record['op_type']))
  817. return False
  818. for key in ('upload_id', 'file_path', 'bucket', 'key'):
  819. if not isinstance(record[key], str):
  820. logger.error('Type Error, {0} in record is not a string type: {1}'.format(key, record[key]))
  821. return False
  822. for key in ('size', 'part_size'):
  823. if not isinstance(record[key], int):
  824. logger.error('Type Error, {0} in record is not an integer type: {1}'.format(key, record[key]))
  825. return False
  826. if not isinstance(record['mtime'], int) and not isinstance(record['mtime'], float):
  827. logger.error(
  828. 'Type Error, mtime in record is not a float or an integer type: {0}'.format(record['mtime']))
  829. return False
  830. except KeyError as e:
  831. logger.error('Key not found: {0}'.format(e.args))
  832. return False
  833. return True
  834. _UPLOAD_TEMP_DIR = '.py-oss-upload'
  835. _DOWNLOAD_TEMP_DIR = '.py-oss-download'
  836. class _ResumableStoreBase(object):
  837. def __init__(self, root, dir):
  838. logger.debug("Init ResumableStoreBase, root path: {0}, temp dir: {1}".format(root, dir))
  839. self.dir = os.path.join(root, dir)
  840. if os.path.isdir(self.dir):
  841. return
  842. utils.makedir_p(self.dir)
  843. def get(self, key):
  844. pathname = self.__path(key)
  845. logger.debug('ResumableStoreBase: get key: {0} from file path: {1}'.format(key, pathname))
  846. if not os.path.exists(pathname):
  847. logger.debug("file {0} is not exist".format(pathname))
  848. return None
  849. # json.load()返回的总是unicode,对于Python2,我们将其转换
  850. # 为str。
  851. try:
  852. with open(to_unicode(pathname), 'r') as f:
  853. content = json.load(f)
  854. except ValueError:
  855. os.remove(pathname)
  856. return None
  857. else:
  858. return stringify(content)
  859. def put(self, key, value):
  860. pathname = self.__path(key)
  861. with open(to_unicode(pathname), 'w') as f:
  862. json.dump(value, f)
  863. logger.debug('ResumableStoreBase: put key: {0} to file path: {1}, value: {2}'.format(key, pathname, value))
  864. def delete(self, key):
  865. pathname = self.__path(key)
  866. os.remove(pathname)
  867. logger.debug('ResumableStoreBase: delete key: {0}, file path: {1}'.format(key, pathname))
  868. def __path(self, key):
  869. return os.path.join(self.dir, key)
  870. def _normalize_path(path):
  871. return os.path.normpath(os.path.normcase(path))
  872. class ResumableStore(_ResumableStoreBase):
  873. """保存断点上传断点信息的类。
  874. 每次上传的信息会保存在 `root/dir/` 下面的某个文件里。
  875. :param str root: 父目录,缺省为HOME
  876. :param str dir: 子目录,缺省为 `_UPLOAD_TEMP_DIR`
  877. """
  878. def __init__(self, root=None, dir=None):
  879. super(ResumableStore, self).__init__(root or os.path.expanduser('~'), dir or _UPLOAD_TEMP_DIR)
  880. @staticmethod
  881. def make_store_key(bucket_name, key, filename):
  882. filepath = _normalize_path(filename)
  883. oss_pathname = 'oss://{0}/{1}'.format(bucket_name, key)
  884. return utils.md5_string(oss_pathname) + '--' + utils.md5_string(filepath)
  885. class ResumableDownloadStore(_ResumableStoreBase):
  886. """保存断点下载断点信息的类。
  887. 每次下载的断点信息会保存在 `root/dir/` 下面的某个文件里。
  888. :param str root: 父目录,缺省为HOME
  889. :param str dir: 子目录,缺省为 `_DOWNLOAD_TEMP_DIR`
  890. """
  891. def __init__(self, root=None, dir=None):
  892. super(ResumableDownloadStore, self).__init__(root or os.path.expanduser('~'), dir or _DOWNLOAD_TEMP_DIR)
  893. @staticmethod
  894. def make_store_key(bucket_name, key, filename, version_id=None):
  895. filepath = _normalize_path(filename)
  896. if version_id is None:
  897. oss_pathname = 'oss://{0}/{1}'.format(bucket_name, key)
  898. else:
  899. oss_pathname = 'oss://{0}/{1}?versionid={2}'.format(bucket_name, key, version_id)
  900. return utils.md5_string(oss_pathname) + '--' + utils.md5_string(filepath)
  901. def make_upload_store(root=None, dir=None):
  902. return ResumableStore(root=root, dir=dir)
  903. def make_download_store(root=None, dir=None):
  904. return ResumableDownloadStore(root=root, dir=dir)
  905. class _PartToProcess(object):
  906. def __init__(self, part_number, start, end, part_crc=None):
  907. self.part_number = part_number
  908. self.start = start
  909. self.end = end
  910. self.part_crc = part_crc
  911. @property
  912. def size(self):
  913. return self.end - self.start
  914. def __hash__(self):
  915. return hash(self.__key)
  916. def __eq__(self, other):
  917. return self.__key == other.__key
  918. @property
  919. def __key(self):
  920. return self.part_number, self.start, self.end