84 lines
3.1 KiB
Python
84 lines
3.1 KiB
Python
|
|
import torch
|
|||
|
|
import os
|
|||
|
|
from pathlib import Path
|
|||
|
|
|
|||
|
|
# 配置参数
|
|||
|
|
model_dir = Path('/home/th/jcq/AI_AutoPlat/AI_web_dsj/ultralytics') # 使用Path对象更安全
|
|||
|
|
model_path = Path('/home/th/jcq/AI_AutoPlat/yolov5-th/yolov5/yolov5s.pt') # 自定义模型权重路径
|
|||
|
|
device = 'cuda:0' if torch.cuda.is_available() else 'cpu' # 自动选择设备
|
|||
|
|
|
|||
|
|
def load_model_offline():
|
|||
|
|
"""离线环境专用模型加载函数"""
|
|||
|
|
try:
|
|||
|
|
# 1. 验证本地文件是否存在
|
|||
|
|
if not model_dir.exists():
|
|||
|
|
raise FileNotFoundError(f"YOLOv5目录不存在: {model_dir}")
|
|||
|
|
if not model_path.exists():
|
|||
|
|
raise FileNotFoundError(f"模型权重文件不存在: {model_path}")
|
|||
|
|
|
|||
|
|
# 2. 确保本地仓库是完整可用的
|
|||
|
|
required_files = ['models', 'utils', 'hubconf.py']
|
|||
|
|
for f in required_files:
|
|||
|
|
if not (model_dir / f).exists():
|
|||
|
|
raise FileNotFoundError(f"YOLOv5仓库不完整,缺失: {f}")
|
|||
|
|
|
|||
|
|
# 3. 强制使用本地加载(禁用任何网络尝试)
|
|||
|
|
os.environ['GITHUB_ASSETS'] = 'off' # 禁用github资源下载
|
|||
|
|
torch.hub.set_dir(str(model_dir.parent)) # 设置hub缓存目录为本地
|
|||
|
|
|
|||
|
|
# 4. 加载模型(完全离线模式)
|
|||
|
|
model = torch.hub.load(
|
|||
|
|
repo_or_dir=str(model_dir),
|
|||
|
|
model='custom',
|
|||
|
|
path=str(model_path),
|
|||
|
|
source='local',
|
|||
|
|
force_reload=False,
|
|||
|
|
skip_validation=True,
|
|||
|
|
device=device,
|
|||
|
|
_verbose=False # 禁用hub的详细输出
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 5. 验证模型加载成功
|
|||
|
|
if not hasattr(model, 'names'):
|
|||
|
|
raise RuntimeError("模型加载异常:缺少关键属性")
|
|||
|
|
|
|||
|
|
print(f"✅ 离线模型加载成功!设备: {device}")
|
|||
|
|
print(f"模型类别: {model.names}")
|
|||
|
|
return model
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"❌ 加载失败: {type(e).__name__}: {e}")
|
|||
|
|
# 详细错误诊断
|
|||
|
|
if isinstance(e, ModuleNotFoundError):
|
|||
|
|
print("\n⚠️ 可能缺少依赖包,请在联网环境执行:")
|
|||
|
|
print(f"pip install -r {model_dir/'requirements.txt'}")
|
|||
|
|
elif isinstance(e, RuntimeError) and "CUDA" in str(e):
|
|||
|
|
print("\n⚠️ CUDA不可用,正在自动切换到CPU模式...")
|
|||
|
|
return load_model_offline_cpu()
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
def load_model_offline_cpu():
|
|||
|
|
"""强制CPU模式重试"""
|
|||
|
|
try:
|
|||
|
|
model = torch.hub.load(
|
|||
|
|
str(model_dir),
|
|||
|
|
'custom',
|
|||
|
|
path=str(model_path),
|
|||
|
|
source='local',
|
|||
|
|
device='cpu'
|
|||
|
|
)
|
|||
|
|
print("✅ 回退到CPU模式加载成功")
|
|||
|
|
return model
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"❌ CPU模式也加载失败: {e}")
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
# 执行加载
|
|||
|
|
model = load_model_offline()
|
|||
|
|
|
|||
|
|
# 使用示例
|
|||
|
|
if model:
|
|||
|
|
img = torch.zeros((1, 3, 640, 640)) # 测试张量
|
|||
|
|
results = model(img)
|
|||
|
|
print(f"推理测试完成!检测结果: {results}")
|