tuoheng_AIPlatform/AI_web_dsj/functions/traindata.py

62 lines
2.0 KiB
Python
Executable File

import os
import random
import yaml
from shutil import copyfile
def split_dataset(data_dir, train_ratio=0.8):
# 获取所有支持的图像文件(jpg和png)
image_files = [f for f in os.listdir(data_dir)
if f.endswith('.jpg') or f.endswith('.png')]
# 检查对应的txt文件是否存在
valid_files = []
for img_file in image_files:
# 根据图像扩展名确定对应的txt文件名
if img_file.endswith('.jpg'):
txt = img_file.replace('.jpg', '.txt')
else: # .png
txt = img_file.replace('.png', '.txt')
if os.path.exists(os.path.join(data_dir, txt)):
valid_files.append(img_file)
# 随机打乱
random.shuffle(valid_files)
# 划分训练集和验证集
split_idx = int(len(valid_files) * train_ratio)
train_files = valid_files[:split_idx]
val_files = valid_files[split_idx:]
return train_files, val_files
def write_file_list(file_list, output_file, base_dir):
with open(output_file, 'w') as f:
for file in file_list:
f.write(os.path.join(base_dir, file) + '\n')
def copy_files_to_folders(files, src_dir, dest_dir):
"""复制图片和标签文件到目标文件夹"""
os.makedirs(dest_dir, exist_ok=True)
for file in files:
# 复制图片
src_img = os.path.join(src_dir, file)
dest_img = os.path.join(dest_dir, file)
copyfile(src_img, dest_img)
# 动态替换扩展名
src_txt = os.path.join(src_dir, file.replace('.jpg', '.txt').replace('.png', '.txt'))
dest_txt = os.path.join(dest_dir, file.replace('.jpg', '.txt').replace('.png', '.txt'))
copyfile(src_txt, dest_txt)
def update_yaml(yaml_file, train_txt, val_txt, nc, names):
with open(yaml_file, 'r') as f:
data = yaml.safe_load(f)
data['train'] = train_txt
data['val'] = val_txt
data['nc'] = nc
data['names'] = names
with open(yaml_file, 'w') as f:
yaml.dump(data, f, default_flow_style=False)