You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

221 lines
6.0 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2023/4/4 17:29
# @Author : old tom
# @File : fu_db_api.py
# @Project : futool-db
# @Desc :
from sqlalchemy import Connection, text, CursorResult
from common.futool.core.fu_collection import split_coll
class SqlExecuteError(Exception):
def __init__(self, msg=''):
Exception.__init__(self, msg)
def _select(conn: Connection, sql) -> CursorResult:
"""
自动关闭连接
:param conn:
:return:
"""
try:
return conn.execute(text(sql))
finally:
conn.close()
def _execute_with_tx(conn: Connection, sql):
"""
带事务执行SQL
:return:
"""
try:
conn.begin()
rt = conn.execute(text(sql))
conn.commit()
return rt.rowcount
except Exception as e:
conn.rollback()
raise SqlExecuteError(msg=f'sql [{sql}] 执行失败,开始回滚,e={e}')
finally:
conn.close()
def execute_many_sql_with_tx(conn: Connection, sql_list):
try:
conn.begin()
for sql in sql_list:
conn.execute(text(sql))
conn.commit()
except Exception as e:
conn.rollback()
raise SqlExecuteError(msg=f'sql [{sql_list}] 执行失败,开始回滚,e={e}')
finally:
conn.close()
def select_one(conn: Connection, sql):
"""
查询一个
:param conn:
:param sql:
:return:
"""
return _select(conn, sql).fetchone()
def select_all(conn: Connection, sql):
"""
查询全部
:param conn:
:param sql:
:return:
"""
return _select(conn, sql).fetchall()
def count(conn: Connection, table):
"""
统计数据量
:param conn:
:param table:
:return:
"""
count_tpl = f'select count(1) from {table}'
return select_one(conn, count_tpl)[0]
def execute_update(conn: Connection, sql):
"""
带事务执行,可用于insert update delete 语句
:param conn:
:param sql:
:return: 受影响的行数与java-jdbc的execute_update返回true|false相似,可用于判断是否执行成功
"""
return _execute_with_tx(conn, sql)
def batch_insert(conn: Connection, db_type, sql_tpl, data, batch_size=1500):
"""
批量插入
将sql转为into oracle_table ( id, code ) values( 1 , '1' ),( 2 , '2' ),( 3 , '3' )
:param conn: 数据库连接
:param batch_size: 每次插入量
:param db_type: 数据库类型
:param sql_tpl: insert into t1 (f1,f2,f3) values %s
:param data: [(1,'tom',29),(2,'jack',30)]
:return:
"""
handler = BatchInsertHandler(db_type, sql_tpl, data, batch_size)
insert_sqls = handler.build_insert()
# 插入都在一个事务内
row_count = 0
try:
conn.begin()
for sql_set in insert_sqls:
rt = conn.execute(text(sql_set))
row_count += rt.rowcount
conn.commit()
return row_count
except Exception as e:
conn.rollback()
raise SqlExecuteError(msg=f"批量插入异常,e={e}")
finally:
conn.close()
class BatchInsertHandler(object):
"""
批量插入处理器
oracle :
insert all
into oracle_table ( id, code ) values( 1 , '1' )
into oracle_table ( id, code ) values( 2 , '2' )
into oracle_table ( id, code ) values( 3 , '3' )
into oracle_table ( id, code ) values( 4 , '4' )
select 1 from dual ;
postgresql and mysql
into oracle_table ( id, code ) values( 1 , '1' ),( 2 , '2' ),( 3 , '3' )
"""
BUILD_INSERT = {
'oracle': 'build_oracle_insert',
'postgresql': 'build_pg_insert',
'mysql': 'build_mysql_insert'
}
class NotSupportError(Exception):
def __init__(self, msg=''):
Exception.__init__(self, msg)
def __init__(self, db_type, sql_tpl, data, batch_size):
"""
:param db_type: 数据库类型
:param sql_tpl: pg及mysql: insert into t1 (f1,f2,f3) values %s
oracle: into t1 (f1,f2,f3) values %s
:param data: [(1,'tom',29),(2,'jack',30)]
:param batch_size:
"""
if db_type not in ['oracle', 'postgresql']:
raise self.NotSupportError()
self.db_type = db_type
self.sql_tpl = sql_tpl
self.data = data
self.batch_size = batch_size
def _split_data(self):
return split_coll(self.data, self.batch_size)
def build_insert(self):
data_set = self._split_data()
sql_set = []
for part in data_set:
sql_set.append(getattr(self, self.BUILD_INSERT[self.db_type])(part))
return sql_set
def build_oracle_insert(self, data_set):
begin = 'insert all \r '
for ds in data_set:
val = '('
for ele in ds:
val += self._field_value_convert(ele)
val = val[0:-1] + ')'
begin += (self.sql_tpl.replace('%s', val) + ' \r ')
end = 'select 1 from dual'
return begin + end
def build_pg_insert(self, data_set):
vals = ''
for ds in data_set:
val = '('
for ele in ds:
val += self._field_value_convert(ele)
val = val[0:-1] + ')'
vals += val + ','
return self.sql_tpl.replace('%s', vals[0:-1])
def build_mysql_insert(self, data_set):
return self.build_pg_insert(data_set)
@staticmethod
def _field_value_convert(field_val):
"""
字段类型转换
:param field_val: 字段值
:return:
"""
# None处理
if field_val is None:
field_val = ''
# 特殊字符处理
if isinstance(field_val, str) and "'" in field_val:
field_val = field_val.replace("'", '')
# TODO 不可见字符处理,查询创建SQL时会删除\r,导致SQL格式混乱
# if isinstance(field_val, str):
# field_val = ''.join(x for x in field_val if x.isprintable())
return "'" + field_val + "'," if isinstance(field_val, str) else str(field_val) + ','