You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
155 lines
4.2 KiB
155 lines
4.2 KiB
#!/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
# @Time : 2023/6/22 19:53
|
|
# @Author : old tom
|
|
# @File : session.py
|
|
# @Project : futool-db-lite
|
|
# @Desc :
|
|
import threading
|
|
|
|
import sqlalchemy
|
|
from sqlalchemy import Engine
|
|
from executor.sql_executor import SQLExecutor
|
|
from transaction.connect_transaction import TransactionFactory
|
|
|
|
|
|
class CreateSessionError(Exception):
|
|
def __init__(self, msg):
|
|
Exception.__init__(self, msg)
|
|
|
|
|
|
class SqlSessionCache(object):
|
|
"""
|
|
利用thread-local 实现线程隔离及单个线程获取到的session一致
|
|
"""
|
|
|
|
def __init__(self, create_session_func):
|
|
self.create_session_func = create_session_func
|
|
self.cache = threading.local()
|
|
|
|
def __call__(self):
|
|
try:
|
|
if self.has_session():
|
|
return self.cache.value
|
|
else:
|
|
val = self.cache.value = self.create_session_func()
|
|
return val
|
|
except Exception as e:
|
|
raise CreateSessionError(msg=f'从cache获取session异常,e={e}')
|
|
|
|
def put_session(self, session):
|
|
self.cache.value = session
|
|
|
|
def has_session(self) -> bool:
|
|
return hasattr(self.cache, "value")
|
|
|
|
def remove_session(self):
|
|
try:
|
|
del self.cache.value
|
|
except AttributeError:
|
|
pass
|
|
|
|
|
|
class SqlsessionFactory(object):
|
|
"""
|
|
工厂模式创建sqlSession
|
|
"""
|
|
|
|
def __init__(self, engine: Engine):
|
|
self._engine = engine
|
|
self.cache = SqlSessionCache(self.create_session)
|
|
|
|
def open_session(self):
|
|
"""
|
|
从缓存获取,如果没有会自动调用create_session创建
|
|
"""
|
|
return self.cache()
|
|
|
|
def create_session(self):
|
|
try:
|
|
conn = self._engine.connect()
|
|
tx_factory = TransactionFactory(conn)
|
|
tx = tx_factory.create_transaction()
|
|
executor = SQLExecutor(tx)
|
|
return Sqlsession(executor, self.cache)
|
|
except Exception as e:
|
|
raise CreateSessionError(msg=f'创建sqlSession异常,e={e}')
|
|
|
|
|
|
class Sqlsession(object):
|
|
def __init__(self, executor: SQLExecutor, session_cache: SqlSessionCache):
|
|
self.executor = executor
|
|
self.session_cache = session_cache
|
|
|
|
def select_one(self, sql):
|
|
return self.executor.query(sql).fetchone()
|
|
|
|
def select_many(self, sql, size):
|
|
return self.executor.query(sql).fetchmany(size)
|
|
|
|
def select_all(self, sql):
|
|
return self.executor.query(sql).fetchall()
|
|
|
|
def insert(self, sql):
|
|
return self._execute_with_tx(sql)
|
|
|
|
def insert_batch(self, sql_template, data, batch_size=1000):
|
|
"""
|
|
批量插入实现
|
|
:param data: [{"id":1, "value":"v1"}, {"id":2, "value":"v2"}] 或者 [(1,'v1'),(2,'v2')]
|
|
:param sql_template: "INSERT INTO table (id, value) VALUES (:id, :value)"
|
|
:param batch_size: 批量插入size
|
|
:return:
|
|
"""
|
|
if data and isinstance(data, list) and len(data) > 0:
|
|
# 特殊处理SQLAlchemy row对象
|
|
if isinstance(data[0], sqlalchemy.engine.row.Row):
|
|
# 转为元组
|
|
data = [tuple(x) for x in data]
|
|
return self.executor.execute_batch(sql_template, data, batch_size)
|
|
|
|
def update(self, sql):
|
|
return self._execute_with_tx(sql)
|
|
|
|
def delete(self, sql):
|
|
return self._execute_with_tx(sql)
|
|
|
|
def _execute_with_tx(self, sql):
|
|
try:
|
|
self.begin_transaction()
|
|
rt = self.executor.execute_update(sql)
|
|
self.commit()
|
|
return rt
|
|
except Exception as e:
|
|
self.rollback()
|
|
print(e)
|
|
|
|
def begin_transaction(self):
|
|
"""
|
|
开启事务
|
|
:return:
|
|
"""
|
|
self.executor.get_transaction().begin_transaction()
|
|
|
|
def commit(self):
|
|
"""
|
|
提交
|
|
:return:
|
|
"""
|
|
self.executor.get_transaction().commit()
|
|
|
|
def rollback(self):
|
|
"""
|
|
回滚
|
|
:return:
|
|
"""
|
|
self.executor.get_transaction().rollback()
|
|
|
|
def close(self):
|
|
"""
|
|
关闭连接
|
|
:return:
|
|
"""
|
|
self.executor.get_connection().close()
|
|
self.session_cache.remove_session()
|