You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

208 lines
5.7 KiB

2 years ago
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2023/4/4 17:29
# @Author : old tom
# @File : fu_db_api.py
# @Project : futool-db
# @Desc :
from sqlalchemy import Connection, text, CursorResult
from common.fudb.dbapis.fu_collection import split_coll
class SqlExecuteError(Exception):
def __init__(self, msg=''):
Exception.__init__(self, msg)
def _select(conn: Connection, sql) -> CursorResult:
"""
自动关闭连接
:param conn:
:return:
"""
try:
return conn.execute(text(sql))
finally:
conn.close()
def _execute_with_tx(conn: Connection, sql):
"""
带事务执行SQL
:return:
"""
try:
conn.begin()
rt = conn.execute(text(sql))
conn.commit()
return rt.rowcount
except Exception as e:
conn.rollback()
raise SqlExecuteError(msg=f'sql [{sql}] 执行失败,开始回滚,e={e}')
finally:
conn.close()
def select_one(conn: Connection, sql):
"""
查询一个
:param conn:
:param sql:
:return:
"""
return _select(conn, sql).fetchone()
def select_all(conn: Connection, sql):
"""
查询全部
:param conn:
:param sql:
:return:
"""
return _select(conn, sql).fetchall()
def count(conn: Connection, table):
"""
统计数据量
:param conn:
:param table:
:return:
"""
count_tpl = f'select count(1) from {table}'
return select_one(conn, count_tpl)[0]
def execute_update(conn: Connection, sql):
"""
带事务执行,可用于insert update delete 语句
:param conn:
:param sql:
:return: 受影响的行数与java-jdbc的execute_update返回true|false相似,可用于判断是否执行成功
"""
return _execute_with_tx(conn, sql)
def batch_insert(conn: Connection, db_type, sql_tpl, data, batch_size=1500):
"""
批量插入
将sql转为into oracle_table ( id, code ) values( 1 , '1' ),( 2 , '2' ),( 3 , '3' )
2 years ago
:param conn: 数据库连接
:param batch_size: 每次插入量
:param db_type: 数据库类型
:param sql_tpl: insert into t1 (f1,f2,f3) values %s
:param data: [(1,'tom',29),(2,'jack',30)]
:return:
"""
handler = BatchInsertHandler(db_type, sql_tpl, data, batch_size)
insert_sqls = handler.build_insert()
# 插入都在一个事务内
row_count = 0
try:
conn.begin()
for sql_set in insert_sqls:
rt = conn.execute(text(sql_set))
row_count += rt.rowcount
conn.commit()
return row_count
except Exception as e:
conn.rollback()
raise SqlExecuteError(msg=f"批量插入异常,e={e}")
finally:
conn.close()
class BatchInsertHandler(object):
"""
批量插入处理器
oracle :
insert all
into oracle_table ( id, code ) values( 1 , '1' )
into oracle_table ( id, code ) values( 2 , '2' )
into oracle_table ( id, code ) values( 3 , '3' )
into oracle_table ( id, code ) values( 4 , '4' )
select 1 from dual ;
postgresql and mysql
into oracle_table ( id, code ) values( 1 , '1' ),( 2 , '2' ),( 3 , '3' )
"""
BUILD_INSERT = {
'oracle': 'build_oracle_insert',
'postgresql': 'build_pg_insert',
'mysql': 'build_mysql_insert'
}
class NotSupportError(Exception):
def __init__(self, msg=''):
Exception.__init__(self, msg)
def __init__(self, db_type, sql_tpl, data, batch_size):
"""
:param db_type: 数据库类型
:param sql_tpl: pg及mysql: insert into t1 (f1,f2,f3) values %s
oracle: into t1 (f1,f2,f3) values %s
:param data: [(1,'tom',29),(2,'jack',30)]
:param batch_size:
"""
if db_type not in ['oracle', 'postgresql']:
raise self.NotSupportError()
self.db_type = db_type
self.sql_tpl = sql_tpl
self.data = data
self.batch_size = batch_size
def _split_data(self):
return split_coll(self.data, self.batch_size)
def build_insert(self):
data_set = self._split_data()
sql_set = []
for part in data_set:
sql_set.append(getattr(self, self.BUILD_INSERT[self.db_type])(part))
return sql_set
def build_oracle_insert(self, data_set):
begin = 'insert all \r '
for ds in data_set:
val = '('
for ele in ds:
val += self._field_value_convert(ele)
2 years ago
val = val[0:-1] + ')'
begin += (self.sql_tpl.replace('%s', val) + ' \r ')
end = 'select 1 from dual'
return begin + end
def build_pg_insert(self, data_set):
vals = ''
for ds in data_set:
val = '('
for ele in ds:
val += self._field_value_convert(ele)
2 years ago
val = val[0:-1] + ')'
vals += val + ','
return self.sql_tpl.replace('%s', vals[0:-1])
def build_mysql_insert(self, data_set):
return self.build_pg_insert(data_set)
@staticmethod
def _field_value_convert(field_val):
"""
字段类型转换
:param field_val: 字段值
:return:
"""
# None处理
if field_val is None:
field_val = ''
# 特殊字符处理
if isinstance(field_val, str) and "'" in field_val:
field_val = field_val.replace("'", '')
# TODO 不可见字符处理,查询创建SQL时会删除\r,导致SQL格式混乱
# if isinstance(field_val, str):
# field_val = ''.join(x for x in field_val if x.isprintable())
return "'" + field_val + "'," if isinstance(field_val, str) else str(field_val) + ','