168 lines
8.4 KiB
Python
168 lines
8.4 KiB
Python
# import shutil
|
||
# from collections import OrderedDict # collections模块中的OrderedDict类
|
||
# import glob
|
||
# import os
|
||
# import cv2
|
||
# import numpy as np
|
||
# import torch
|
||
# import sys # 引入某一模块的方法
|
||
# sys.path.append("../") # 为了导入上级目录的d2lzh_pytorch.py,添加一个新路径
|
||
# import time
|
||
#
|
||
#
|
||
# class Saver(object):
|
||
#
|
||
# def __init__(self, args):
|
||
# self.args = args
|
||
# self.directory = os.path.join('runs', args.dataset, args.checkname) # 路径拼接:runs\pascal\deeplab-resnet
|
||
# self.runs = sorted(glob.glob(os.path.join(self.directory, 'experiment_*'))) # 搜索并排序(默认升序)
|
||
# run_id = int(self.runs[-1].split('_')[-1]) + 1 if self.runs else 0 # split() 通过指定分隔符对字符串进行切片,run_id=?,从0开始加1
|
||
#
|
||
# # self.experiment_dir = os.path.join(self.directory, 'experiment_{}'.format(str(run_id))) # runs\pascal\deeplab-resnet\experiment_?
|
||
# self.experiment_dir = args.output_dir # 自行设置的输出目录
|
||
# if not os.path.exists(self.experiment_dir):
|
||
# os.makedirs(self.experiment_dir) # 生成该路径下的目录
|
||
#
|
||
# # def save_checkpoint(self, state, is_best, filename='checkpoint.pth'):
|
||
# # """Saves checkpoint to disk"""
|
||
# # filename = os.path.join(self.experiment_dir, filename) # runs\pascal\deeplab-resnet\experiment_?\checkpoint.pth.tar
|
||
# # torch.save(state, filename) # 生成checkpoint.pth
|
||
# # if is_best:
|
||
# # best_pred = state['best_pred']
|
||
# # epoch = state['epoch']
|
||
# # str_ = ("%15.5g;" * 2) % (epoch, best_pred)
|
||
# # with open(os.path.join(self.experiment_dir, 'best_pred.txt'), 'a') as f:
|
||
# # f.write(str_+'\n')
|
||
# # # if self.runs:
|
||
# # # previous_miou = [0.0]
|
||
# # # for run in self.runs:
|
||
# # # run_id = run.split('_')[-1]
|
||
# # # path = os.path.join(self.directory, 'experiment_{}'.format(str(run_id)), 'best_pred.txt')
|
||
# # # if os.path.exists(path):
|
||
# # # with open(path, 'r') as f:
|
||
# # # miou = float(f.readline())
|
||
# # # previous_miou.append(miou)
|
||
# # # else:
|
||
# # # continue
|
||
# # # max_miou = max(previous_miou)
|
||
# # # if best_pred > max_miou:
|
||
# # # shutil.copyfile(filename, os.path.join(self.directory, 'model_best.pth'))#全局最佳模型
|
||
# # # else:
|
||
# # # shutil.copyfile(filename, os.path.join(self.directory, 'model_best.pth'))
|
||
#
|
||
# def save_val_result(self, epoch, Acc, Acc_class, mIoU, class_IoU, FWIoU, recall, precision, f1):
|
||
# str_ = ("%15.5g;" * 13) % (epoch, Acc, Acc_class, mIoU, FWIoU, class_IoU[0], class_IoU[1], recall[0], recall[1], precision[0], precision[1], f1[0], f1[1]) # txt保存指标
|
||
# with open(os.path.join(self.experiment_dir, 'val_result.txt'), 'a') as f: # 这句话自带文件关闭功能,所以和那些先open再write再close的方式来说,更加pythontic!
|
||
# f.write(str_ + '\n')
|
||
#
|
||
# # def save_experiment_config(self, num_pictures):
|
||
# # logfile = os.path.join(self.experiment_dir, 'parameters.txt') # runs\pascal\deeplab-resnet\experiment_?\parameters.txt
|
||
# # log_file = open(logfile, 'w')
|
||
# # p = OrderedDict() # 创建实例对象
|
||
# # # 字典能够将信息关联起来,但它们不记录键值对的顺序。OrederedDict实例的行为与字典相同,区别在于记录了添加的键值对的顺序。
|
||
# # p['datset'] = self.args.dataset
|
||
# # p['backbone'] = self.args.backbone
|
||
# # p['out_stride'] = self.args.out_stride
|
||
# # p['lr'] = self.args.lr
|
||
# # p['lr_scheduler'] = self.args.lr_scheduler
|
||
# # p['loss_type'] = self.args.loss_type
|
||
# # p['epoch'] = self.args.epochs
|
||
# # p['base_size'] = self.args.base_size
|
||
# # p['crop_size'] = self.args.crop_size
|
||
# # p['batch_size'] = self.args.batch_size
|
||
# # p['num_pictures'] = num_pictures
|
||
# #
|
||
# # for key, val in p.items():
|
||
# # log_file.write(key + ':' + str(val) + '\n')
|
||
# # log_file.close()
|
||
#
|
||
# # def predict_save_images(self, model, args, epoch, label_info, test_loader, pathName):
|
||
# def predict_save_images(self, model, args, epoch, label_info, test_loader):
|
||
# # print('调用成功了')
|
||
# # if not args.dataset=='potsdam':
|
||
# # csv_path = os.path.join('path', args.dataset, 'class_dict.csv')
|
||
#
|
||
# # else:
|
||
# # csv_path = os.path.join('path/ISPRS', args.dataset, 'class_dict.csv')
|
||
# # label_info = get_label_info(csv_path)
|
||
# # print(test_loader)
|
||
# # Time_model = []
|
||
# # Time_test = []
|
||
# # time00 = time.time()
|
||
# # cnt_list = []
|
||
# # for (sample, name, WH) in test_loader:
|
||
# for (sample, name, WH) in test_loader: #封装在CbySegmentation.__getitem__里,需调整
|
||
# bs = len(name) #name里是图名,WH是高宽
|
||
# # cnt_list.append(bs)
|
||
# # begin1 = time.time()
|
||
# # image = sample[0] #取第一个tensor是原图,第二个是mask
|
||
# image = sample #取第一个tensor是原图,第二个是mask 这里要送进去四维的,batch为1。将mask去掉了
|
||
# # print('sample_shape',sample.shape)
|
||
# # print('image_shape',image.shape)
|
||
# # print('sample',sample)
|
||
# # print('name',name)
|
||
# # print('WH',WH)
|
||
# model.eval()
|
||
# if args.cuda:
|
||
# image = image.cuda()
|
||
# # begin2 = time.time()
|
||
# with torch.no_grad():
|
||
# predict = model(image)
|
||
# # end2 = time.time()
|
||
# # time_model = end2 - begin2
|
||
# # print('batchTime:%.3f ms, each:%.3f , bs:%d ' % (time_model*1000.0, time_model*1000.0/(bs * 1.0), bs))
|
||
#
|
||
# # predict=torch.squeeze(predict)
|
||
#
|
||
# predict = predict.data.cpu().numpy()
|
||
# predict = np.argmax(predict, axis=1)
|
||
#
|
||
# label_values = [label_info[key] for key in label_info]
|
||
# colour_codes = np.array(label_values)
|
||
#
|
||
# predict = colour_codes[predict.astype(int)]
|
||
#
|
||
# # crop_size恢复到原图尺寸
|
||
# for ii in range(bs):
|
||
# # print('line120:',WH)
|
||
# w, h = WH[0][ii], WH[1][ii]
|
||
# # w,h=WH
|
||
# predict_one = cv2.resize(predict[ii], (int(w), int(h)), interpolation=cv2.INTER_NEAREST)
|
||
# # save_path = os.path.join(self.experiment_dir, pathName)
|
||
# save_path = self.experiment_dir
|
||
# if not os.path.exists(save_path):
|
||
# os.makedirs(save_path) # 生成该路径下的目录
|
||
# # save_path = os.path.join(save_path, '%d_' % epoch+name[ii])
|
||
# save_path = os.path.join(save_path, name[ii])
|
||
# # print('save_path',save_path)
|
||
# # print('epoch',epoch)
|
||
# # print('name[ii]',name[ii])
|
||
#
|
||
# cv2.imwrite(save_path, cv2.cvtColor(np.uint8(predict_one), cv2.COLOR_RGB2BGR)) # 保存图片
|
||
# # end1 = time.time()
|
||
# # time_test = end1 - begin1
|
||
#
|
||
# # print('time test: batchTime:%.3f ms, one Time:%.3f, bs:%d '%(time_test*1000.0,time_test*1000.0/(bs*1.0),bs))
|
||
# # Time_model.append(time_model)
|
||
# # Time_test.append(time_test)
|
||
# # time11 = time.time()
|
||
#
|
||
# # Max_model = max(Time_model) # 原始
|
||
# # Min_model = min(Time_model) # 原始
|
||
# # Max_test = max(Time_test) # 原始
|
||
# # Min_test = min(Time_test) # 原始
|
||
# #
|
||
# # cnt_sample = sum(cnt_list)
|
||
# # ave_model = np.mean(Time_model)
|
||
# # ave_test = np.mean(Time_test)
|
||
# # print()
|
||
# # print('each model: ave:%.3f ms bs:%d' % (sum(Time_model)*1000.0/cnt_sample, cnt_list[0]))
|
||
# # print('bacthc inference:max:%.3f ms ,min:%3f ms,ave:%3f ms bs:%d ' % (Max_model*1000.0, Min_model*1000.0, ave_model*1000.0, cnt_list[1])) # 原始
|
||
# # print('All task total time:%.3f s' % (time11-time00))
|
||
# #
|
||
# # print('ave_mo del:max:%.3f ms ,min:%.3f ms,ave:%.3f ms'%(Max_test*1000.0,Min_test*1000.0,ave_test*1000.0)) # 原始
|
||
# return sample,predict_one
|
||
#
|
||
#
|
||
#
|