You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

155 lines
4.2 KiB

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2023/6/22 19:53
# @Author : old tom
# @File : session.py
# @Project : futool-db-lite
# @Desc :
import threading
import sqlalchemy
from sqlalchemy import Engine
from executor.sql_executor import SQLExecutor
from transaction.connect_transaction import TransactionFactory
class CreateSessionError(Exception):
def __init__(self, msg):
Exception.__init__(self, msg)
class SqlSessionCache(object):
"""
利用thread-local 实现线程隔离及单个线程获取到的session一致
"""
def __init__(self, create_session_func):
self.create_session_func = create_session_func
self.cache = threading.local()
def __call__(self):
try:
if self.has_session():
return self.cache.value
else:
val = self.cache.value = self.create_session_func()
return val
except Exception as e:
raise CreateSessionError(msg=f'从cache获取session异常,e={e}')
def put_session(self, session):
self.cache.value = session
def has_session(self) -> bool:
return hasattr(self.cache, "value")
def remove_session(self):
try:
del self.cache.value
except AttributeError:
pass
class SqlsessionFactory(object):
"""
工厂模式创建sqlSession
"""
def __init__(self, engine: Engine):
self._engine = engine
self.cache = SqlSessionCache(self.create_session)
def open_session(self):
"""
从缓存获取,如果没有会自动调用create_session创建
"""
return self.cache()
def create_session(self):
try:
conn = self._engine.connect()
tx_factory = TransactionFactory(conn)
tx = tx_factory.create_transaction()
executor = SQLExecutor(tx)
return Sqlsession(executor, self.cache)
except Exception as e:
raise CreateSessionError(msg=f'创建sqlSession异常,e={e}')
class Sqlsession(object):
def __init__(self, executor: SQLExecutor, session_cache: SqlSessionCache):
self.executor = executor
self.session_cache = session_cache
def select_one(self, sql):
return self.executor.query(sql).fetchone()
def select_many(self, sql, size):
return self.executor.query(sql).fetchmany(size)
def select_all(self, sql):
return self.executor.query(sql).fetchall()
def insert(self, sql):
return self._execute_with_tx(sql)
def insert_batch(self, sql_template, data, batch_size=1000):
"""
批量插入实现
:param data: [{"id":1, "value":"v1"}, {"id":2, "value":"v2"}] 或者 [(1,'v1'),(2,'v2')]
:param sql_template: "INSERT INTO table (id, value) VALUES (:id, :value)"
:param batch_size: 批量插入size
:return:
"""
if data and isinstance(data, list) and len(data) > 0:
# 特殊处理SQLAlchemy row对象
if isinstance(data[0], sqlalchemy.engine.row.Row):
# 转为元组
data = [tuple(x) for x in data]
return self.executor.execute_batch(sql_template, data, batch_size)
def update(self, sql):
return self._execute_with_tx(sql)
def delete(self, sql):
return self._execute_with_tx(sql)
def _execute_with_tx(self, sql):
try:
self.begin_transaction()
rt = self.executor.execute_update(sql)
self.commit()
return rt
except Exception as e:
self.rollback()
print(e)
def begin_transaction(self):
"""
开启事务
:return:
"""
self.executor.get_transaction().begin_transaction()
def commit(self):
"""
提交
:return:
"""
self.executor.get_transaction().commit()
def rollback(self):
"""
回滚
:return:
"""
self.executor.get_transaction().rollback()
def close(self):
"""
关闭连接
:return:
"""
self.executor.get_connection().close()
self.session_cache.remove_session()