#!/usr/bin/env python # -*- coding: utf-8 -*- # @Time : 2023/6/22 19:53 # @Author : old tom # @File : session.py # @Project : futool-db-lite # @Desc : import threading from sqlalchemy import Engine from executor.sql_executor import SQLExecutor from transaction.connect_transaction import TransactionFactory class CreateSessionError(Exception): def __init__(self, msg): Exception.__init__(self, msg) class SqlSessionCache(object): """ 利用thread-local 实现线程隔离及单个线程获取到的session一致 """ def __init__(self, create_session_func): self.create_session_func = create_session_func self.cache = threading.local() def __call__(self): try: if self.has_session(): return self.cache.value else: val = self.cache.value = self.create_session_func() return val except Exception as e: raise CreateSessionError(msg=f'从cache获取session异常,e={e}') def put_session(self, session): self.cache.value = session def has_session(self) -> bool: return hasattr(self.cache, "value") def remove_session(self): try: del self.cache.value except AttributeError: pass class SqlsessionFactory(object): """ 工厂模式创建sqlSession """ def __init__(self, engine: Engine): self._engine = engine self.cache = SqlSessionCache(self.create_session) def open_session(self): """ 从缓存获取,如果没有会自动调用create_session创建 """ return self.cache() def create_session(self): try: conn = self._engine.connect() tx_factory = TransactionFactory(conn) tx = tx_factory.create_transaction() executor = SQLExecutor(tx) return Sqlsession(executor, self.cache) except Exception as e: raise CreateSessionError(msg=f'创建sqlSession异常,e={e}') class Sqlsession(object): def __init__(self, executor: SQLExecutor, session_cache: SqlSessionCache): self.executor = executor self.session_cache = session_cache def select_one(self, sql): return self.executor.query(sql).fetchone() def select_many(self, sql, size): return self.executor.query(sql).fetchmany(size) def select_all(self, sql): return self.executor.query(sql).fetchall() def insert(self, sql): return self.executor.execute_update(sql) def insert_batch(self, sql): """ todo implement """ pass def update(self, sql): return self.executor.execute_update(sql) def delete(self, sql): return self.executor.execute_update(sql) def begin_transaction(self): """ 开启事务 :return: """ self.executor.get_transaction().begin_transaction() def commit(self): """ 提交 :return: """ self.executor.get_transaction().commit() def rollback(self): """ 回滚 :return: """ self.executor.get_transaction().rollback() def close(self): """ 关闭连接 :return: """ self.executor.get_connection().close() self.session_cache.remove_session()