落水人员检测
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

168 lines
8.4KB

  1. # import shutil
  2. # from collections import OrderedDict # collections模块中的OrderedDict类
  3. # import glob
  4. # import os
  5. # import cv2
  6. # import numpy as np
  7. # import torch
  8. # import sys # 引入某一模块的方法
  9. # sys.path.append("../") # 为了导入上级目录的d2lzh_pytorch.py,添加一个新路径
  10. # import time
  11. #
  12. #
  13. # class Saver(object):
  14. #
  15. # def __init__(self, args):
  16. # self.args = args
  17. # self.directory = os.path.join('runs', args.dataset, args.checkname) # 路径拼接:runs\pascal\deeplab-resnet
  18. # self.runs = sorted(glob.glob(os.path.join(self.directory, 'experiment_*'))) # 搜索并排序(默认升序)
  19. # run_id = int(self.runs[-1].split('_')[-1]) + 1 if self.runs else 0 # split() 通过指定分隔符对字符串进行切片,run_id=?,从0开始加1
  20. #
  21. # # self.experiment_dir = os.path.join(self.directory, 'experiment_{}'.format(str(run_id))) # runs\pascal\deeplab-resnet\experiment_?
  22. # self.experiment_dir = args.output_dir # 自行设置的输出目录
  23. # if not os.path.exists(self.experiment_dir):
  24. # os.makedirs(self.experiment_dir) # 生成该路径下的目录
  25. #
  26. # # def save_checkpoint(self, state, is_best, filename='checkpoint.pth'):
  27. # # """Saves checkpoint to disk"""
  28. # # filename = os.path.join(self.experiment_dir, filename) # runs\pascal\deeplab-resnet\experiment_?\checkpoint.pth.tar
  29. # # torch.save(state, filename) # 生成checkpoint.pth
  30. # # if is_best:
  31. # # best_pred = state['best_pred']
  32. # # epoch = state['epoch']
  33. # # str_ = ("%15.5g;" * 2) % (epoch, best_pred)
  34. # # with open(os.path.join(self.experiment_dir, 'best_pred.txt'), 'a') as f:
  35. # # f.write(str_+'\n')
  36. # # # if self.runs:
  37. # # # previous_miou = [0.0]
  38. # # # for run in self.runs:
  39. # # # run_id = run.split('_')[-1]
  40. # # # path = os.path.join(self.directory, 'experiment_{}'.format(str(run_id)), 'best_pred.txt')
  41. # # # if os.path.exists(path):
  42. # # # with open(path, 'r') as f:
  43. # # # miou = float(f.readline())
  44. # # # previous_miou.append(miou)
  45. # # # else:
  46. # # # continue
  47. # # # max_miou = max(previous_miou)
  48. # # # if best_pred > max_miou:
  49. # # # shutil.copyfile(filename, os.path.join(self.directory, 'model_best.pth'))#全局最佳模型
  50. # # # else:
  51. # # # shutil.copyfile(filename, os.path.join(self.directory, 'model_best.pth'))
  52. #
  53. # def save_val_result(self, epoch, Acc, Acc_class, mIoU, class_IoU, FWIoU, recall, precision, f1):
  54. # str_ = ("%15.5g;" * 13) % (epoch, Acc, Acc_class, mIoU, FWIoU, class_IoU[0], class_IoU[1], recall[0], recall[1], precision[0], precision[1], f1[0], f1[1]) # txt保存指标
  55. # with open(os.path.join(self.experiment_dir, 'val_result.txt'), 'a') as f: # 这句话自带文件关闭功能,所以和那些先open再write再close的方式来说,更加pythontic!
  56. # f.write(str_ + '\n')
  57. #
  58. # # def save_experiment_config(self, num_pictures):
  59. # # logfile = os.path.join(self.experiment_dir, 'parameters.txt') # runs\pascal\deeplab-resnet\experiment_?\parameters.txt
  60. # # log_file = open(logfile, 'w')
  61. # # p = OrderedDict() # 创建实例对象
  62. # # # 字典能够将信息关联起来,但它们不记录键值对的顺序。OrederedDict实例的行为与字典相同,区别在于记录了添加的键值对的顺序。
  63. # # p['datset'] = self.args.dataset
  64. # # p['backbone'] = self.args.backbone
  65. # # p['out_stride'] = self.args.out_stride
  66. # # p['lr'] = self.args.lr
  67. # # p['lr_scheduler'] = self.args.lr_scheduler
  68. # # p['loss_type'] = self.args.loss_type
  69. # # p['epoch'] = self.args.epochs
  70. # # p['base_size'] = self.args.base_size
  71. # # p['crop_size'] = self.args.crop_size
  72. # # p['batch_size'] = self.args.batch_size
  73. # # p['num_pictures'] = num_pictures
  74. # #
  75. # # for key, val in p.items():
  76. # # log_file.write(key + ':' + str(val) + '\n')
  77. # # log_file.close()
  78. #
  79. # # def predict_save_images(self, model, args, epoch, label_info, test_loader, pathName):
  80. # def predict_save_images(self, model, args, epoch, label_info, test_loader):
  81. # # print('调用成功了')
  82. # # if not args.dataset=='potsdam':
  83. # # csv_path = os.path.join('path', args.dataset, 'class_dict.csv')
  84. #
  85. # # else:
  86. # # csv_path = os.path.join('path/ISPRS', args.dataset, 'class_dict.csv')
  87. # # label_info = get_label_info(csv_path)
  88. # # print(test_loader)
  89. # # Time_model = []
  90. # # Time_test = []
  91. # # time00 = time.time()
  92. # # cnt_list = []
  93. # # for (sample, name, WH) in test_loader:
  94. # for (sample, name, WH) in test_loader: #封装在CbySegmentation.__getitem__里,需调整
  95. # bs = len(name) #name里是图名,WH是高宽
  96. # # cnt_list.append(bs)
  97. # # begin1 = time.time()
  98. # # image = sample[0] #取第一个tensor是原图,第二个是mask
  99. # image = sample #取第一个tensor是原图,第二个是mask 这里要送进去四维的,batch为1。将mask去掉了
  100. # # print('sample_shape',sample.shape)
  101. # # print('image_shape',image.shape)
  102. # # print('sample',sample)
  103. # # print('name',name)
  104. # # print('WH',WH)
  105. # model.eval()
  106. # if args.cuda:
  107. # image = image.cuda()
  108. # # begin2 = time.time()
  109. # with torch.no_grad():
  110. # predict = model(image)
  111. # # end2 = time.time()
  112. # # time_model = end2 - begin2
  113. # # print('batchTime:%.3f ms, each:%.3f , bs:%d ' % (time_model*1000.0, time_model*1000.0/(bs * 1.0), bs))
  114. #
  115. # # predict=torch.squeeze(predict)
  116. #
  117. # predict = predict.data.cpu().numpy()
  118. # predict = np.argmax(predict, axis=1)
  119. #
  120. # label_values = [label_info[key] for key in label_info]
  121. # colour_codes = np.array(label_values)
  122. #
  123. # predict = colour_codes[predict.astype(int)]
  124. #
  125. # # crop_size恢复到原图尺寸
  126. # for ii in range(bs):
  127. # # print('line120:',WH)
  128. # w, h = WH[0][ii], WH[1][ii]
  129. # # w,h=WH
  130. # predict_one = cv2.resize(predict[ii], (int(w), int(h)), interpolation=cv2.INTER_NEAREST)
  131. # # save_path = os.path.join(self.experiment_dir, pathName)
  132. # save_path = self.experiment_dir
  133. # if not os.path.exists(save_path):
  134. # os.makedirs(save_path) # 生成该路径下的目录
  135. # # save_path = os.path.join(save_path, '%d_' % epoch+name[ii])
  136. # save_path = os.path.join(save_path, name[ii])
  137. # # print('save_path',save_path)
  138. # # print('epoch',epoch)
  139. # # print('name[ii]',name[ii])
  140. #
  141. # cv2.imwrite(save_path, cv2.cvtColor(np.uint8(predict_one), cv2.COLOR_RGB2BGR)) # 保存图片
  142. # # end1 = time.time()
  143. # # time_test = end1 - begin1
  144. #
  145. # # print('time test: batchTime:%.3f ms, one Time:%.3f, bs:%d '%(time_test*1000.0,time_test*1000.0/(bs*1.0),bs))
  146. # # Time_model.append(time_model)
  147. # # Time_test.append(time_test)
  148. # # time11 = time.time()
  149. #
  150. # # Max_model = max(Time_model) # 原始
  151. # # Min_model = min(Time_model) # 原始
  152. # # Max_test = max(Time_test) # 原始
  153. # # Min_test = min(Time_test) # 原始
  154. # #
  155. # # cnt_sample = sum(cnt_list)
  156. # # ave_model = np.mean(Time_model)
  157. # # ave_test = np.mean(Time_test)
  158. # # print()
  159. # # print('each model: ave:%.3f ms bs:%d' % (sum(Time_model)*1000.0/cnt_sample, cnt_list[0]))
  160. # # print('bacthc inference:max:%.3f ms ,min:%3f ms,ave:%3f ms bs:%d ' % (Max_model*1000.0, Min_model*1000.0, ave_model*1000.0, cnt_list[1])) # 原始
  161. # # print('All task total time:%.3f s' % (time11-time00))
  162. # #
  163. # # print('ave_mo del:max:%.3f ms ,min:%.3f ms,ave:%.3f ms'%(Max_test*1000.0,Min_test*1000.0,ave_test*1000.0)) # 原始
  164. # return sample,predict_one
  165. #
  166. #
  167. #