import argparse import train import test import eval from datasets.dataset_dota import DOTA from datasets.dataset_hrsc import HRSC from models import ctrbox_net import decoder import os import time def parse_args(): parser = argparse.ArgumentParser(description='BBAVectors Implementation') parser.add_argument('--num_epoch', type=int, default=300, help='Number of epochs') parser.add_argument('--batch_size', type=int, default=8, help='Number of batch size') parser.add_argument('--num_workers', type=int, default=4, help='Number of workers') # 原来是 parser.add_argument('--init_lr', type=float, default=1.25e-4, help='Initial learning rate') parser.add_argument('--input_h', type=int, default=608, help='Resized image height') parser.add_argument('--input_w', type=int, default=608, help='Resized image width') parser.add_argument('--K', type=int, default=500, help='Maximum of objects') parser.add_argument('--conf_thresh', type=float, default=0.18, help='Confidence threshold, 0.1 for general evaluation') parser.add_argument('--ngpus', type=int, default=1, help='Number of gpus, ngpus>1 for multigpu') parser.add_argument('--resume_train', type=str, default='', help='Weights resumed in training') parser.add_argument('--resume', type=str, default='weights_dota/model_last_resnet18_20230421_10K495dataset.pth', help='Weights resumed in testing and evaluation') # weight path #parser.add_argument('--resume', type=str, default='weights_dota/model_last_resnet18_20230409_10Kdataset.pth', help='Weights resumed in testing and evaluation') # weight path #parser.add_argument('--resume', type=str, default='weights_dota/model_last_resnet50_20230406_10Kdataset.pth', help='Weights resumed in testing and evaluation') # weight path #parser.add_argument('--resume', type=str, default='weights_dota/model_last_resnet34_20230326_10Kdataset.pth', help='Weights resumed in testing and evaluation') # weight path #parser.add_argument('--resume', type=str, default='weights_dota/model_last_resnet101_20230315.pth', help='Weights resumed in testing and evaluation') # weight path #parser.add_argument('--resume', type=str, default='model_last_resnet18_20230421_10K495dataset.pth', help='Weights resumed in testing and evaluation') # weight path #parser.add_argument('--resume', type=str, default='model_last.pth', help='Weights resumed in testing and evaluation') # weight path parser.add_argument('--dataset', type=str, default='dota', help='Name of dataset') parser.add_argument('--data_dir', type=str, default='./dataPath', help='Data directory') # parser.add_argument('--data_dir', type=str, default='../Datasets/dota', help='Data directory') #parser.add_argument('--phase', type=str, default='train', help='Phase choice= {train, test, eval}') parser.add_argument('--phase', type=str, default='test', help='Phase choice= {train, test, eval}') #parser.add_argument('--phase', type=str, default='eval', help='Phase choice= {train, test, eval}') parser.add_argument('--wh_channels', type=int, default=8, help='Number of channels for the vectors (4x2)') #yuan 8 args = parser.parse_args() return args if __name__ == '__main__': args = parse_args() dataset = {'dota': DOTA, 'hrsc': HRSC} # num_classes = {'dota': 15, 'hrsc': 1} num_classes = {'dota': 1, 'hrsc': 1} heads = {'hm': num_classes[args.dataset], 'wh': 10, 'reg': 2, 'cls_theta': 1 } down_ratio = 4 model = ctrbox_net.CTRBOX(heads=heads, pretrained=True, down_ratio=down_ratio, final_kernel=1, head_conv=256) decoder = decoder.DecDecoder(K=args.K, conf_thresh=args.conf_thresh, num_classes=num_classes[args.dataset]) import time #T1 = time.time() if args.phase == 'train': ctrbox_obj = train.TrainModule(dataset=dataset, num_classes=num_classes, model=model, decoder=decoder, down_ratio=down_ratio) ctrbox_obj.train_network(args) elif args.phase == 'test': ctrbox_obj = test.TestModule(dataset=dataset, num_classes=num_classes, model=model, decoder=decoder) T1 = time.time() ctrbox_obj.test(args, down_ratio=down_ratio) else: ctrbox_obj = eval.EvalModule(dataset=dataset, num_classes=num_classes, model=model, decoder=decoder) ctrbox_obj.evaluation(args, down_ratio=down_ratio) T2 = time.time() print('程序总运行时间:%s毫秒' % ((T2 - T1) * 1000))