No puede seleccionar más de 25 temas Los temas deben comenzar con una letra o número, pueden incluir guiones ('-') y pueden tener hasta 35 caracteres de largo.

103 líneas
5.0KB

  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. # DIUx xView 2018 Challenge https://challenge.xviewdataset.org by U.S. National Geospatial-Intelligence Agency (NGA)
  3. # -------- DOWNLOAD DATA MANUALLY and jar xf val_images.zip to 'datasets/xView' before running train command! --------
  4. # Example usage: python train.py --data xView.yaml
  5. # parent
  6. # ├── yolov5
  7. # └── datasets
  8. # └── xView ← downloads here
  9. # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
  10. path: ../datasets/xView # dataset root dir
  11. train: images/autosplit_train.txt # train images (relative to 'path') 90% of 847 train images
  12. val: images/autosplit_val.txt # train images (relative to 'path') 10% of 847 train images
  13. # Classes
  14. nc: 60 # number of classes
  15. names: ['Fixed-wing Aircraft', 'Small Aircraft', 'Cargo Plane', 'Helicopter', 'Passenger Vehicle', 'Small Car', 'Bus',
  16. 'Pickup Truck', 'Utility Truck', 'Truck', 'Cargo Truck', 'Truck w/Box', 'Truck Tractor', 'Trailer',
  17. 'Truck w/Flatbed', 'Truck w/Liquid', 'Crane Truck', 'Railway Vehicle', 'Passenger Car', 'Cargo Car',
  18. 'Flat Car', 'Tank car', 'Locomotive', 'Maritime Vessel', 'Motorboat', 'Sailboat', 'Tugboat', 'Barge',
  19. 'Fishing Vessel', 'Ferry', 'Yacht', 'Container Ship', 'Oil Tanker', 'Engineering Vehicle', 'Tower crane',
  20. 'Container Crane', 'Reach Stacker', 'Straddle Carrier', 'Mobile Crane', 'Dump Truck', 'Haul Truck',
  21. 'Scraper/Tractor', 'Front loader/Bulldozer', 'Excavator', 'Cement Mixer', 'Ground Grader', 'Hut/Tent', 'Shed',
  22. 'Building', 'Aircraft Hangar', 'Damaged Building', 'Facility', 'Construction Site', 'Vehicle Lot', 'Helipad',
  23. 'Storage Tank', 'Shipping container lot', 'Shipping Container', 'Pylon', 'Tower'] # class names
  24. # Download script/URL (optional) ---------------------------------------------------------------------------------------
  25. download: |
  26. import json
  27. import os
  28. from pathlib import Path
  29. import numpy as np
  30. from PIL import Image
  31. from tqdm import tqdm
  32. from utils.datasets import autosplit
  33. from utils.general import download, xyxy2xywhn
  34. def convert_labels(fname=Path('xView/xView_train.geojson')):
  35. # Convert xView geoJSON labels to YOLO format
  36. path = fname.parent
  37. with open(fname) as f:
  38. print(f'Loading {fname}...')
  39. data = json.load(f)
  40. # Make dirs
  41. labels = Path(path / 'labels' / 'train')
  42. os.system(f'rm -rf {labels}')
  43. labels.mkdir(parents=True, exist_ok=True)
  44. # xView classes 11-94 to 0-59
  45. xview_class2index = [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 1, 2, -1, 3, -1, 4, 5, 6, 7, 8, -1, 9, 10, 11,
  46. 12, 13, 14, 15, -1, -1, 16, 17, 18, 19, 20, 21, 22, -1, 23, 24, 25, -1, 26, 27, -1, 28, -1,
  47. 29, 30, 31, 32, 33, 34, 35, 36, 37, -1, 38, 39, 40, 41, 42, 43, 44, 45, -1, -1, -1, -1, 46,
  48. 47, 48, 49, -1, 50, 51, -1, 52, -1, -1, -1, 53, 54, -1, 55, -1, -1, 56, -1, 57, -1, 58, 59]
  49. shapes = {}
  50. for feature in tqdm(data['features'], desc=f'Converting {fname}'):
  51. p = feature['properties']
  52. if p['bounds_imcoords']:
  53. id = p['image_id']
  54. file = path / 'train_images' / id
  55. if file.exists(): # 1395.tif missing
  56. try:
  57. box = np.array([int(num) for num in p['bounds_imcoords'].split(",")])
  58. assert box.shape[0] == 4, f'incorrect box shape {box.shape[0]}'
  59. cls = p['type_id']
  60. cls = xview_class2index[int(cls)] # xView class to 0-60
  61. assert 59 >= cls >= 0, f'incorrect class index {cls}'
  62. # Write YOLO label
  63. if id not in shapes:
  64. shapes[id] = Image.open(file).size
  65. box = xyxy2xywhn(box[None].astype(np.float), w=shapes[id][0], h=shapes[id][1], clip=True)
  66. with open((labels / id).with_suffix('.txt'), 'a') as f:
  67. f.write(f"{cls} {' '.join(f'{x:.6f}' for x in box[0])}\n") # write label.txt
  68. except Exception as e:
  69. print(f'WARNING: skipping one label for {file}: {e}')
  70. # Download manually from https://challenge.xviewdataset.org
  71. dir = Path(yaml['path']) # dataset root dir
  72. # urls = ['https://d307kc0mrhucc3.cloudfront.net/train_labels.zip', # train labels
  73. # 'https://d307kc0mrhucc3.cloudfront.net/train_images.zip', # 15G, 847 train images
  74. # 'https://d307kc0mrhucc3.cloudfront.net/val_images.zip'] # 5G, 282 val images (no labels)
  75. # download(urls, dir=dir, delete=False)
  76. # Convert labels
  77. convert_labels(dir / 'xView_train.geojson')
  78. # Move images
  79. images = Path(dir / 'images')
  80. images.mkdir(parents=True, exist_ok=True)
  81. Path(dir / 'train_images').rename(dir / 'images' / 'train')
  82. Path(dir / 'val_images').rename(dir / 'images' / 'val')
  83. # Split
  84. autosplit(dir / 'images' / 'train')