auto_trade_sys/backend/database/connection.py
薇薇安 fccc9b2717 a
2026-01-17 10:07:31 +08:00

125 lines
4.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
数据库连接管理
"""
import pymysql
from contextlib import contextmanager
import os
import logging
from pathlib import Path
logger = logging.getLogger(__name__)
# 尝试加载.env文件
try:
from dotenv import load_dotenv
# 从backend目录或项目根目录查找.env文件
backend_dir = Path(__file__).parent.parent
project_root = backend_dir.parent
# 按优先级查找.env文件
env_files = [
backend_dir / '.env', # backend/.env (优先)
project_root / '.env', # 项目根目录/.env
]
loaded = False
for env_file in env_files:
if env_file.exists():
load_dotenv(env_file, override=True)
logger.info(f"已加载环境变量文件: {env_file}")
loaded = True
break
if not loaded:
# 如果都不存在,尝试自动查找(不报错)
load_dotenv(project_root / '.env', override=False)
except ImportError:
# 如果没有安装python-dotenv跳过
logger.debug("python-dotenv未安装跳过.env文件加载")
except Exception as e:
logger.warning(f"加载.env文件失败: {e}")
class Database:
"""数据库连接类"""
def __init__(self):
self.host = os.getenv('DB_HOST', 'localhost')
self.port = int(os.getenv('DB_PORT', 3306))
self.user = os.getenv('DB_USER', 'root')
self.password = os.getenv('DB_PASSWORD', '')
self.database = os.getenv('DB_NAME', 'auto_trade_sys')
# 记录配置信息(不显示密码)
logger.debug(f"数据库配置: host={self.host}, port={self.port}, user={self.user}, database={self.database}")
@contextmanager
def get_connection(self):
"""获取数据库连接(上下文管理器)"""
conn = None
try:
conn = pymysql.connect(
host=self.host,
port=self.port,
user=self.user,
password=self.password,
database=self.database,
charset='utf8mb4',
cursorclass=pymysql.cursors.DictCursor,
autocommit=False
)
# 设置时区为北京时间UTC+8
with conn.cursor() as cursor:
cursor.execute("SET time_zone = '+08:00'")
conn.commit()
yield conn
except Exception as e:
if conn:
conn.rollback()
logger.error(f"数据库连接错误: {e}")
raise
finally:
if conn:
conn.close()
def execute_query(self, query, params=None):
"""执行查询,返回所有结果"""
with self.get_connection() as conn:
with conn.cursor() as cursor:
cursor.execute(query, params)
conn.commit()
return cursor.fetchall()
def execute_one(self, query, params=None):
"""执行查询,返回单条结果"""
with self.get_connection() as conn:
with conn.cursor() as cursor:
cursor.execute(query, params)
conn.commit()
return cursor.fetchone()
def execute_update(self, query, params=None):
"""执行更新,返回影响行数"""
try:
with self.get_connection() as conn:
with conn.cursor() as cursor:
affected = cursor.execute(query, params)
conn.commit()
return affected
except Exception as e:
# 重新抛出异常,让调用者处理(如 update_exit 中的异常处理)
# 不要在这里记录为"数据库连接错误",因为这可能是业务逻辑错误(如唯一约束冲突)
raise
def execute_many(self, query, params_list):
"""批量执行"""
with self.get_connection() as conn:
with conn.cursor() as cursor:
affected = cursor.executemany(query, params_list)
conn.commit()
return affected
# 全局数据库实例
db = Database()