Ship_Tilt_Detection/main_for_val.py

83 lines
4.5 KiB
Python

import argparse
import train
import test
import eval
from datasets.dataset_dota import DOTA
from datasets.dataset_hrsc import HRSC
from models import ctrbox_net
import decoder
import os
import time
def parse_args():
parser = argparse.ArgumentParser(description='BBAVectors Implementation')
parser.add_argument('--num_epoch', type=int, default=300, help='Number of epochs')
parser.add_argument('--batch_size', type=int, default=8, help='Number of batch size')
parser.add_argument('--num_workers', type=int, default=4, help='Number of workers') # 原来是
parser.add_argument('--init_lr', type=float, default=1.25e-4, help='Initial learning rate')
parser.add_argument('--input_h', type=int, default=608, help='Resized image height')
parser.add_argument('--input_w', type=int, default=608, help='Resized image width')
parser.add_argument('--K', type=int, default=500, help='Maximum of objects')
parser.add_argument('--conf_thresh', type=float, default=0.18, help='Confidence threshold, 0.1 for general evaluation')
parser.add_argument('--ngpus', type=int, default=1, help='Number of gpus, ngpus>1 for multigpu')
parser.add_argument('--resume_train', type=str, default='', help='Weights resumed in training')
parser.add_argument('--resume', type=str, default='model_last_resnet18_20230421_10K495dataset.pth', help='Weights resumed in testing and evaluation') # weight path
#parser.add_argument('--resume', type=str, default='weights_dota/model_last_resnet50_20230406_10Kdataset.pth', help='Weights resumed in testing and evaluation') # weight path
#parser.add_argument('--resume', type=str, default='weights_dota/model_last_resnet34_20230326_10Kdataset.pth', help='Weights resumed in testing and evaluation') # weight path
#parser.add_argument('--resume', type=str, default='weights_dota/model_last_resnet101_20230315.pth', help='Weights resumed in testing and evaluation') # weight path
#parser.add_argument('--resume', type=str, default='model_resnet101_20230315.pth', help='Weights resumed in testing and evaluation') # weight path
#parser.add_argument('--resume', type=str, default='model_last.pth', help='Weights resumed in testing and evaluation') # weight path
parser.add_argument('--dataset', type=str, default='dota', help='Name of dataset')
parser.add_argument('--data_dir', type=str, default='./dataPath', help='Data directory')
# parser.add_argument('--data_dir', type=str, default='../Datasets/dota', help='Data directory')
#parser.add_argument('--phase', type=str, default='train', help='Phase choice= {train, test, eval}')
#parser.add_argument('--phase', type=str, default='test', help='Phase choice= {train, test, eval}')
parser.add_argument('--phase', type=str, default='eval', help='Phase choice= {train, test, eval}')
parser.add_argument('--wh_channels', type=int, default=8, help='Number of channels for the vectors (4x2)') #yuan 8
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
dataset = {'dota': DOTA, 'hrsc': HRSC}
# num_classes = {'dota': 15, 'hrsc': 1}
num_classes = {'dota': 1, 'hrsc': 1}
heads = {'hm': num_classes[args.dataset],
'wh': 10,
'reg': 2,
'cls_theta': 1
}
down_ratio = 4
model = ctrbox_net.CTRBOX(heads=heads,
pretrained=True,
down_ratio=down_ratio,
final_kernel=1,
head_conv=256)
decoder = decoder.DecDecoder(K=args.K,
conf_thresh=args.conf_thresh,
num_classes=num_classes[args.dataset])
import time
#T1 = time.time()
if args.phase == 'train':
ctrbox_obj = train.TrainModule(dataset=dataset,
num_classes=num_classes,
model=model,
decoder=decoder,
down_ratio=down_ratio)
ctrbox_obj.train_network(args)
elif args.phase == 'test':
ctrbox_obj = test.TestModule(dataset=dataset, num_classes=num_classes, model=model, decoder=decoder)
T1 = time.time()
ctrbox_obj.test(args, down_ratio=down_ratio)
else:
ctrbox_obj = eval.EvalModule(dataset=dataset, num_classes=num_classes, model=model, decoder=decoder)
ctrbox_obj.evaluation(args, down_ratio=down_ratio)
T2 = time.time()
print('程序总运行时间:%s毫秒' % ((T2 - T1) * 1000))