Skip to content

Commit

Permalink
修复前后置sql解析和执行方法 (#33)
Browse files Browse the repository at this point in the history
* 修复前后置sql解析和执行方法

* 修复sql类型错误提示信息
  • Loading branch information
wu-clan authored Sep 6, 2023
1 parent 5e7b29d commit c39c12b
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 52 deletions.
1 change: 1 addition & 0 deletions httpfpt/core/conf.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ port = 3306
user = 'root'
password = '123456'
database = 'test'
charset = 'utf-8'

# redis 数据库
[redis]
Expand Down
1 change: 1 addition & 0 deletions httpfpt/core/get_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
MysqlDB_USER = __config['mysql']['user']
MysqlDB_PASSWORD = __config['mysql']['password']
MysqlDB_DATABASE = __config['mysql']['database']
MysqlDB_CHARSET = __config['mysql']['charset']

# redis 数据库
REDIS_HOST = __config['redis']['host']
Expand Down
104 changes: 60 additions & 44 deletions httpfpt/db/mysql_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# _*_ coding:utf-8 _*_
import datetime
import decimal
from typing import Any, Optional
from typing import Optional

import pymysql
from dbutils.pooled_db import PooledDB
Expand All @@ -14,78 +14,97 @@
from httpfpt.common.yaml_handler import write_yaml_vars
from httpfpt.core import get_conf
from httpfpt.core.path_conf import RUN_ENV_PATH
from httpfpt.enums.query_fetch import QueryFetchType
from httpfpt.enums.sql_type import SqlType
from httpfpt.enums.var_type import VarType
from httpfpt.utils.enum_control import get_enum_values


class MysqlDB:
def __init__(self) -> None:
try:
self.conn = PooledDB(
pymysql,
maxconnections=15,
blocking=True, # 防止连接过多报错
host=get_conf.MysqlDB_HOST,
port=get_conf.MysqlDB_PORT,
user=get_conf.MysqlDB_USER,
password=get_conf.MysqlDB_PASSWORD,
database=get_conf.MysqlDB_DATABASE,
).connection()
except BaseException as e:
log.error(f'数据库 mysql 连接失败: {e}')
# 声明游标
self.cursor = self.conn.cursor()
self._pool = PooledDB(
pymysql,
host=get_conf.MysqlDB_HOST,
port=get_conf.MysqlDB_PORT,
user=get_conf.MysqlDB_USER,
password=get_conf.MysqlDB_PASSWORD,
database=get_conf.MysqlDB_DATABASE,
charset=get_conf.MysqlDB_CHARSET,
maxconnections=15,
blocking=True, # 连接池中如果没有可用连接后,是否阻塞等待
autocommit=False, # 是否自动提交
)
self.conn = self._pool.connection()
self.cursor = self.conn.cursor(cursor=pymysql.cursors.DictCursor) # type: ignore

def query(self, sql: str, fetch: str = 'all') -> Any:
def close(self) -> None:
"""
关闭游标和数据库连接
:return:
"""
self.cursor.close()
self.conn.close()

def query(self, sql: str, fetch: str = 'all') -> dict:
"""
数据库查询
:param sql:
:param fetch: 查询条件; all / 任意内容(单条记录)
:param fetch: 查询条件; one: 查询一条数据; all: 查询所有数据
:return:
"""
data = {}
try:
self.cursor.execute(sql)
if fetch == 'all':
if fetch == QueryFetchType.ONE:
query_data = self.cursor.fetchone()
elif fetch == QueryFetchType.ALL:
query_data = self.cursor.fetchall()
else:
query_data = self.cursor.fetchone()
raise ValueError(f'查询条件 {fetch} 错误, 请使用 one / all')
except Exception as e:
log.error(f'执行 {sql} 失败: {e}')
raise e
else:
log.info(f'执行 {sql} 成功')
return query_data
try:
for k, v in query_data.items():
if isinstance(v, decimal.Decimal):
if v % 1 == 0:
data[k] = int(v)
data[k] = float(v)
elif isinstance(v, datetime.datetime):
data[k] = str(v)
else:
data[k] = v
except Exception as e:
log.error(f'序列化 {sql} 查询结果失败: {e}')
raise e
return data
finally:
self.close()

def execute(self, sql: str) -> None:
def execute(self, sql: str) -> int:
"""
执行 sql 操作
:return:
"""
try:
self.cursor.execute(sql)
rowcount = self.cursor.execute(sql)
self.conn.commit()
except Exception as e:
self.conn.rollback()
log.error(f'执行 {sql} 失败: {e}')
raise e
else:
log.info(f'执行 {sql} 成功')
return rowcount
finally:
self.close()

def close(self) -> None:
"""
关闭游标和数据库连接
:return:
"""
self.cursor.close()
self.conn.close()

def exec_case_sql(self, sql: list, env: Optional[str] = None) -> dict:
def exec_case_sql(self, sql: str | list, env: Optional[str] = None) -> None:
"""
执行用例 sql
Expand All @@ -95,21 +114,19 @@ def exec_case_sql(self, sql: list, env: Optional[str] = None) -> dict:
"""
sql_type = get_enum_values(SqlType)
if any(_.upper() in sql for _ in sql_type):
raise ValueError(f'{sql} 中存在不允许的命令类型, 仅支持 DQL 类型 sql 语句')
raise ValueError(f'{sql} 中存在不允许的命令类型, 仅支持 {sql_type} 类型 sql 语句')
else:
data = {}
if isinstance(sql, str):
log.info(f'执行 sql: {sql}')
self.query(sql)
for s in sql:
# 获取返回数据
if isinstance(s, str):
log.info(f'执行 sql: {s}')
query_data = self.query(s)
for k, v in query_data.items():
if isinstance(v, decimal.Decimal):
data[k] = float(v)
elif isinstance(v, datetime.datetime):
data[k] = str(v)
else:
data[k] = v
if SqlType.select in s:
self.query(s)
else:
self.execute(s)
# 设置变量
if isinstance(s, dict):
log.info(f'执行变量提取 sql: {s["sql"]}')
Expand All @@ -135,4 +152,3 @@ def exec_case_sql(self, sql: list, env: Optional[str] = None) -> dict:
raise ValueError(
f'前置 sql 设置变量失败, 用例参数 "type: {set_type}" 值错误, 请使用 cache / env / global' # noqa: E501
)
return data
8 changes: 8 additions & 0 deletions httpfpt/enums/query_fetch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from httpfpt.enums import StrEnum


class QueryFetchType(StrEnum):
ONE = 'one'
ALL = 'all'
1 change: 1 addition & 0 deletions httpfpt/enums/sql_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@
@unique
class SqlType(StrEnum):
select = 'SELECT'
insert = 'INSERT'
update = 'UPDATE'
delete = 'DELETE'
4 changes: 2 additions & 2 deletions httpfpt/utils/enum_control.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from enum import Enum
from typing import Any
from typing import Type


def get_enum_values(enum_class: Any) -> list:
def get_enum_values(enum_class: Type[Enum]) -> list:
if issubclass(enum_class, Enum):
return list(map(lambda ec: ec.value, enum_class))
else:
Expand Down
16 changes: 10 additions & 6 deletions httpfpt/utils/request/request_data_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,8 +498,11 @@ def setup_sql(self) -> Union[list, None]:
for i in sql:
if isinstance(i, dict):
for k, v in i.items():
if not isinstance(v, str):
raise ValueError(f'请求参数解析失败,参数 test_steps:setup:sql:{k} 不是有效的 str 类型') # noqa: E501
if k != 'value':
if not isinstance(v, str):
raise ValueError(
f'请求参数解析失败,参数 test_steps:setup:sql:{k} 不是有效的 str 类型' # noqa: E501
)
else:
if not isinstance(i, str):
raise ValueError(f'请求数据解析失败, 参数 test_steps:setup:sql:{i} 不是有效的 str 类型')
Expand Down Expand Up @@ -556,10 +559,11 @@ def teardown_sql(self) -> Union[list, None]:
for i in sql:
if isinstance(i, dict):
for k, v in i.items():
if not isinstance(v, str):
raise ValueError(
f'请求参数解析失败,参数 test_steps:teardown:sql:{k} 不是有效的 str 类型' # noqa: E501
)
if k != 'value':
if not isinstance(v, str):
raise ValueError(
f'请求参数解析失败,参数 test_steps:teardown:sql:{k} 不是有效的 str 类型' # noqa: E501
)
else:
if not isinstance(i, str):
raise ValueError(f'请求数据解析失败, 参数 test_steps:teardown:sql:{i} 不是有效的 str 类型') # noqa: E501
Expand Down

0 comments on commit c39c12b

Please sign in to comment.