Update README

This commit is contained in:
Teoge 2019-12-27 10:57:50 +08:00
parent 4b6a1c2218
commit 86623eaa9a
4 changed files with 45 additions and 18 deletions

View File

@ -4,13 +4,10 @@ This is the implementation of DMPR-PS using PyTorch.
## Requirements ## Requirements
* CUDA
* PyTorch * PyTorch
* OpenCV * CUDA (optional)
* NumPy * Other requirements
* Pillow `pip install -r requirements.txt`
* Visdom (optional)
* Matplotlib (optional)
## Pre-trained weights ## Pre-trained weights
@ -30,8 +27,8 @@ The [pre-trained weights](https://drive.google.com/open?id=1OuyF8bGttA11-CKJ4Mj3
python inference.py --mode video --detector_weights $DETECTOR_WEIGHTS --video $VIDEO python inference.py --mode video --detector_weights $DETECTOR_WEIGHTS --video $VIDEO
``` ```
`DETECTOR_WEIGHTS` is the trained weights of detector. Argument `DETECTOR_WEIGHTS` is the trained weights of detector.
`VIDEO` is path to the video. Argument `VIDEO` is path to the video.
View `config.py` for more argument details. View `config.py` for more argument details.
## Prepare data ## Prepare data
@ -45,9 +42,9 @@ The [pre-trained weights](https://drive.google.com/open?id=1OuyF8bGttA11-CKJ4Mj3
python prepare_dataset.py --dataset test --label_directory $LABEL_DIRECTORY --image_directory $IMAGE_DIRECTORY --output_directory $OUTPUT_DIRECTORY python prepare_dataset.py --dataset test --label_directory $LABEL_DIRECTORY --image_directory $IMAGE_DIRECTORY --output_directory $OUTPUT_DIRECTORY
``` ```
`LABEL_DIRECTORY` is the directory containing json labels. Argument `LABEL_DIRECTORY` is the directory containing json labels.
`IMAGE_DIRECTORY` is the directory containing jpg images. Argument `IMAGE_DIRECTORY` is the directory containing jpg images.
`OUTPUT_DIRECTORY` is the directory where output images and labels are. Argument `OUTPUT_DIRECTORY` is the directory where output images and labels are.
View `prepare_dataset.py` for more argument details. View `prepare_dataset.py` for more argument details.
## Train ## Train
@ -56,7 +53,7 @@ The [pre-trained weights](https://drive.google.com/open?id=1OuyF8bGttA11-CKJ4Mj3
python train.py --dataset_directory $TRAIN_DIRECTORY python train.py --dataset_directory $TRAIN_DIRECTORY
``` ```
`TRAIN_DIRECTORY` is the train directory generated in data preparation. Argument `TRAIN_DIRECTORY` is the train directory generated in data preparation.
View `config.py` for more argument details (batch size, learning rate, etc). View `config.py` for more argument details (batch size, learning rate, etc).
## Evaluate ## Evaluate
@ -67,8 +64,8 @@ View `config.py` for more argument details (batch size, learning rate, etc).
python evaluate.py --dataset_directory $TEST_DIRECTORY --detector_weights $DETECTOR_WEIGHTS python evaluate.py --dataset_directory $TEST_DIRECTORY --detector_weights $DETECTOR_WEIGHTS
``` ```
`TEST_DIRECTORY` is the test directory generated in data preparation. Argument `TEST_DIRECTORY` is the test directory generated in data preparation.
`DETECTOR_WEIGHTS` is the trained weights of detector. Argument `DETECTOR_WEIGHTS` is the trained weights of detector.
View `config.py` for more argument details (batch size, learning rate, etc). View `config.py` for more argument details (batch size, learning rate, etc).
* Evaluate parking-slot detection * Evaluate parking-slot detection
@ -77,7 +74,22 @@ View `config.py` for more argument details (batch size, learning rate, etc).
python ps_evaluate.py --label_directory $LABEL_DIRECTORY --image_directory $IMAGE_DIRECTORY --detector_weights $DETECTOR_WEIGHTS python ps_evaluate.py --label_directory $LABEL_DIRECTORY --image_directory $IMAGE_DIRECTORY --detector_weights $DETECTOR_WEIGHTS
``` ```
`LABEL_DIRECTORY` is the directory containing testing json labels. Argument `LABEL_DIRECTORY` is the directory containing testing json labels.
`IMAGE_DIRECTORY` is the directory containing testing jpg images. Argument `IMAGE_DIRECTORY` is the directory containing testing jpg images.
`DETECTOR_WEIGHTS` is the trained weights of detector. Argument `DETECTOR_WEIGHTS` is the trained weights of detector.
View `config.py` for more argument details. View `config.py` for more argument details.
## Citing DMPR-PS
If you find DMPR-PS useful in your research, please consider citing:
```()
@inproceedings{DMPR-PS,
Author = {Junhao Huang and Lin Zhang and Ying Shen and Huijuan Zhang and Shengjie Zhao and Yukai Yang},
Booktitle = {2019 IEEE International Conference on Multimedia and Expo (ICME)},
Title = {{DMPR-PS}: A novel approach for parking-slot detection using directional marking-point regression},
Month = {Jul.},
Year = {2019},
Pages = {212-217}
}
```

View File

@ -2,6 +2,7 @@
import torch import torch
import config import config
import util import util
from thop import profile
from data import get_predicted_points, match_marking_points, calc_point_squre_dist, calc_point_direction_angle from data import get_predicted_points, match_marking_points, calc_point_squre_dist, calc_point_direction_angle
from data import ParkingSlotDataset from data import ParkingSlotDataset
from model import DirectionalPointDetector from model import DirectionalPointDetector
@ -85,8 +86,14 @@ def evaluate_detector(args):
average_precision = util.calc_average_precision(precisions, recalls) average_precision = util.calc_average_precision(precisions, recalls)
if args.enable_visdom: if args.enable_visdom:
logger.plot_curve(precisions, recalls) logger.plot_curve(precisions, recalls)
sample = torch.randn(1, 3, config.INPUT_IMAGE_SIZE,
config.INPUT_IMAGE_SIZE)
flops, params = profile(dp_detector, inputs=(sample.to(device), ))
logger.log(average_loss=total_loss / len(psdataset), logger.log(average_loss=total_loss / len(psdataset),
average_precision=average_precision) average_precision=average_precision,
flops=flops,
params=params)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -53,6 +53,7 @@ def psevaluate_detector(args):
image = cv.imread(os.path.join(args.image_directory, name + '.jpg')) image = cv.imread(os.path.join(args.image_directory, name + '.jpg'))
pred_points = detect_marking_points( pred_points = detect_marking_points(
dp_detector, image, config.CONFID_THRESH_FOR_POINT, device) dp_detector, image, config.CONFID_THRESH_FOR_POINT, device)
slots = []
if pred_points: if pred_points:
marking_points = list(list(zip(*pred_points))[1]) marking_points = list(list(zip(*pred_points))[1])
slots = inference_slots(marking_points) slots = inference_slots(marking_points)

7
requirements.txt Normal file
View File

@ -0,0 +1,7 @@
torchvision
opencv-python
numpy
Pillow
visdom
matplotlib
thop