fix: some issue

This commit is contained in:
meilin.huang
2025-03-11 12:42:20 +08:00
parent c7c3fd7f7e
commit bc21ba7c1e
23 changed files with 280 additions and 242 deletions

View File

@@ -333,7 +333,6 @@ func (d *dbSqlExecAppImpl) saveSqlExecLog(dbSqlExecRecord *entity.DbSqlExec, res
func (d *dbSqlExecAppImpl) doSelect(ctx context.Context, sqlExecParam *sqlExecParam) (*dto.DbSqlExecRes, error) {
maxCount := config.GetDbms().MaxResultSet
selectStmt := sqlExecParam.Stmt
selectSql := sqlExecParam.Sql
sqlExecParam.SqlExecRecord.Type = entity.DbSqlExecTypeQuery
@@ -343,49 +342,7 @@ func (d *dbSqlExecAppImpl) doSelect(ctx context.Context, sqlExecParam *sqlExecPa
}
}
if selectStmt != nil {
needCheckLimit := false
var limit *sqlstmt.Limit
switch stmt := selectStmt.(type) {
case *sqlstmt.SimpleSelectStmt:
qs := stmt.QuerySpecification
limit = qs.Limit
if qs.SelectElements != nil && (qs.SelectElements.Star != "" || len(qs.SelectElements.Elements) > 1) {
needCheckLimit = true
}
case *sqlstmt.UnionSelectStmt:
limit = stmt.Limit
selectSql = selectStmt.GetText()
needCheckLimit = true
}
// 如果配置为0则不校验分页参数
if needCheckLimit && maxCount != 0 {
if limit == nil {
return nil, errorx.NewBizI(ctx, imsg.ErrNoLimitStmt)
}
if limit.RowCount > maxCount {
return nil, errorx.NewBizI(ctx, imsg.ErrLimitInvalid, "count", maxCount)
}
}
} else {
if maxCount != 0 {
if !strings.Contains(selectSql, "limit") &&
// 兼容oracle rownum分页
!strings.Contains(selectSql, "rownum") &&
// 兼容mssql offset分页
!strings.Contains(selectSql, "offset") &&
// 兼容mssql top 分页 with result as ({query sql}) select top 100 * from result
!strings.Contains(selectSql, " top ") {
// 判断是不是count语句
if !strings.Contains(selectSql, "count(") {
return nil, errorx.NewBizI(ctx, imsg.ErrNoLimitStmt)
}
}
}
}
return d.doQuery(ctx, sqlExecParam.DbConn, selectSql)
return d.doQuery(ctx, sqlExecParam.DbConn, selectSql, maxCount)
}
func (d *dbSqlExecAppImpl) doOtherRead(ctx context.Context, sqlExecParam *sqlExecParam) (*dto.DbSqlExecRes, error) {
@@ -398,7 +355,7 @@ func (d *dbSqlExecAppImpl) doOtherRead(ctx context.Context, sqlExecParam *sqlExe
}
}
return d.doQuery(ctx, sqlExecParam.DbConn, selectSql)
return d.doQuery(ctx, sqlExecParam.DbConn, selectSql, 0)
}
func (d *dbSqlExecAppImpl) doExecDDL(ctx context.Context, sqlExecParam *sqlExecParam) (*dto.DbSqlExecRes, error) {
@@ -588,11 +545,23 @@ func (d *dbSqlExecAppImpl) doInsert(ctx context.Context, sqlExecParam *sqlExecPa
return d.doExec(ctx, sqlExecParam.DbConn, sqlExecParam.Sql)
}
func (d *dbSqlExecAppImpl) doQuery(ctx context.Context, dbConn *dbi.DbConn, sql string) (*dto.DbSqlExecRes, error) {
cols, res, err := dbConn.QueryContext(ctx, sql)
func (d *dbSqlExecAppImpl) doQuery(ctx context.Context, dbConn *dbi.DbConn, sql string, maxRows int) (*dto.DbSqlExecRes, error) {
res := make([]map[string]any, 0, 16)
nowRows := 0
cols, err := dbConn.WalkQueryRows(ctx, sql, func(row map[string]any, columns []*dbi.QueryColumn) error {
nowRows++
// 超过指定的最大查询记录数,则停止查询
if maxRows != 0 && nowRows > maxRows {
return dbi.NewStopWalkQueryError(fmt.Sprintf("exceed the maximum number of query records %d", maxRows))
}
res = append(res, row)
return nil
})
if err != nil {
return nil, err
}
return &dto.DbSqlExecRes{
Sql: sql,
Columns: cols,