""" 数据库连接管理 """ import pymysql from contextlib import contextmanager import os import logging from pathlib import Path logger = logging.getLogger(__name__) # 尝试加载.env文件 try: from dotenv import load_dotenv # 从backend目录或项目根目录查找.env文件 backend_dir = Path(__file__).parent.parent project_root = backend_dir.parent # 按优先级查找.env文件 env_files = [ backend_dir / '.env', # backend/.env (优先) project_root / '.env', # 项目根目录/.env ] loaded = False for env_file in env_files: if env_file.exists(): load_dotenv(env_file, override=True) logger.info(f"已加载环境变量文件: {env_file}") loaded = True break if not loaded: # 如果都不存在,尝试自动查找(不报错) load_dotenv(project_root / '.env', override=False) except ImportError: # 如果没有安装python-dotenv,跳过 logger.debug("python-dotenv未安装,跳过.env文件加载") except Exception as e: logger.warning(f"加载.env文件失败: {e}") class Database: """数据库连接类""" def __init__(self): self.host = os.getenv('DB_HOST', 'localhost') self.port = int(os.getenv('DB_PORT', 3306)) self.user = os.getenv('DB_USER', 'root') self.password = os.getenv('DB_PASSWORD', '') self.database = os.getenv('DB_NAME', 'auto_trade_sys') # 记录配置信息(不显示密码) logger.debug(f"数据库配置: host={self.host}, port={self.port}, user={self.user}, database={self.database}") @contextmanager def get_connection(self): """获取数据库连接(上下文管理器)""" conn = None try: conn = pymysql.connect( host=self.host, port=self.port, user=self.user, password=self.password, database=self.database, charset='utf8mb4', cursorclass=pymysql.cursors.DictCursor, autocommit=False ) # 设置时区为北京时间(UTC+8) with conn.cursor() as cursor: cursor.execute("SET time_zone = '+08:00'") conn.commit() yield conn except Exception as e: if conn: conn.rollback() logger.error(f"数据库连接错误: {e}") raise finally: if conn: conn.close() def execute_query(self, query, params=None): """执行查询,返回所有结果""" with self.get_connection() as conn: with conn.cursor() as cursor: cursor.execute(query, params) conn.commit() return cursor.fetchall() def execute_one(self, query, params=None): """执行查询,返回单条结果""" with self.get_connection() as conn: with conn.cursor() as cursor: cursor.execute(query, params) conn.commit() return cursor.fetchone() def execute_update(self, query, params=None): """执行更新,返回影响行数""" with self.get_connection() as conn: with conn.cursor() as cursor: affected = cursor.execute(query, params) conn.commit() return affected def execute_many(self, query, params_list): """批量执行""" with self.get_connection() as conn: with conn.cursor() as cursor: affected = cursor.executemany(query, params_list) conn.commit() return affected # 全局数据库实例 db = Database()