tuoheng_AIPlatform/AI_web_dsj/test/yolov5_test.py

84 lines
3.1 KiB
Python
Raw Normal View History

2025-07-15 10:01:04 +08:00
import torch
import os
from pathlib import Path
# 配置参数
model_dir = Path('/home/th/jcq/AI_AutoPlat/AI_web_dsj/ultralytics') # 使用Path对象更安全
model_path = Path('/home/th/jcq/AI_AutoPlat/yolov5-th/yolov5/yolov5s.pt') # 自定义模型权重路径
device = 'cuda:0' if torch.cuda.is_available() else 'cpu' # 自动选择设备
def load_model_offline():
"""离线环境专用模型加载函数"""
try:
# 1. 验证本地文件是否存在
if not model_dir.exists():
raise FileNotFoundError(f"YOLOv5目录不存在: {model_dir}")
if not model_path.exists():
raise FileNotFoundError(f"模型权重文件不存在: {model_path}")
# 2. 确保本地仓库是完整可用的
required_files = ['models', 'utils', 'hubconf.py']
for f in required_files:
if not (model_dir / f).exists():
raise FileNotFoundError(f"YOLOv5仓库不完整缺失: {f}")
# 3. 强制使用本地加载(禁用任何网络尝试)
os.environ['GITHUB_ASSETS'] = 'off' # 禁用github资源下载
torch.hub.set_dir(str(model_dir.parent)) # 设置hub缓存目录为本地
# 4. 加载模型(完全离线模式)
model = torch.hub.load(
repo_or_dir=str(model_dir),
model='custom',
path=str(model_path),
source='local',
force_reload=False,
skip_validation=True,
device=device,
_verbose=False # 禁用hub的详细输出
)
# 5. 验证模型加载成功
if not hasattr(model, 'names'):
raise RuntimeError("模型加载异常:缺少关键属性")
print(f"✅ 离线模型加载成功!设备: {device}")
print(f"模型类别: {model.names}")
return model
except Exception as e:
print(f"❌ 加载失败: {type(e).__name__}: {e}")
# 详细错误诊断
if isinstance(e, ModuleNotFoundError):
print("\n⚠️ 可能缺少依赖包,请在联网环境执行:")
print(f"pip install -r {model_dir/'requirements.txt'}")
elif isinstance(e, RuntimeError) and "CUDA" in str(e):
print("\n⚠️ CUDA不可用正在自动切换到CPU模式...")
return load_model_offline_cpu()
return None
def load_model_offline_cpu():
"""强制CPU模式重试"""
try:
model = torch.hub.load(
str(model_dir),
'custom',
path=str(model_path),
source='local',
device='cpu'
)
print("✅ 回退到CPU模式加载成功")
return model
except Exception as e:
print(f"❌ CPU模式也加载失败: {e}")
return None
if __name__ == "__main__":
# 执行加载
model = load_model_offline()
# 使用示例
if model:
img = torch.zeros((1, 3, 640, 640)) # 测试张量
results = model(img)
print(f"推理测试完成!检测结果: {results}")