diff --git a/config/db_config.py b/config/db_config.py index 7e6f71d..09208db 100644 --- a/config/db_config.py +++ b/config/db_config.py @@ -7,7 +7,6 @@ # @Desc : import os from pydantic import BaseModel - import toml # 默认配置文件名 diff --git a/main.py b/main.py index d720931..b003b61 100644 --- a/main.py +++ b/main.py @@ -6,13 +6,17 @@ # @Project : futool-db-lite # @Desc : 数据库测试 from session.engine.dbengine import DatabaseEngine +from session.engine.dbengine import container if __name__ == '__main__': db_engine = DatabaseEngine('testdb') + db_engine2 = DatabaseEngine('testdb') + print(db_engine.engine is db_engine2.engine) session_factory = db_engine.create_session_factory() - sql_session = session_factory.open_session() - # print(sql_session.select_all('select * from metadata_object')) - sql_session.select_one('select * from metadata_object') - sql_session.select_many('select * from metadata_object', 2) - # print(sql_session.delete('delete from metadata_object where meta_id=1657707190871527425')) - sql_session.close() + session_factory2 = db_engine2.create_session_factory() + print(session_factory is session_factory2) + # for i in range(100): + # # sess = session_factory.create_session() + # sess = session_factory.open_session() + # print(id(sess)) + print(container.engines, container.factory) diff --git a/session/engine/dbengine.py b/session/engine/dbengine.py index ed942b0..3431ba4 100644 --- a/session/engine/dbengine.py +++ b/session/engine/dbengine.py @@ -9,6 +9,7 @@ from urllib.parse import quote_plus as urlquote from sqlalchemy import create_engine from config.db_config import DbConfigLoader, DEFAULT_CONF_PATH from session.session import SqlsessionFactory +from util.futool_lang import str_md5 class DatabaseEngineError(Exception): @@ -16,6 +17,32 @@ class DatabaseEngineError(Exception): Exception.__init__(self, msg) +class ObjContainer(object): + # DatabaseEngine容器 + engines = {} + # SqlsessionFactory容器 + factory = {} + + def has_obj(self, attr, k): + return k in getattr(self, attr) + + def get_obj(self, attr, k): + return getattr(self, attr)[k] + + def put_obj(self, attr, k, v): + getattr(self, attr)[k] = v + + def remove_obj(self, attr, k): + del getattr(self, attr)[k] + + def clean(self): + self.engines = {} + self.factory = {} + + +container = ObjContainer() + + class DatabaseEngine(object): """ 数据库连接 @@ -28,6 +55,8 @@ class DatabaseEngine(object): 'oracle': 'oracle+cx_oracle://{0}:{1}@{2}:{3}/?service_name={4}' } + attr = 'engines' + def __init__(self, db_name, conf_path=DEFAULT_CONF_PATH): """ pool_size:连接池大小 @@ -42,12 +71,28 @@ class DatabaseEngine(object): raise DatabaseEngineError(msg='不支持的数据库类型') # urlquote 处理密码中的特殊字符 url = self.DB_URL[conf.dialect].format(conf.user, urlquote(conf.passwd), conf.host, conf.port, conf.database) - self.engine = create_engine(url, pool_size=conf.pool_size, pool_recycle=conf.pool_recycle, pool_pre_ping=True, - echo=conf.show_sql) + # URL生成唯一ID + self.engine_id = str_md5(url) + # 根据数据源ID判断是否新建连接池 + if container.has_obj(self.attr, self.engine_id): + self.engine = container.get_obj(self.attr, self.engine_id) + else: + engine = create_engine(url, pool_size=conf.pool_size, pool_recycle=conf.pool_recycle, pool_pre_ping=True, + echo=conf.show_sql) + container.put_obj(self.attr, self.engine_id, engine) + self.engine = engine def create_session_factory(self): """ 创建session工厂 + 相同数据库URL下全局只有一个SqlsessionFactory,除非有其他数据源 + 根据engine_id判断是否新建对象 :return: """ - return SqlsessionFactory(self.engine) + attr = 'factory' + if container.has_obj(attr, self.engine_id): + return container.get_obj(attr, self.engine_id) + else: + factory = SqlsessionFactory(self.engine) + container.put_obj(attr, self.engine_id, factory) + return factory diff --git a/session/session.py b/session/session.py index c5fb72b..9cf2c3e 100644 --- a/session/session.py +++ b/session/session.py @@ -5,6 +5,7 @@ # @File : session.py # @Project : futool-db-lite # @Desc : +import threading from sqlalchemy import Engine from executor.sql_executor import SQLExecutor from transaction.connect_transaction import TransactionFactory @@ -15,6 +16,38 @@ class CreateSessionError(Exception): Exception.__init__(self, msg) +class SqlSessionCache(object): + """ + 利用thread-local 实现线程隔离及单个线程获取到的session一致 + """ + + def __init__(self, create_session_func): + self.create_session_func = create_session_func + self.cache = threading.local() + + def __call__(self): + try: + if self.has_session(): + return self.cache.value + else: + val = self.cache.value = self.create_session_func() + return val + except Exception as e: + raise CreateSessionError(msg=f'从cache获取session异常,e={e}') + + def put_session(self, session): + self.cache.value = session + + def has_session(self) -> bool: + return hasattr(self.cache, "value") + + def remove_session(self): + try: + del self.cache.value + except AttributeError: + pass + + class SqlsessionFactory(object): """ 工厂模式创建sqlSession @@ -22,21 +55,29 @@ class SqlsessionFactory(object): def __init__(self, engine: Engine): self._engine = engine + self.cache = SqlSessionCache(self.create_session) def open_session(self): + """ + 从缓存获取,如果没有会自动调用create_session创建 + """ + return self.cache() + + def create_session(self): try: conn = self._engine.connect() tx_factory = TransactionFactory(conn) tx = tx_factory.create_transaction() executor = SQLExecutor(tx) - return Sqlsession(executor) + return Sqlsession(executor, self.cache) except Exception as e: raise CreateSessionError(msg=f'创建sqlSession异常,e={e}') class Sqlsession(object): - def __init__(self, executor: SQLExecutor): + def __init__(self, executor: SQLExecutor, session_cache: SqlSessionCache): self.executor = executor + self.session_cache = session_cache def select_one(self, sql): return self.executor.query(sql).fetchone() @@ -83,3 +124,4 @@ class Sqlsession(object): :return: """ self.executor.get_connection().close() + self.session_cache.remove_session() diff --git a/util/__init__.py b/util/__init__.py new file mode 100644 index 0000000..de6bc71 --- /dev/null +++ b/util/__init__.py @@ -0,0 +1,7 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Time : 2023/6/23 17:29 +# @Author : old tom +# @File : __init__.py.py +# @Project : futool-db-0.1 +# @Desc : diff --git a/util/futool_lang.py b/util/futool_lang.py new file mode 100644 index 0000000..f8e481d --- /dev/null +++ b/util/futool_lang.py @@ -0,0 +1,17 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Time : 2023/6/23 17:30 +# @Author : old tom +# @File : futool_lang.py +# @Project : futool-db-0.1 +# @Desc : +import hashlib + + +def str_md5(content): + """ + 字符串转MD5 + """ + md5 = hashlib.md5() + md5.update(content.encode(encoding='utf-8')) + return md5.hexdigest()