Add class filtering to `LoadImagesAndLabels()` dataloader (#5172)
* Add train class filter feature to datasets.py
Allows for training on a subset of total classes if `include_class` list is defined on datasets.py L448:
```python
include_class = [] # filter labels to include only these classes (optional)
```
* segments fix
This commit is contained in:
parent
b754525e99
commit
a346926996
|
|
@ -437,10 +437,6 @@ class LoadImagesAndLabels(Dataset):
|
||||||
self.shapes = np.array(shapes, dtype=np.float64)
|
self.shapes = np.array(shapes, dtype=np.float64)
|
||||||
self.img_files = list(cache.keys()) # update
|
self.img_files = list(cache.keys()) # update
|
||||||
self.label_files = img2label_paths(cache.keys()) # update
|
self.label_files = img2label_paths(cache.keys()) # update
|
||||||
if single_cls:
|
|
||||||
for x in self.labels:
|
|
||||||
x[:, 0] = 0
|
|
||||||
|
|
||||||
n = len(shapes) # number of images
|
n = len(shapes) # number of images
|
||||||
bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index
|
bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index
|
||||||
nb = bi[-1] + 1 # number of batches
|
nb = bi[-1] + 1 # number of batches
|
||||||
|
|
@ -448,6 +444,20 @@ class LoadImagesAndLabels(Dataset):
|
||||||
self.n = n
|
self.n = n
|
||||||
self.indices = range(n)
|
self.indices = range(n)
|
||||||
|
|
||||||
|
# Update labels
|
||||||
|
include_class = [] # filter labels to include only these classes (optional)
|
||||||
|
include_class_array = np.array(include_class).reshape(1, -1)
|
||||||
|
for i, (label, segment) in enumerate(zip(self.labels, self.segments)):
|
||||||
|
if include_class:
|
||||||
|
j = (label[:, 0:1] == include_class_array).any(1)
|
||||||
|
self.labels[i] = label[j]
|
||||||
|
if segment:
|
||||||
|
self.segments[i] = segment[j]
|
||||||
|
if single_cls: # single-class training, merge all classes into 0
|
||||||
|
self.labels[i][:, 0] = 0
|
||||||
|
if segment:
|
||||||
|
self.segments[i][:, 0] = 0
|
||||||
|
|
||||||
# Rectangular Training
|
# Rectangular Training
|
||||||
if self.rect:
|
if self.rect:
|
||||||
# Sort by aspect ratio
|
# Sort by aspect ratio
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue