120 lines
3.8 KiB
Python
120 lines
3.8 KiB
Python
"""
|
||
数据库连接管理
|
||
"""
|
||
import pymysql
|
||
from contextlib import contextmanager
|
||
import os
|
||
import logging
|
||
from pathlib import Path
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 尝试加载.env文件
|
||
try:
|
||
from dotenv import load_dotenv
|
||
# 从backend目录或项目根目录查找.env文件
|
||
backend_dir = Path(__file__).parent.parent
|
||
project_root = backend_dir.parent
|
||
|
||
# 按优先级查找.env文件
|
||
env_files = [
|
||
backend_dir / '.env', # backend/.env (优先)
|
||
project_root / '.env', # 项目根目录/.env
|
||
]
|
||
|
||
loaded = False
|
||
for env_file in env_files:
|
||
if env_file.exists():
|
||
load_dotenv(env_file, override=True)
|
||
logger.info(f"已加载环境变量文件: {env_file}")
|
||
loaded = True
|
||
break
|
||
|
||
if not loaded:
|
||
# 如果都不存在,尝试自动查找(不报错)
|
||
load_dotenv(project_root / '.env', override=False)
|
||
except ImportError:
|
||
# 如果没有安装python-dotenv,跳过
|
||
logger.debug("python-dotenv未安装,跳过.env文件加载")
|
||
except Exception as e:
|
||
logger.warning(f"加载.env文件失败: {e}")
|
||
|
||
|
||
class Database:
|
||
"""数据库连接类"""
|
||
|
||
def __init__(self):
|
||
self.host = os.getenv('DB_HOST', 'localhost')
|
||
self.port = int(os.getenv('DB_PORT', 3306))
|
||
self.user = os.getenv('DB_USER', 'root')
|
||
self.password = os.getenv('DB_PASSWORD', '')
|
||
self.database = os.getenv('DB_NAME', 'auto_trade_sys')
|
||
|
||
# 记录配置信息(不显示密码)
|
||
logger.debug(f"数据库配置: host={self.host}, port={self.port}, user={self.user}, database={self.database}")
|
||
|
||
@contextmanager
|
||
def get_connection(self):
|
||
"""获取数据库连接(上下文管理器)"""
|
||
conn = None
|
||
try:
|
||
conn = pymysql.connect(
|
||
host=self.host,
|
||
port=self.port,
|
||
user=self.user,
|
||
password=self.password,
|
||
database=self.database,
|
||
charset='utf8mb4',
|
||
cursorclass=pymysql.cursors.DictCursor,
|
||
autocommit=False
|
||
)
|
||
# 设置时区为北京时间(UTC+8)
|
||
with conn.cursor() as cursor:
|
||
cursor.execute("SET time_zone = '+08:00'")
|
||
conn.commit()
|
||
yield conn
|
||
except Exception as e:
|
||
if conn:
|
||
conn.rollback()
|
||
logger.error(f"数据库连接错误: {e}")
|
||
raise
|
||
finally:
|
||
if conn:
|
||
conn.close()
|
||
|
||
def execute_query(self, query, params=None):
|
||
"""执行查询,返回所有结果"""
|
||
with self.get_connection() as conn:
|
||
with conn.cursor() as cursor:
|
||
cursor.execute(query, params)
|
||
conn.commit()
|
||
return cursor.fetchall()
|
||
|
||
def execute_one(self, query, params=None):
|
||
"""执行查询,返回单条结果"""
|
||
with self.get_connection() as conn:
|
||
with conn.cursor() as cursor:
|
||
cursor.execute(query, params)
|
||
conn.commit()
|
||
return cursor.fetchone()
|
||
|
||
def execute_update(self, query, params=None):
|
||
"""执行更新,返回影响行数"""
|
||
with self.get_connection() as conn:
|
||
with conn.cursor() as cursor:
|
||
affected = cursor.execute(query, params)
|
||
conn.commit()
|
||
return affected
|
||
|
||
def execute_many(self, query, params_list):
|
||
"""批量执行"""
|
||
with self.get_connection() as conn:
|
||
with conn.cursor() as cursor:
|
||
affected = cursor.executemany(query, params_list)
|
||
conn.commit()
|
||
return affected
|
||
|
||
|
||
# 全局数据库实例
|
||
db = Database()
|