Browse Source

Add class filtering to `LoadImagesAndLabels()` dataloader (#5172)

* Add train class filter feature to datasets.py

Allows for training on a subset of total classes if `include_class` list is defined on datasets.py L448:
```python
        include_class = []  # filter labels to include only these classes (optional)
```

* segments fix
modifyDataloader
Glenn Jocher GitHub 3 years ago
parent
commit
a346926996
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 14 additions and 4 deletions
  1. +14
    -4
      utils/datasets.py

+ 14
- 4
utils/datasets.py View File

self.shapes = np.array(shapes, dtype=np.float64) self.shapes = np.array(shapes, dtype=np.float64)
self.img_files = list(cache.keys()) # update self.img_files = list(cache.keys()) # update
self.label_files = img2label_paths(cache.keys()) # update self.label_files = img2label_paths(cache.keys()) # update
if single_cls:
for x in self.labels:
x[:, 0] = 0

n = len(shapes) # number of images n = len(shapes) # number of images
bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index
nb = bi[-1] + 1 # number of batches nb = bi[-1] + 1 # number of batches
self.n = n self.n = n
self.indices = range(n) self.indices = range(n)


# Update labels
include_class = [] # filter labels to include only these classes (optional)
include_class_array = np.array(include_class).reshape(1, -1)
for i, (label, segment) in enumerate(zip(self.labels, self.segments)):
if include_class:
j = (label[:, 0:1] == include_class_array).any(1)
self.labels[i] = label[j]
if segment:
self.segments[i] = segment[j]
if single_cls: # single-class training, merge all classes into 0
self.labels[i][:, 0] = 0
if segment:
self.segments[i][:, 0] = 0

# Rectangular Training # Rectangular Training
if self.rect: if self.rect:
# Sort by aspect ratio # Sort by aspect ratio

Loading…
Cancel
Save