#!/usr/bin/env python # -*- coding: utf-8 -*- # @Time : 2023/6/22 19:53 # @Author : old tom # @File : session.py # @Project : futool-db-lite # @Desc : import threading import sqlalchemy from sqlalchemy import Engine from executor.sql_executor import SQLExecutor from transaction.connect_transaction import TransactionFactory class CreateSessionError(Exception): def __init__(self, msg): Exception.__init__(self, msg) class SqlSessionCache(object): """ 利用thread-local 实现线程隔离及单个线程获取到的session一致 """ def __init__(self, create_session_func): self.create_session_func = create_session_func self.cache = threading.local() def __call__(self): try: if self.has_session(): return self.cache.value else: val = self.cache.value = self.create_session_func() return val except Exception as e: raise CreateSessionError(msg=f'从cache获取session异常,e={e}') def put_session(self, session): self.cache.value = session def has_session(self) -> bool: return hasattr(self.cache, "value") def remove_session(self): try: del self.cache.value except AttributeError: pass class Sqlsession(object): def __init__(self, executor: SQLExecutor, session_cache: SqlSessionCache): self.executor = executor self.session_cache = session_cache def select_one(self, sql): return self.executor.query(sql).fetchone() def select_many(self, sql, size): return self.executor.query(sql).fetchmany(size) def select_all(self, sql): return self.executor.query(sql).fetchall() def select_by_stream(self, sql, stream_size): """ 流式查询,循环取数据 :param sql: :param stream_size: 每次返回量 for partition in result.partitions(): # partition is an iterable that will be at most stream_size items for row in partition: print(f"{row}") :return: """ return self.executor.query_by_stream(sql, stream_size) def insert(self, sql): return self._execute_with_tx(sql) def insert_batch(self, sql_template, data, batch_size=1000): """ 批量插入实现 :param data: [{"id":1, "value":"v1"}, {"id":2, "value":"v2"}] 或者 [(1,'v1'),(2,'v2')] :param sql_template: "INSERT INTO table (id, value) VALUES (:id, :value)" :param batch_size: 批量插入size :return: """ if data and isinstance(data, list) and len(data) > 0: # 特殊处理SQLAlchemy row对象 if isinstance(data[0], sqlalchemy.engine.row.Row): # 转为元组 data = [tuple(x) for x in data] return self.executor.execute_batch(sql_template, data, batch_size) def update(self, sql): return self._execute_with_tx(sql) def delete(self, sql): return self._execute_with_tx(sql) def _execute_with_tx(self, sql): try: self.begin_transaction() rt = self.executor.execute_update(sql) self.commit() return rt except Exception as e: self.rollback() print(e) def begin_transaction(self): """ 开启事务 :return: """ self.executor.get_transaction().begin_transaction() def commit(self): """ 提交 :return: """ self.executor.get_transaction().commit() def rollback(self): """ 回滚 :return: """ self.executor.get_transaction().rollback() def close(self): """ 关闭连接 :return: """ self.executor.get_connection().close() self.session_cache.remove_session() class SqlsessionFactory(object): """ 工厂模式创建sqlSession """ def __init__(self, engine: Engine): self._engine = engine self.cache = SqlSessionCache(self.create_session) def open_session(self) -> Sqlsession: """ 从缓存获取,如果没有会自动调用create_session创建 """ return self.cache() def create_session(self) -> Sqlsession: try: conn = self._engine.connect() tx_factory = TransactionFactory(conn) tx = tx_factory.create_transaction() executor = SQLExecutor(tx) return Sqlsession(executor, self.cache) except Exception as e: raise CreateSessionError(msg=f'创建sqlSession异常,e={e}')