您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

102 行
4.9KB

  1. # xView 2018 dataset https://challenge.xviewdataset.org
  2. # ----> NOTE: DOWNLOAD DATA MANUALLY from URL above and unzip to /datasets/xView before running train command below
  3. # Train command: python train.py --data xView.yaml
  4. # Default dataset location is next to YOLOv5:
  5. # /parent
  6. # /datasets/xView
  7. # /yolov5
  8. # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
  9. path: ../datasets/xView # dataset root dir
  10. train: images/autosplit_train.txt # train images (relative to 'path') 90% of 847 train images
  11. val: images/autosplit_val.txt # train images (relative to 'path') 10% of 847 train images
  12. # Classes
  13. nc: 60 # number of classes
  14. names: [ 'Fixed-wing Aircraft', 'Small Aircraft', 'Cargo Plane', 'Helicopter', 'Passenger Vehicle', 'Small Car', 'Bus',
  15. 'Pickup Truck', 'Utility Truck', 'Truck', 'Cargo Truck', 'Truck w/Box', 'Truck Tractor', 'Trailer',
  16. 'Truck w/Flatbed', 'Truck w/Liquid', 'Crane Truck', 'Railway Vehicle', 'Passenger Car', 'Cargo Car',
  17. 'Flat Car', 'Tank car', 'Locomotive', 'Maritime Vessel', 'Motorboat', 'Sailboat', 'Tugboat', 'Barge',
  18. 'Fishing Vessel', 'Ferry', 'Yacht', 'Container Ship', 'Oil Tanker', 'Engineering Vehicle', 'Tower crane',
  19. 'Container Crane', 'Reach Stacker', 'Straddle Carrier', 'Mobile Crane', 'Dump Truck', 'Haul Truck',
  20. 'Scraper/Tractor', 'Front loader/Bulldozer', 'Excavator', 'Cement Mixer', 'Ground Grader', 'Hut/Tent', 'Shed',
  21. 'Building', 'Aircraft Hangar', 'Damaged Building', 'Facility', 'Construction Site', 'Vehicle Lot', 'Helipad',
  22. 'Storage Tank', 'Shipping container lot', 'Shipping Container', 'Pylon', 'Tower' ] # class names
  23. # Download script/URL (optional) ---------------------------------------------------------------------------------------
  24. download: |
  25. import json
  26. import os
  27. from pathlib import Path
  28. import numpy as np
  29. from PIL import Image
  30. from tqdm import tqdm
  31. from utils.datasets import autosplit
  32. from utils.general import download, xyxy2xywhn
  33. def convert_labels(fname=Path('xView/xView_train.geojson')):
  34. # Convert xView geoJSON labels to YOLO format
  35. path = fname.parent
  36. with open(fname) as f:
  37. print(f'Loading {fname}...')
  38. data = json.load(f)
  39. # Make dirs
  40. labels = Path(path / 'labels' / 'train')
  41. os.system(f'rm -rf {labels}')
  42. labels.mkdir(parents=True, exist_ok=True)
  43. # xView classes 11-94 to 0-59
  44. xview_class2index = [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 1, 2, -1, 3, -1, 4, 5, 6, 7, 8, -1, 9, 10, 11,
  45. 12, 13, 14, 15, -1, -1, 16, 17, 18, 19, 20, 21, 22, -1, 23, 24, 25, -1, 26, 27, -1, 28, -1,
  46. 29, 30, 31, 32, 33, 34, 35, 36, 37, -1, 38, 39, 40, 41, 42, 43, 44, 45, -1, -1, -1, -1, 46,
  47. 47, 48, 49, -1, 50, 51, -1, 52, -1, -1, -1, 53, 54, -1, 55, -1, -1, 56, -1, 57, -1, 58, 59]
  48. shapes = {}
  49. for feature in tqdm(data['features'], desc=f'Converting {fname}'):
  50. p = feature['properties']
  51. if p['bounds_imcoords']:
  52. id = p['image_id']
  53. file = path / 'train_images' / id
  54. if file.exists(): # 1395.tif missing
  55. try:
  56. box = np.array([int(num) for num in p['bounds_imcoords'].split(",")])
  57. assert box.shape[0] == 4, f'incorrect box shape {box.shape[0]}'
  58. cls = p['type_id']
  59. cls = xview_class2index[int(cls)] # xView class to 0-60
  60. assert 59 >= cls >= 0, f'incorrect class index {cls}'
  61. # Write YOLO label
  62. if id not in shapes:
  63. shapes[id] = Image.open(file).size
  64. box = xyxy2xywhn(box[None].astype(np.float), w=shapes[id][0], h=shapes[id][1], clip=True)
  65. with open((labels / id).with_suffix('.txt'), 'a') as f:
  66. f.write(f"{cls} {' '.join(f'{x:.6f}' for x in box[0])}\n") # write label.txt
  67. except Exception as e:
  68. print(f'WARNING: skipping one label for {file}: {e}')
  69. # Download manually from https://challenge.xviewdataset.org
  70. dir = Path(yaml['path']) # dataset root dir
  71. # urls = ['https://d307kc0mrhucc3.cloudfront.net/train_labels.zip', # train labels
  72. # 'https://d307kc0mrhucc3.cloudfront.net/train_images.zip', # 15G, 847 train images
  73. # 'https://d307kc0mrhucc3.cloudfront.net/val_images.zip'] # 5G, 282 val images (no labels)
  74. # download(urls, dir=dir, delete=False)
  75. # Convert labels
  76. convert_labels(dir / 'xView_train.geojson')
  77. # Move images
  78. images = Path(dir / 'images')
  79. images.mkdir(parents=True, exist_ok=True)
  80. Path(dir / 'train_images').rename(dir / 'images' / 'train')
  81. Path(dir / 'val_images').rename(dir / 'images' / 'val')
  82. # Split
  83. autosplit(dir / 'images' / 'train')