100 lines
3.4 KiB
Python
Executable File
100 lines
3.4 KiB
Python
Executable File
import os
|
|
import random
|
|
import yaml
|
|
from shutil import copyfile
|
|
|
|
def split_dataset(data_dir, train_ratio=0.8):
|
|
# 获取所有jpg文件
|
|
jpg_files = [f for f in os.listdir(data_dir) if f.endswith('.jpg')]
|
|
|
|
# 检查对应的txt文件是否存在
|
|
valid_files = []
|
|
for jpg in jpg_files:
|
|
txt = jpg.replace('.jpg', '.txt')
|
|
if os.path.exists(os.path.join(data_dir, txt)):
|
|
valid_files.append(jpg)
|
|
|
|
# 随机打乱
|
|
random.shuffle(valid_files)
|
|
|
|
# 划分训练集和验证集
|
|
split_idx = int(len(valid_files) * train_ratio)
|
|
train_files = valid_files[:split_idx]
|
|
val_files = valid_files[split_idx:]
|
|
|
|
return train_files, val_files
|
|
|
|
def write_file_list(file_list, output_file, base_dir):
|
|
with open(output_file, 'w') as f:
|
|
for file in file_list:
|
|
f.write(os.path.join(base_dir, file) + '\n')
|
|
|
|
def copy_files_to_folders(files, src_dir, dest_dir):
|
|
"""复制图片和标签文件到目标文件夹"""
|
|
os.makedirs(dest_dir, exist_ok=True)
|
|
for file in files:
|
|
# 复制图片
|
|
src_img = os.path.join(src_dir, file)
|
|
dest_img = os.path.join(dest_dir, file)
|
|
copyfile(src_img, dest_img)
|
|
|
|
# 复制对应的标签文件
|
|
src_txt = os.path.join(src_dir, file.replace('.jpg', '.txt'))
|
|
dest_txt = os.path.join(dest_dir, file.replace('.jpg', '.txt'))
|
|
copyfile(src_txt, dest_txt)
|
|
|
|
def update_yaml(yaml_file, train_txt, val_txt, nc, names):
|
|
with open(yaml_file, 'r') as f:
|
|
data = yaml.safe_load(f)
|
|
|
|
data['train'] = train_txt
|
|
data['val'] = val_txt
|
|
data['nc'] = nc
|
|
data['names'] = names
|
|
|
|
with open(yaml_file, 'w') as f:
|
|
yaml.dump(data, f, default_flow_style=False)
|
|
|
|
# def main():
|
|
# # 配置参数
|
|
# data_dir = "/home/th/jcq/AI_AutoPlat/DATA/TrainDatasets/027123"
|
|
# output_dir = "/home/th/jcq/AI_AutoPlat/DATA/TrainDatasets/027123_temp"
|
|
|
|
# # yaml_file = "/home/thsw/WJ/jcq/yolov5-th/data/road.yaml" # 替换为你的yaml文件路径
|
|
# yaml_file = "/home/thsw/WJ/jcq/yolov5-th/data/cl.yaml" # 替换为你的yaml文件路径
|
|
|
|
# # 确保输出目录存在
|
|
# os.makedirs(output_dir, exist_ok=True)
|
|
|
|
# # 创建 train 和 val 文件夹
|
|
# train_dir = os.path.join(data_dir, "train")
|
|
# val_dir = os.path.join(data_dir, "val")
|
|
# os.makedirs(train_dir, exist_ok=True)
|
|
# os.makedirs(val_dir, exist_ok=True)
|
|
|
|
# # 划分数据集
|
|
# train_files, val_files = split_dataset(data_dir)
|
|
|
|
# # 复制文件到 train/ 和 val/ 文件夹
|
|
# copy_files_to_folders(train_files, data_dir, train_dir)
|
|
# copy_files_to_folders(val_files, data_dir, val_dir)
|
|
|
|
# # 写入训练集和验证集txt文件
|
|
# train_txt = os.path.join(data_dir, "train.txt")
|
|
# val_txt = os.path.join(data_dir, "val.txt")
|
|
|
|
# write_file_list(train_files, train_txt, train_dir)
|
|
# write_file_list(val_files, val_txt, val_dir)
|
|
|
|
# # 更新yaml文件
|
|
# nc = 5
|
|
# names = ['flag', 'buoy', 'shipname', 'ship', 'uncover']
|
|
# update_yaml(yaml_file, train_txt, val_txt, nc, names)
|
|
|
|
# print(f"数据集划分完成,训练集: {len(train_files)} 张,验证集: {len(val_files)} 张")
|
|
# print(f"训练集列表已写入: {train_txt}")
|
|
# print(f"验证集列表已写入: {val_txt}")
|
|
# print(f"YAML文件已更新: {yaml_file}")
|
|
|
|
# if __name__ == "__main__":
|
|
# main() |