100 lines
2.8 KiB
Python
100 lines
2.8 KiB
Python
|
|
#!/usr/bin/env python3
|
|||
|
|
"""init db"""
|
|||
|
|
|
|||
|
|
import sys
|
|||
|
|
import os
|
|||
|
|
|
|||
|
|
|
|||
|
|
def find_project_root():
|
|||
|
|
"""智能查找项目根目录"""
|
|||
|
|
current_dir = os.path.abspath(os.getcwd())
|
|||
|
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
|||
|
|
|
|||
|
|
# 可能的项目根目录位置
|
|||
|
|
possible_roots = [
|
|||
|
|
current_dir, # 当前工作目录
|
|||
|
|
script_dir, # 脚本所在目录
|
|||
|
|
os.path.dirname(script_dir), # 脚本父目录
|
|||
|
|
os.path.dirname(os.path.dirname(script_dir)) # 脚本祖父目录
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
for root in possible_roots:
|
|||
|
|
backend_dir = os.path.join(root, 'backend')
|
|||
|
|
if os.path.exists(backend_dir) and os.path.exists(os.path.join(backend_dir, 'th_agenter')):
|
|||
|
|
return root, backend_dir
|
|||
|
|
|
|||
|
|
raise FileNotFoundError("无法找到项目根目录和backend目录")
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 查找项目根目录和backend目录
|
|||
|
|
project_root, backend_dir = find_project_root()
|
|||
|
|
|
|||
|
|
# 添加backend目录到Python路径
|
|||
|
|
sys.path.insert(0, backend_dir)
|
|||
|
|
|
|||
|
|
# 保存原始工作目录
|
|||
|
|
original_cwd = os.getcwd()
|
|||
|
|
|
|||
|
|
# 设置工作目录为backend,以便找到.env文件
|
|||
|
|
os.chdir(backend_dir)
|
|||
|
|
|
|||
|
|
from th_agenter.db.database import get_db, init_db
|
|||
|
|
from th_agenter.services.user import UserService
|
|||
|
|
from th_agenter.utils.schemas import UserCreate
|
|||
|
|
import asyncio
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def create_database_tables():
|
|||
|
|
"""Create all database tables using SQLAlchemy models."""
|
|||
|
|
try:
|
|||
|
|
await init_db()
|
|||
|
|
print('Database tables created successfully using SQLAlchemy models')
|
|||
|
|
return True
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f'Error creating database tables: {e}')
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def create_test_user():
|
|||
|
|
"""Create a test user."""
|
|||
|
|
# First, create all database tables using SQLAlchemy models
|
|||
|
|
if not await create_database_tables():
|
|||
|
|
print('Failed to create database tables')
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
db = next(get_db())
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
user_service = UserService(db)
|
|||
|
|
|
|||
|
|
# Create test user
|
|||
|
|
user_data = UserCreate(
|
|||
|
|
username='test',
|
|||
|
|
email='test@example.com',
|
|||
|
|
password='123456',
|
|||
|
|
full_name='Test User 1'
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Check if user already exists
|
|||
|
|
existing_user = user_service.get_user_by_email(user_data.email)
|
|||
|
|
if existing_user:
|
|||
|
|
print(f'User already exists: {existing_user.username} ({existing_user.email})')
|
|||
|
|
return existing_user
|
|||
|
|
|
|||
|
|
# Create new user
|
|||
|
|
user = user_service.create_user(user_data)
|
|||
|
|
print(f'Created user: {user.username} ({user.email})')
|
|||
|
|
return user
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f'Error creating user: {e}')
|
|||
|
|
return None
|
|||
|
|
finally:
|
|||
|
|
db.close()
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
try:
|
|||
|
|
asyncio.run(create_test_user())
|
|||
|
|
finally:
|
|||
|
|
# 恢复原始工作目录
|
|||
|
|
os.chdir(original_cwd)
|