83 lines
4.5 KiB
Python
83 lines
4.5 KiB
Python
import argparse
|
|
import train
|
|
import test
|
|
import eval
|
|
from datasets.dataset_dota import DOTA
|
|
from datasets.dataset_hrsc import HRSC
|
|
from models import ctrbox_net
|
|
import decoder
|
|
import os
|
|
import time
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description='BBAVectors Implementation')
|
|
parser.add_argument('--num_epoch', type=int, default=300, help='Number of epochs')
|
|
parser.add_argument('--batch_size', type=int, default=8, help='Number of batch size')
|
|
parser.add_argument('--num_workers', type=int, default=4, help='Number of workers') # 原来是
|
|
parser.add_argument('--init_lr', type=float, default=1.25e-4, help='Initial learning rate')
|
|
parser.add_argument('--input_h', type=int, default=608, help='Resized image height')
|
|
parser.add_argument('--input_w', type=int, default=608, help='Resized image width')
|
|
parser.add_argument('--K', type=int, default=500, help='Maximum of objects')
|
|
parser.add_argument('--conf_thresh', type=float, default=0.18, help='Confidence threshold, 0.1 for general evaluation')
|
|
parser.add_argument('--ngpus', type=int, default=1, help='Number of gpus, ngpus>1 for multigpu')
|
|
parser.add_argument('--resume_train', type=str, default='', help='Weights resumed in training')
|
|
parser.add_argument('--resume', type=str, default='model_last_resnet18_20230421_10K495dataset.pth', help='Weights resumed in testing and evaluation') # weight path
|
|
|
|
#parser.add_argument('--resume', type=str, default='weights_dota/model_last_resnet50_20230406_10Kdataset.pth', help='Weights resumed in testing and evaluation') # weight path
|
|
#parser.add_argument('--resume', type=str, default='weights_dota/model_last_resnet34_20230326_10Kdataset.pth', help='Weights resumed in testing and evaluation') # weight path
|
|
#parser.add_argument('--resume', type=str, default='weights_dota/model_last_resnet101_20230315.pth', help='Weights resumed in testing and evaluation') # weight path
|
|
#parser.add_argument('--resume', type=str, default='model_resnet101_20230315.pth', help='Weights resumed in testing and evaluation') # weight path
|
|
#parser.add_argument('--resume', type=str, default='model_last.pth', help='Weights resumed in testing and evaluation') # weight path
|
|
parser.add_argument('--dataset', type=str, default='dota', help='Name of dataset')
|
|
parser.add_argument('--data_dir', type=str, default='./dataPath', help='Data directory')
|
|
# parser.add_argument('--data_dir', type=str, default='../Datasets/dota', help='Data directory')
|
|
#parser.add_argument('--phase', type=str, default='train', help='Phase choice= {train, test, eval}')
|
|
#parser.add_argument('--phase', type=str, default='test', help='Phase choice= {train, test, eval}')
|
|
parser.add_argument('--phase', type=str, default='eval', help='Phase choice= {train, test, eval}')
|
|
parser.add_argument('--wh_channels', type=int, default=8, help='Number of channels for the vectors (4x2)') #yuan 8
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
if __name__ == '__main__':
|
|
args = parse_args()
|
|
dataset = {'dota': DOTA, 'hrsc': HRSC}
|
|
# num_classes = {'dota': 15, 'hrsc': 1}
|
|
num_classes = {'dota': 1, 'hrsc': 1}
|
|
heads = {'hm': num_classes[args.dataset],
|
|
'wh': 10,
|
|
'reg': 2,
|
|
'cls_theta': 1
|
|
}
|
|
down_ratio = 4
|
|
model = ctrbox_net.CTRBOX(heads=heads,
|
|
pretrained=True,
|
|
down_ratio=down_ratio,
|
|
final_kernel=1,
|
|
head_conv=256)
|
|
|
|
decoder = decoder.DecDecoder(K=args.K,
|
|
conf_thresh=args.conf_thresh,
|
|
num_classes=num_classes[args.dataset])
|
|
import time
|
|
|
|
#T1 = time.time()
|
|
|
|
if args.phase == 'train':
|
|
ctrbox_obj = train.TrainModule(dataset=dataset,
|
|
num_classes=num_classes,
|
|
model=model,
|
|
decoder=decoder,
|
|
down_ratio=down_ratio)
|
|
|
|
ctrbox_obj.train_network(args)
|
|
elif args.phase == 'test':
|
|
ctrbox_obj = test.TestModule(dataset=dataset, num_classes=num_classes, model=model, decoder=decoder)
|
|
T1 = time.time()
|
|
ctrbox_obj.test(args, down_ratio=down_ratio)
|
|
else:
|
|
ctrbox_obj = eval.EvalModule(dataset=dataset, num_classes=num_classes, model=model, decoder=decoder)
|
|
ctrbox_obj.evaluation(args, down_ratio=down_ratio)
|
|
|
|
T2 = time.time()
|
|
print('程序总运行时间:%s毫秒' % ((T2 - T1) * 1000)) |