You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

134 lines
3.3 KiB

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2023/6/22 19:53
# @Author : old tom
# @File : session.py
# @Project : futool-db-lite
# @Desc :
import threading
from sqlalchemy import Engine
from db.executor.sql_executor import SQLExecutor
from db.transaction.connect_transaction import TransactionFactory
class CreateSessionError(Exception):
def __init__(self, msg):
Exception.__init__(self, msg)
class SqlSessionCache(object):
"""
利用thread-local 实现线程隔离及单个线程获取到的session一致
"""
def __init__(self, create_session_func):
self.create_session_func = create_session_func
self.cache = threading.local()
def __call__(self):
try:
if self.has_session():
return self.cache.value
else:
val = self.cache.value = self.create_session_func()
return val
except Exception as e:
raise CreateSessionError(msg=f'从cache获取session异常,e={e}')
def put_session(self, session):
self.cache.value = session
def has_session(self) -> bool:
return hasattr(self.cache, "value")
def remove_session(self):
try:
del self.cache.value
except AttributeError:
pass
class SqlsessionFactory(object):
"""
工厂模式创建sqlSession
"""
def __init__(self, engine: Engine):
self._engine = engine
self.cache = SqlSessionCache(self.create_session)
def open_session(self):
"""
从缓存获取,如果没有会自动调用create_session创建
"""
return self.cache()
def create_session(self):
try:
conn = self._engine.connect()
tx_factory = TransactionFactory(conn)
tx = tx_factory.create_transaction()
executor = SQLExecutor(tx)
return Sqlsession(executor, self.cache)
except Exception as e:
raise CreateSessionError(msg=f'创建sqlSession异常,e={e}')
class Sqlsession(object):
def __init__(self, executor: SQLExecutor, session_cache: SqlSessionCache):
self.executor = executor
self.session_cache = session_cache
def select_one(self, sql):
return self.executor.query(sql).fetchone()
def select_many(self, sql, size):
return self.executor.query(sql).fetchmany(size)
def select_all(self, sql):
return self.executor.query(sql).fetchall()
def insert(self, sql):
return self.executor.execute_update(sql)
def insert_batch(self, sql):
"""
todo implement
"""
pass
def update(self, sql):
return self.executor.execute_update(sql)
def delete(self, sql):
return self.executor.execute_update(sql)
def begin_transaction(self):
"""
开启事务
:return:
"""
self.executor.get_transaction().begin_transaction()
def commit(self):
"""
提交
:return:
"""
self.executor.get_transaction().commit()
def rollback(self):
"""
回滚
:return:
"""
self.executor.get_transaction().rollback()
def close(self):
"""
关闭连接
:return:
"""
self.executor.get_connection().close()
self.session_cache.remove_session()