|
|
#!/usr/bin/env python
|
|
|
# -*- coding: utf-8 -*-
|
|
|
# @Time : 2023/4/4 17:29
|
|
|
# @Author : old tom
|
|
|
# @File : fu_db_api.py
|
|
|
# @Project : futool-db
|
|
|
# @Desc :
|
|
|
|
|
|
from sqlalchemy import Connection, text, CursorResult
|
|
|
from common.futool.core.fu_collection import split_coll
|
|
|
|
|
|
|
|
|
class SqlExecuteError(Exception):
|
|
|
def __init__(self, msg=''):
|
|
|
Exception.__init__(self, msg)
|
|
|
|
|
|
|
|
|
def _select(conn: Connection, sql) -> CursorResult:
|
|
|
"""
|
|
|
自动关闭连接
|
|
|
:param conn:
|
|
|
:return:
|
|
|
"""
|
|
|
try:
|
|
|
return conn.execute(text(sql))
|
|
|
finally:
|
|
|
conn.close()
|
|
|
|
|
|
|
|
|
def _execute_with_tx(conn: Connection, sql):
|
|
|
"""
|
|
|
带事务执行SQL
|
|
|
:return:
|
|
|
"""
|
|
|
try:
|
|
|
conn.begin()
|
|
|
rt = conn.execute(text(sql))
|
|
|
conn.commit()
|
|
|
return rt.rowcount
|
|
|
except Exception as e:
|
|
|
conn.rollback()
|
|
|
raise SqlExecuteError(msg=f'sql [{sql}] 执行失败,开始回滚,e={e}')
|
|
|
finally:
|
|
|
conn.close()
|
|
|
|
|
|
|
|
|
def execute_many_sql_with_tx(conn: Connection, sql_list):
|
|
|
try:
|
|
|
conn.begin()
|
|
|
for sql in sql_list:
|
|
|
conn.execute(text(sql))
|
|
|
conn.commit()
|
|
|
except Exception as e:
|
|
|
conn.rollback()
|
|
|
raise SqlExecuteError(msg=f'sql [{sql_list}] 执行失败,开始回滚,e={e}')
|
|
|
finally:
|
|
|
conn.close()
|
|
|
|
|
|
|
|
|
def select_one(conn: Connection, sql):
|
|
|
"""
|
|
|
查询一个
|
|
|
:param conn:
|
|
|
:param sql:
|
|
|
:return:
|
|
|
"""
|
|
|
return _select(conn, sql).fetchone()
|
|
|
|
|
|
|
|
|
def select_all(conn: Connection, sql):
|
|
|
"""
|
|
|
查询全部
|
|
|
:param conn:
|
|
|
:param sql:
|
|
|
:return:
|
|
|
"""
|
|
|
return _select(conn, sql).fetchall()
|
|
|
|
|
|
|
|
|
def count(conn: Connection, table):
|
|
|
"""
|
|
|
统计数据量
|
|
|
:param conn:
|
|
|
:param table:
|
|
|
:return:
|
|
|
"""
|
|
|
count_tpl = f'select count(1) from {table}'
|
|
|
return select_one(conn, count_tpl)[0]
|
|
|
|
|
|
|
|
|
def execute_update(conn: Connection, sql):
|
|
|
"""
|
|
|
带事务执行,可用于insert update delete 语句
|
|
|
:param conn:
|
|
|
:param sql:
|
|
|
:return: 受影响的行数,与java-jdbc的execute_update返回true|false相似,可用于判断是否执行成功
|
|
|
"""
|
|
|
return _execute_with_tx(conn, sql)
|
|
|
|
|
|
|
|
|
def batch_insert(conn: Connection, db_type, sql_tpl, data, batch_size=1500):
|
|
|
"""
|
|
|
批量插入
|
|
|
将sql转为into oracle_table ( id, code ) values( 1 , '1' ),( 2 , '2' ),( 3 , '3' )
|
|
|
:param conn: 数据库连接
|
|
|
:param batch_size: 每次插入量
|
|
|
:param db_type: 数据库类型
|
|
|
:param sql_tpl: insert into t1 (f1,f2,f3) values %s
|
|
|
:param data: [(1,'tom',29),(2,'jack',30)]
|
|
|
:return:
|
|
|
"""
|
|
|
handler = BatchInsertHandler(db_type, sql_tpl, data, batch_size)
|
|
|
insert_sqls = handler.build_insert()
|
|
|
# 插入都在一个事务内
|
|
|
row_count = 0
|
|
|
try:
|
|
|
conn.begin()
|
|
|
for sql_set in insert_sqls:
|
|
|
rt = conn.execute(text(sql_set))
|
|
|
row_count += rt.rowcount
|
|
|
conn.commit()
|
|
|
return row_count
|
|
|
except Exception as e:
|
|
|
conn.rollback()
|
|
|
raise SqlExecuteError(msg=f"批量插入异常,e={e}")
|
|
|
finally:
|
|
|
conn.close()
|
|
|
|
|
|
|
|
|
class BatchInsertHandler(object):
|
|
|
"""
|
|
|
批量插入处理器
|
|
|
oracle :
|
|
|
insert all
|
|
|
into oracle_table ( id, code ) values( 1 , '1' )
|
|
|
into oracle_table ( id, code ) values( 2 , '2' )
|
|
|
into oracle_table ( id, code ) values( 3 , '3' )
|
|
|
into oracle_table ( id, code ) values( 4 , '4' )
|
|
|
select 1 from dual ;
|
|
|
postgresql and mysql
|
|
|
into oracle_table ( id, code ) values( 1 , '1' ),( 2 , '2' ),( 3 , '3' )
|
|
|
|
|
|
"""
|
|
|
|
|
|
BUILD_INSERT = {
|
|
|
'oracle': 'build_oracle_insert',
|
|
|
'postgresql': 'build_pg_insert',
|
|
|
'mysql': 'build_mysql_insert'
|
|
|
}
|
|
|
|
|
|
class NotSupportError(Exception):
|
|
|
def __init__(self, msg=''):
|
|
|
Exception.__init__(self, msg)
|
|
|
|
|
|
def __init__(self, db_type, sql_tpl, data, batch_size):
|
|
|
"""
|
|
|
:param db_type: 数据库类型
|
|
|
:param sql_tpl: pg及mysql: insert into t1 (f1,f2,f3) values %s
|
|
|
oracle: into t1 (f1,f2,f3) values %s
|
|
|
:param data: [(1,'tom',29),(2,'jack',30)]
|
|
|
:param batch_size:
|
|
|
"""
|
|
|
if db_type not in ['oracle', 'postgresql']:
|
|
|
raise self.NotSupportError()
|
|
|
self.db_type = db_type
|
|
|
self.sql_tpl = sql_tpl
|
|
|
self.data = data
|
|
|
self.batch_size = batch_size
|
|
|
|
|
|
def _split_data(self):
|
|
|
return split_coll(self.data, self.batch_size)
|
|
|
|
|
|
def build_insert(self):
|
|
|
data_set = self._split_data()
|
|
|
sql_set = []
|
|
|
for part in data_set:
|
|
|
sql_set.append(getattr(self, self.BUILD_INSERT[self.db_type])(part))
|
|
|
return sql_set
|
|
|
|
|
|
def build_oracle_insert(self, data_set):
|
|
|
begin = 'insert all \r '
|
|
|
for ds in data_set:
|
|
|
val = '('
|
|
|
for ele in ds:
|
|
|
val += self._field_value_convert(ele)
|
|
|
val = val[0:-1] + ')'
|
|
|
begin += (self.sql_tpl.replace('%s', val) + ' \r ')
|
|
|
end = 'select 1 from dual'
|
|
|
return begin + end
|
|
|
|
|
|
def build_pg_insert(self, data_set):
|
|
|
vals = ''
|
|
|
for ds in data_set:
|
|
|
val = '('
|
|
|
for ele in ds:
|
|
|
val += self._field_value_convert(ele)
|
|
|
val = val[0:-1] + ')'
|
|
|
vals += val + ','
|
|
|
return self.sql_tpl.replace('%s', vals[0:-1])
|
|
|
|
|
|
def build_mysql_insert(self, data_set):
|
|
|
return self.build_pg_insert(data_set)
|
|
|
|
|
|
@staticmethod
|
|
|
def _field_value_convert(field_val):
|
|
|
"""
|
|
|
字段类型转换
|
|
|
:param field_val: 字段值
|
|
|
:return:
|
|
|
"""
|
|
|
# None处理
|
|
|
if field_val is None:
|
|
|
field_val = ''
|
|
|
# 特殊字符处理
|
|
|
if isinstance(field_val, str) and "'" in field_val:
|
|
|
field_val = field_val.replace("'", '')
|
|
|
# TODO 不可见字符处理,查询创建SQL时会删除\r,导致SQL格式混乱
|
|
|
# if isinstance(field_val, str):
|
|
|
# field_val = ''.join(x for x in field_val if x.isprintable())
|
|
|
return "'" + field_val + "'," if isinstance(field_val, str) else str(field_val) + ','
|