diff --git a/frontend/package.json b/frontend/package.json index 322f79f9..61b412a1 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -13,7 +13,7 @@ "@element-plus/icons-vue": "^2.3.2", "@logicflow/core": "^2.1.7", "@logicflow/extension": "^2.1.9", - "@vueuse/core": "^14.1.0", + "@vueuse/core": "^14.2.0", "@xterm/addon-fit": "^0.11.0", "@xterm/addon-search": "^0.16.0", "@xterm/addon-web-links": "^0.12.0", @@ -31,15 +31,15 @@ "monaco-sql-languages": "^0.15.1", "nprogress": "^0.2.0", "pinia": "^3.0.4", - "qrcode.vue": "^3.6.0", + "qrcode.vue": "^3.8.0", "screenfull": "^6.0.2", "sortablejs": "^1.15.6", - "sql-formatter": "^15.6.12", + "sql-formatter": "^15.7.0", "trzsz": "^1.1.5", "uuid": "^13.0.0", "vue": "^v3.6.0-beta.2", "vue-i18n": "^11.2.8", - "vue-router": "^4.6.4", + "vue-router": "^5.0.2", "vuedraggable": "^4.1.0", "xlsx": "^0.18.5" }, diff --git a/frontend/src/views/flow/components/flowdesign/node/aitask/PropSetting.vue b/frontend/src/views/flow/components/flowdesign/node/aitask/PropSetting.vue index 90ccad15..7141107e 100644 --- a/frontend/src/views/flow/components/flowdesign/node/aitask/PropSetting.vue +++ b/frontend/src/views/flow/components/flowdesign/node/aitask/PropSetting.vue @@ -27,16 +27,14 @@ - + - + @@ -45,6 +43,7 @@ import { notEmpty } from '@/common/assert'; import { formatDate } from '@/common/utils/format'; import EnumTag from '@/components/enumtag/EnumTag.vue'; +import MonacoEditor from '@/components/monaco/MonacoEditor.vue'; import { useI18nPleaseInput } from '@/hooks/useI18n'; import { ProcinstTaskStatus } from '@/views/flow/enums'; import { computed } from 'vue'; diff --git a/frontend/src/views/ops/db/component/sqleditor/DbSqlEditor.vue b/frontend/src/views/ops/db/component/sqleditor/DbSqlEditor.vue index 0eac5192..0e0bc492 100644 --- a/frontend/src/views/ops/db/component/sqleditor/DbSqlEditor.vue +++ b/frontend/src/views/ops/db/component/sqleditor/DbSqlEditor.vue @@ -400,7 +400,7 @@ const runNonQuerySqls = async (sqls: string[], newTab: boolean) => { const result: any = (data.value as any)[0]; results.push({ sql: result.sql, - rowsAffected: result.res?.[0]?.rowsAffected, + rowsAffected: result.res?.[0].rowsAffected, error: result.errorMsg || '-', }); } catch (error: any) { @@ -413,9 +413,9 @@ const runNonQuerySqls = async (sqls: string[], newTab: boolean) => { // 设置表格列 state.execResTabs[i].tableColumn = [ - { columnName: 'sql', columnType: 'string', show: true }, - { columnName: 'rowsAffected', columnType: 'number', show: true }, - { columnName: 'error', columnType: 'string', show: true }, + { columnName: 'SQL', key: 'sql', columnType: 'string', show: true }, + { columnName: 'RowsAffected', key: 'rowsAffected', columnType: 'number', show: true }, + { columnName: 'Error', key: 'error', columnType: 'string', show: true }, ]; state.execResTabs[i].data = results; diff --git a/server/go.mod b/server/go.mod index db3b1b6b..53807e65 100644 --- a/server/go.mod +++ b/server/go.mod @@ -6,8 +6,8 @@ require ( gitee.com/chunanyong/dm v1.8.22 gitee.com/liuzongyang/libpq v1.10.11 github.com/antlr4-go/antlr/v4 v4.13.1 - github.com/cloudwego/eino v0.7.13 - github.com/cloudwego/eino-ext/components/model/openai v0.1.6 + github.com/cloudwego/eino v0.7.32 + github.com/cloudwego/eino-ext/components/model/openai v0.1.8 github.com/docker/docker v28.5.0+incompatible github.com/docker/go-connections v0.6.0 github.com/gin-gonic/gin v1.11.0 @@ -59,7 +59,7 @@ require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/clipperhouse/uax29/v2 v2.2.0 // indirect github.com/cloudwego/base64x v0.1.6 // indirect - github.com/cloudwego/eino-ext/libs/acl/openai v0.1.10 // indirect + github.com/cloudwego/eino-ext/libs/acl/openai v0.1.13 // indirect github.com/containerd/errdefs v1.0.0 // indirect github.com/containerd/errdefs/pkg v0.3.0 // indirect github.com/containerd/log v0.1.0 // indirect diff --git a/server/internal/ai/agent/agent.go b/server/internal/ai/agent/agent.go index 25ac58ff..a876291b 100644 --- a/server/internal/ai/agent/agent.go +++ b/server/internal/ai/agent/agent.go @@ -2,48 +2,42 @@ package agent import ( "context" - "errors" - "io" "mayfly-go/internal/ai/config" aimodel "mayfly-go/internal/ai/model" - "mayfly-go/pkg/gox" "mayfly-go/pkg/logx" + "github.com/cloudwego/eino/adk" "github.com/cloudwego/eino/components/tool" - "github.com/cloudwego/eino/compose" - "github.com/cloudwego/eino/flow/agent" - "github.com/cloudwego/eino/flow/agent/react" "github.com/cloudwego/eino/schema" ) -// GetAiAgent 获取AI Agent -func GetAiAgent(ctx context.Context, aiConfig *config.AIModelConfig, tools ...tool.BaseTool) (*react.Agent, error) { +// GetAgent 获取AI Agent +func GetAgent(ctx context.Context, aiConfig *config.AIModelConfig, tools ...tool.BaseTool) (adk.Agent, error) { toolableChatModel, err := aimodel.GetChatModel(ctx, aiConfig) if err != nil { return nil, err } // 初始化所需的 tools - toolsConf := compose.ToolsNodeConfig{ - Tools: tools, - } - // 创建 agent - return react.NewAgent(ctx, &react.AgentConfig{ - ToolCallingModel: toolableChatModel, - ToolsConfig: toolsConf, - MaxStep: len(toolsConf.Tools)*1 + 3, - MessageModifier: func(ctx context.Context, input []*schema.Message) []*schema.Message { - return input - }, + toolsConfig := adk.ToolsConfig{} + toolsConfig.Tools = tools + + chatAgent, err := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{ + Name: "ops_expert", + Description: "一位拥有20多年系统管理、数据库管理和基础设施优化经验的专业DevOps专家。", + Instruction: `你现在是一位专业的数据库管理员、Redis管理员和安全审核专家,请根据用户的问题给出最合适的答案。`, + Model: toolableChatModel, + ToolsConfig: toolsConfig, }) + if err != nil { + return nil, err + } + + return chatAgent, nil } -type AiAgent struct { - *react.Agent -} - -// NewAiAgent 创建AI Agent,并注册指定类型的工具 -func NewAiAgent(ctx context.Context, toolTypes ...ToolType) (*AiAgent, error) { +// GetOpsExpertAgent 获取运维专家agent +func GetOpsExpertAgent(ctx context.Context, toolTypes ...ToolType) (*AiAgent, error) { tools := make([]tool.BaseTool, 0) for _, toolType := range toolTypes { if t, exists := GetTools(toolType); exists { @@ -51,7 +45,7 @@ func NewAiAgent(ctx context.Context, toolTypes ...ToolType) (*AiAgent, error) { } } - agent, err := GetAiAgent(ctx, config.GetAiModel(), tools...) + agent, err := GetAgent(ctx, config.GetAiModel(), tools...) if err != nil { return nil, err } @@ -60,88 +54,52 @@ func NewAiAgent(ctx context.Context, toolTypes ...ToolType) (*AiAgent, error) { }, nil } -// Chat 聊天,返回消息流通道 -func (aiAgent *AiAgent) Chat(ctx context.Context, sysPrompt string, question string) (<-chan *schema.Message, <-chan error) { - ch := make(chan *schema.Message, 512) - errCh := make(chan error, 1) +type AiAgent struct { + adk.Agent +} +// Run 运行,并返回最终结果 +func (aiAgent *AiAgent) Run(ctx context.Context, sysPrompt string, question string) (string, error) { if sysPrompt == "" { sysPrompt = "你现在是一位拥有20年实战经验的顶级系统运维专家,精通Linux操作系统、数据库管理(如MySQL、PostgreSQL)、NoSQL数据库(如Redis、MongoDB)以及搜索引擎(如Elasticsearch)。" } - agentOption := []agent.AgentOption{} + runner := adk.NewRunner(ctx, adk.RunnerConfig{ + EnableStreaming: false, + Agent: aiAgent.Agent, + CheckPointStore: NewInMemoryStore(), + }) - go func() { - defer close(ch) - defer close(errCh) - defer gox.Recover(func(err error) { - errCh <- err - }) + iter := runner.Run(ctx, []adk.Message{ + { + Role: schema.System, + Content: sysPrompt, + }, + { + Role: schema.User, + Content: question, + }, + }) - sr, err := aiAgent.Stream(ctx, []*schema.Message{ - { - Role: schema.System, - Content: sysPrompt, - }, - { - Role: schema.User, - Content: question, - }, - }, agentOption...) - if err != nil { - errCh <- err // 将错误发送到错误通道 - return - } - defer sr.Close() - - for { - msg, err := sr.Recv() - if err != nil { - if errors.Is(err, io.EOF) { - break - } - logx.Errorf("failed to recv response: %v", err) - break - } - // logx.Debugf("stream: %s", msg.String()) - ch <- msg - } - }() - - return ch, errCh -} - -// GetChatMsg 获取完整的聊天回复内容 -func (aiAgent *AiAgent) GetChatMsg(ctx context.Context, sysPrompt string, question string) (string, error) { - msgChan, errChan := aiAgent.Chat(ctx, sysPrompt, question) res := "" - - // 使用 select 同时监听消息通道和错误通道 for { - select { - case msg, ok := <-msgChan: - if !ok { - // 消息通道已关闭,说明正常结束 - // 检查错误通道是否有错误 - select { - case err := <-errChan: - if err != nil { - return "", err - } - default: - return res, nil - } - return res, nil - } - res += msg.Content - case err := <-errChan: - // 优先检查错误通道 - if err != nil { - return "", err - } - case <-ctx.Done(): - // 上下文被取消 - return "", ctx.Err() + event, ok := iter.Next() + if !ok { + break + } + + err := event.Err + if err != nil { + logx.Error(err.Error()) + return res, err + } + + LogEvent(event) + msg := event.Output.MessageOutput.Message + if msg != nil { + res = msg.Content } } + + return res, nil } diff --git a/server/internal/ai/agent/print.go b/server/internal/ai/agent/print.go new file mode 100644 index 00000000..2073f272 --- /dev/null +++ b/server/internal/ai/agent/print.go @@ -0,0 +1,121 @@ +package agent + +import ( + "fmt" + "io" + "log" + "mayfly-go/pkg/logx" + "strings" + + "github.com/cloudwego/eino/adk" + "github.com/cloudwego/eino/schema" +) + +func LogEvent(event *adk.AgentEvent) { + logx.Debugf("agent name: %s, path: %s", event.AgentName, event.RunPath) + if event.Output != nil && event.Output.MessageOutput != nil { + if m := event.Output.MessageOutput.Message; m != nil { + if len(m.Content) > 0 { + if m.Role == schema.Tool { + logx.Debugf("agent tool response: %s", m.Content) + } else { + logx.Debugf("agent answer: %s", m.Content) + } + } + if len(m.ToolCalls) > 0 { + for _, tc := range m.ToolCalls { + logx.Debugf("agent tool name: %s", tc.Function.Name) + logx.Debugf("agent tool arguments: %s", tc.Function.Arguments) + } + } + } else if s := event.Output.MessageOutput.MessageStream; s != nil { + toolMap := map[int][]*schema.Message{} + var contentStart bool + charNumOfOneRow := 0 + maxCharNumOfOneRow := 120 + for { + chunk, err := s.Recv() + if err != nil { + if err == io.EOF { + break + } + logx.Debugf("agent error: %v", err) + return + } + if chunk.Content != "" { + if !contentStart { + contentStart = true + if chunk.Role == schema.Tool { + logx.Debugf("agent tool response: ") + } else { + logx.Debugf("agent answer: ") + } + } + + charNumOfOneRow += len(chunk.Content) + if strings.Contains(chunk.Content, "\n") { + charNumOfOneRow = 0 + } else if charNumOfOneRow >= maxCharNumOfOneRow { + logx.Debugf("\n") + charNumOfOneRow = 0 + } + logx.Debugf("%v", chunk.Content) + } + + if len(chunk.ToolCalls) > 0 { + for _, tc := range chunk.ToolCalls { + index := tc.Index + if index == nil { + logx.Error("index is nil") + } + toolMap[*index] = append(toolMap[*index], &schema.Message{ + Role: chunk.Role, + ToolCalls: []schema.ToolCall{ + { + ID: tc.ID, + Type: tc.Type, + Index: tc.Index, + Function: schema.FunctionCall{ + Name: tc.Function.Name, + Arguments: tc.Function.Arguments, + }, + }, + }, + }) + } + } + } + + for _, msgs := range toolMap { + m, err := schema.ConcatMessages(msgs) + if err != nil { + log.Fatalf("ConcatMessage failed: %v", err) + return + } + logx.Debugf("agent tool name: %s", m.ToolCalls[0].Function.Name) + logx.Debugf("agent tool arguments: %s", m.ToolCalls[0].Function.Arguments) + } + } + } + if event.Action != nil { + if event.Action.TransferToAgent != nil { + logx.Debugf("agent action: transfer to %v", event.Action.TransferToAgent.DestAgentName) + } + if event.Action.Interrupted != nil { + for _, ic := range event.Action.Interrupted.InterruptContexts { + str, ok := ic.Info.(fmt.Stringer) + if ok { + logx.Debugf("\n%s", str.String()) + } else { + logx.Debugf("\n%v", ic.Info) + } + } + } + if event.Action.Exit { + logx.Debugf("agent action: exit") + } + } + if event.Err != nil { + logx.Debugf("agent error: %v", event.Err) + } +} diff --git a/server/internal/ai/agent/store.go b/server/internal/ai/agent/store.go new file mode 100644 index 00000000..cb0c9adb --- /dev/null +++ b/server/internal/ai/agent/store.go @@ -0,0 +1,27 @@ +package agent + +import ( + "context" + + "github.com/cloudwego/eino/compose" +) + +func NewInMemoryStore() compose.CheckPointStore { + return &inMemoryStore{ + mem: map[string][]byte{}, + } +} + +type inMemoryStore struct { + mem map[string][]byte +} + +func (i *inMemoryStore) Set(ctx context.Context, key string, value []byte) error { + i.mem[key] = value + return nil +} + +func (i *inMemoryStore) Get(ctx context.Context, key string) ([]byte, bool, error) { + v, ok := i.mem[key] + return v, ok, nil +} diff --git a/server/internal/ai/agent/utils.go b/server/internal/ai/agent/utils.go new file mode 100644 index 00000000..9dc74934 --- /dev/null +++ b/server/internal/ai/agent/utils.go @@ -0,0 +1,101 @@ +package agent + +import ( + "bytes" + "errors" + "mayfly-go/pkg/utils/collx" + "mayfly-go/pkg/utils/jsonx" + "regexp" + "strings" +) + +// ParseLLMJSON 尝试从大模型输出中解析 JSON +func ParseLLMJSON[T any](raw string) (*T, error) { + candidates := extractJSONCandidates(raw) + + var lastErr error + for _, c := range candidates { + if v, err := jsonx.To[T](c); err == nil { + return v, nil + } else { + lastErr = err + } + } + + if lastErr == nil { + lastErr = errors.New("no json candidate found") + } + return nil, lastErr +} + +// ParseLLMJSON2Map 解析 LLM 返回的JSON为map +func ParseLLMJSON2Map(raw string) (collx.M, error) { + if res, err := ParseLLMJSON[collx.M](raw); err != nil { + return nil, err + } else { + return *res, nil + } +} + + +func extractJSONCandidates(raw string) []string { + var results []string + text := strings.TrimSpace(raw) + + // 1. 优先提取 code block 中的 JSON(对象 or 数组) + codeBlockRe := regexp.MustCompile( + "(?s)```(?:json)?\\s*([\\[{].*?[\\]}])\\s*```", + ) + matches := codeBlockRe.FindAllStringSubmatch(text, -1) + for _, m := range matches { + results = append(results, strings.TrimSpace(m[1])) + } + + // 2. 如果没找到 code block,尝试从全文裁剪 JSON + if len(results) == 0 { + if clipped := clipJSONValue(text); clipped != "" { + results = append(results, clipped) + } + } + + return results +} + +func clipJSONValue(s string) string { + objIdx := strings.Index(s, "{") + arrIdx := strings.Index(s, "[") + + start := -1 + var open, close byte + + switch { + case objIdx != -1 && (arrIdx == -1 || objIdx < arrIdx): + start = objIdx + open, close = '{', '}' + case arrIdx != -1: + start = arrIdx + open, close = '[', ']' + default: + return "" + } + + var buf bytes.Buffer + depth := 0 + + for i := start; i < len(s); i++ { + ch := s[i] + buf.WriteByte(ch) + + switch ch { + case open: + depth++ + case close: + depth-- + if depth == 0 { + return buf.String() + } + } + } + + return "" +} \ No newline at end of file diff --git a/server/internal/ai/agent/utils_test.go b/server/internal/ai/agent/utils_test.go new file mode 100644 index 00000000..fe19085a --- /dev/null +++ b/server/internal/ai/agent/utils_test.go @@ -0,0 +1,108 @@ +package agent + +import ( + "testing" +) + +// TestParseLLMJSON 测试 ParseLLMJSON 函数 +func TestParseLLMJSON(t *testing.T) { + // 定义测试用例结构体 + tests := []struct { + name string + input string + expected any + hasError bool + }{ + { + name: "Valid JSON Object", + input: "```json\n{\n \"name\": \"Alice\",\n \"age\": \"30\"\n}\n```", + expected: map[string]any{ + "name": "Alice", + "age": "30", + }, + hasError: false, + }, + { + name: "Valid JSON Object", + input: "```\n{\n \"name\": \"Alice\",\n \"age\": \"40\"\n}\n```", + expected: map[string]any{ + "name": "Alice", + "age": "40", + }, + hasError: false, + }, + { + name: "Valid JSON Object", + input: "aaabbbccc```\n{\n \"name\": \"Alice\",\n \"age\": \"50\"\n}\n```dddd", + expected: map[string]any{ + "name": "Alice", + "age": "50", + }, + hasError: false, + }, + { + name: "Valid JSON Array", + input: "```json\n[\n {\"id\": \"1\", \"value\": \"foo\"},\n {\"id\": \"2\", \"value\": \"bar\"}\n]\n```", + expected: []map[string]any{ + {"id": "1", "value": "foo"}, + {"id": "2", "value": "bar"}, + }, + hasError: false, + }, + { + name: "Valid JSON Array", + input: "aaaa```json\n[\n {\"id\": \"11\", \"value\": \"foo\"},\n {\"id\": \"22\", \"value\": \"bar\"}\n]\n```", + expected: []map[string]any{ + {"id": "11", "value": "foo"}, + {"id": "22", "value": "bar"}, + }, + hasError: false, + }, + { + name: "Invalid JSON Format", + input: "This is not a valid JSON", + expected: nil, + hasError: true, + }, + { + name: "Empty Input", + input: "", + expected: nil, + hasError: true, + }, + } + + // 执行测试用例 + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var result any + var err error + + // 根据 expected 类型调用不同的 ParseLLMJSON 方法 + switch tt.expected.(type) { + case map[string]any: + result, err = ParseLLMJSON[map[string]any](tt.input) + case []map[string]any: + result, err = ParseLLMJSON[[]map[string]any](tt.input) + default: + result, err = ParseLLMJSON[any](tt.input) + } + + // 验证错误情况 + if tt.hasError { + if err == nil { + t.Errorf("expected an error but got none") + } + return + } + + // 验证无错误情况下的结果 + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + t.Logf("%v", result) + }) + } +} \ No newline at end of file diff --git a/server/internal/ai/api/ai_db.go b/server/internal/ai/api/ai_db.go index 4e1e2045..e11168fb 100644 --- a/server/internal/ai/api/ai_db.go +++ b/server/internal/ai/api/ai_db.go @@ -2,7 +2,6 @@ package api import ( "fmt" - "mayfly-go/internal/ai/prompt" "mayfly-go/pkg/biz" "mayfly-go/pkg/logx" "mayfly-go/pkg/req" @@ -116,16 +115,5 @@ func (a *AiDB) GenerateSql(rc *req.Ctx) { func generateSqlPrompt(dbType, text string, tables []string) string { // 使用prompt包中的GetPrompt函数获取提示词模板 // 如果没有找到模板,则使用默认模板 - tableStr := strings.Join(tables, ", ") - promptTemplate := prompt.GetPrompt("SQL_GENERATE", dbType, tableStr) - if promptTemplate == "" { - promptTemplate = "你是一位专业的SQL开发工程师,请根据用户的自然语言描述,生成符合%s语法的SQL语句。\n" - if len(tables) > 0 { - promptTemplate += "相关表名:" + tableStr + "\n" - } - promptTemplate += "请确保生成的SQL语句语法正确,仅返回SQL语句,不要包含其他解释内容。" - promptTemplate = fmt.Sprintf(promptTemplate, dbType) - } - - return promptTemplate + return "" } diff --git a/server/internal/ai/model/openai.go b/server/internal/ai/model/openai.go index eff36dd9..f9727140 100644 --- a/server/internal/ai/model/openai.go +++ b/server/internal/ai/model/openai.go @@ -24,8 +24,5 @@ func (o *Openai) NewChatModel(ctx context.Context, aiConfig *config.AIModelConfi Timeout: time.Duration(aiConfig.TimeOut) * time.Second, MaxTokens: &aiConfig.MaxTokens, Temperature: &aiConfig.Temperature, - ResponseFormat: &openai.ChatCompletionResponseFormat{ - Type: openai.ChatCompletionResponseFormatTypeJSONObject, - }, }) } diff --git a/server/internal/ai/prompt/prompt.go b/server/internal/ai/prompt/prompt.go index 0093b808..83888815 100644 --- a/server/internal/ai/prompt/prompt.go +++ b/server/internal/ai/prompt/prompt.go @@ -2,51 +2,20 @@ package prompt import ( "embed" - "fmt" - "mayfly-go/pkg/logx" "mayfly-go/pkg/utils/stringx" - "strings" ) -const ( - FLOW_BIZ_AUDIT = "FLOW_BIZ_AUDIT" - SQL_GENERATE = "SQL_GENERATE" -) - -//go:embed prompts.txt +//go:embed prompts/*.md var prompts embed.FS -// prompt缓存 key: XXX_YYY value: 内容 -var promptCache = make(map[string]string, 20) +// GetPrompt 获取本地prompts文件内容,并进行模板解析 +func GetPrompt(filename string, values any) (string, error) { + // 自动添加 prompts/ 前缀 + fullPath := "prompts/" + filename -// 获取本地文件的prompt内容,并进行解析,获取对应key的prompt内容 -func GetPrompt(key string, formatValues ...any) string { - prompt := promptCache[key] - if prompt != "" { - return fmt.Sprintf(prompt, formatValues...) - } - - bytes, err := prompts.ReadFile("prompts.txt") + bytes, err := prompts.ReadFile(fullPath) if err != nil { - logx.Error("failed to read prompt file: prompts.txt, err: %v", err) - return "" + return "", err } - allPrompts := string(bytes) - - propmts := strings.Split(allPrompts, "---------------------------------------") - var res string - for _, keyAndPrompt := range propmts { - keyAndPrompt = stringx.TrimSpaceAndBr(keyAndPrompt) - // 获取第一行的Key信息如:--XXX_YYY - info := strings.SplitN(keyAndPrompt, "\n", 2) - // prompt,即去除第一行的key与备注信息 - prompt := info[1] - // 获取key;如:XXX_YYY - promptKey := strings.Split(strings.Split(info[0], " ")[0], "--")[1] - if key == promptKey { - res = prompt - } - promptCache[promptKey] = prompt - } - return fmt.Sprintf(res, formatValues...) + return stringx.TemplateParse(string(bytes), values) } diff --git a/server/internal/ai/prompt/prompts.txt b/server/internal/ai/prompt/prompts.txt deleted file mode 100644 index c3d2d7ee..00000000 --- a/server/internal/ai/prompt/prompts.txt +++ /dev/null @@ -1,38 +0,0 @@ ---FLOW_BIZ_AUDIT 流程业务审核 -你现在是一位专业的数据库管理员、Redis管理员和安全审核专家。请根据以下审核规则分析用户提供的内容,并以严格的JSON格式返回分析结果。 -- 当用户询问表结构时,禁止凭经验回答,可以使用 QueryTableInfo 工具获取真实表DDL数据进行核验字段等。 - -审核规则: -%s - -待审核内容为结构化的业务操作请求,包含以下要素: -1. 操作指令:可能包含单条或多条SQL语句和/或Redis命令 -2. 数据库上下文:每条指令关联的目标数据库信息,包括: - - 数据库唯一标识(ID) - - 数据库实例名称 - - 数据库类型(如MySQL、PostgreSQL、Redis等) - -请根据指令类型和目标数据库类型,应用相应的安全审核规则进行逐条验证。若存在任何不符合安全规范的指令,整体审核结果应判定为不通过。 - -请严格遵循以下要求: -1. 仅输出有效的JSON对象,不要包含任何解释性文字 -2. 禁止包含任何Markdown格式(包括但不限于```json、```等代码引用符号) -3. JSON格式必须严格包含以下字段且无额外内容: -{ - "allowExecute": boolean, // 是否允许执行操作,true或false - "suggestion": string // 具体的建议内容,如"通过"或"拒绝原因"等。如果是多条命令审核,请详细说明哪条命令存在问题 -} ---------------------------------------- ---SQL_GENERATE 生成SQL -你是一位专业的SQL开发工程师,请根据用户的自然语言描述,生成符合%s语法的SQL语句。 - -相关表名:%s - -请确保生成的SQL语句: -1. 语法正确,符合指定数据库类型的标准 -2. 逻辑清晰,准确表达用户的需求 -3. 仅返回SQL语句,不要包含任何解释或说明 -4. 避免使用可能导致性能问题的写法 -5. 确保SQL语句的安全性,防止SQL注入等安全问题 - -如果用户的需求不明确或无法完全实现,请说明原因。 \ No newline at end of file diff --git a/server/internal/ai/prompt/prompts/flow_biz_audit.md b/server/internal/ai/prompt/prompts/flow_biz_audit.md new file mode 100644 index 00000000..8572c370 --- /dev/null +++ b/server/internal/ai/prompt/prompts/flow_biz_audit.md @@ -0,0 +1,141 @@ +# 系统角色 + +你是一位 **专业的数据库管理员(DBA)**、**Redis 管理员** 和 **安全审核专家**。 + +你的唯一职责是: +✅ **对用户提交的结构化业务操作请求进行安全审核,并输出是否允许执行的最终裁决。** + +⚠️ **重要限制**: +- 不执行任何实际操作 +- 不提供 SQL 改写或优化建议 +- 不基于经验或推测回答任何数据库事实 + +--- + +# 审核规则 + +{{.rule}} + +--- + +# 核心强制原则(必须严格遵守) + +## 1. 禁止猜测与推测 + +- ❌ 严禁基于经验、习惯或主观推测回答以下内容: + - 表结构 + - 字段信息(名称、类型等) + - 索引信息(主键、唯一索引、普通索引) +- ✅ 仅允许基于 **真实、已验证的数据库元数据** 进行审核判断 + +--- + +## 2. 工具调用强制性 + +### 触发条件 + +当审核规则中出现以下任一要求时,**必须调用对应工具获取真实数据**: +- 校验表是否存在 +- 校验字段是否存在或字段类型 +- 校验主键、唯一索引、普通索引 +- 判断 SQL 是否依赖索引或命中索引安全规则 + +### 严格限制 + +- ❌ 禁止在未调用工具的情况下直接给出审核结论 +- ❌ 禁止基于经验或假设进行判断 + +--- + +## 3. 信息不足即不通过 + +- 若审核所需的事实信息缺失、无法确认,或未通过工具校验 +- ✅ **必须直接判定为不通过(allowExecute = false)** + +--- + +# 工具说明 + +## 工具:QueryTableInfo + +### 功能说明 + +用于查询指定数据库表的真实结构信息,包括: +- 字段名称 +- 字段类型 +- 主键字段 +- 索引信息(唯一索引 / 普通索引) + +### 调用规则 + +- 当审核规则要求校验 **表、字段或索引** 时: + - ✅ **必须调用该工具** + - ❌ **不得跳过、替代或基于经验判断** + +### 参数约束 + +- 工具参数必须 **完全来源于 SQL 解析结果** +- ❌ 禁止编造表名、字段名或数据库信息 + +--- + +# 输入说明 + +每条业务操作请求包含以下内容: + +1. SQL 或 Redis 指令(单条或多条) +2. 数据库上下文信息: + - 数据库唯一标识(ID) + - 数据库实例名称 + - 数据库类型(MySQL / PostgreSQL / Redis 等) + +--- + +# 输出说明(极其重要) + +## 输出格式强制要求 + +- ✅ **最终只能输出一个 JSON 对象** +- ❌ **不得输出任何额外文字、解释、Markdown、日志或调试信息** +- ✅ **输出结果必须可被 JSON.parse 成功解析** +- ✅ **即使审核失败,也必须输出 JSON** + +--- + +## 输出结构(字段不可增删) + +```json +{ + "allowExecute": boolean, + "suggestion": string +} +``` + +--- + +## 字段说明 + +### allowExecute + +- true:所有指令均符合安全规范,且所需信息已完整校验 +- false:存在任意不符合安全规范的指令,或审核信息不足无法判断 + +### suggestion + +- **通过**:填写“通过”或简要确认说明 +- **不通过**: + - 必须明确指出 **具体不通过原因** + - 若包含多条指令,需明确指出 **哪一条指令** 存在问题 + - 表述需简洁、明确,不得模糊或推测 + +--- + +# 最终裁决规则(不可违背) + +- 只要存在 **任意一条** 不符合安全规范的指令 + → **整体审核结果必须为不通过** +- ❌ 禁止输出以下非确定性结论: + - “部分通过” + - “建议执行” + - “可能有风险” + - “视情况而定” diff --git a/server/internal/auth/api/account_login.go b/server/internal/auth/api/account_login.go index 9614ea14..9a1d379d 100644 --- a/server/internal/auth/api/account_login.go +++ b/server/internal/auth/api/account_login.go @@ -47,7 +47,7 @@ func (a *AccountLogin) ReqConfs() *req.Confs { // @router /auth/accounts/login [post] func (a *AccountLogin) Login(rc *req.Ctx) { - loginForm := req.BindJson[*form.LoginForm](rc) + loginForm := req.BindJson[form.LoginForm](rc) ctx := rc.MetaCtx accountLoginSecurity := config.GetAccountLoginSecurity() @@ -96,7 +96,7 @@ type OtpVerifyInfo struct { // OTP双因素校验 func (a *AccountLogin) OtpVerify(rc *req.Ctx) { - otpVerify := req.BindJson[*form.OtpVerfiy](rc) + otpVerify := req.BindJson[form.OtpVerfiy](rc) ctx := rc.MetaCtx tokenKey := fmt.Sprintf("otp:token:%s", otpVerify.OtpToken) diff --git a/server/internal/auth/api/ldap_login.go b/server/internal/auth/api/ldap_login.go index 0504d468..e990cc67 100644 --- a/server/internal/auth/api/ldap_login.go +++ b/server/internal/auth/api/ldap_login.go @@ -47,7 +47,7 @@ func (a *LdapLogin) GetLdapEnabled(rc *req.Ctx) { // @router /auth/ldap/login [post] func (a *LdapLogin) Login(rc *req.Ctx) { - loginForm := req.BindJson[*form.LoginForm](rc) + loginForm := req.BindJson[form.LoginForm](rc) ctx := rc.MetaCtx accountLoginSecurity := config.GetAccountLoginSecurity() // 判断是否有开启登录验证码校验 diff --git a/server/internal/db/ai/tools/query_table_info.go b/server/internal/db/ai/tools/query_table_info.go index 89979761..9bbbddb1 100644 --- a/server/internal/db/ai/tools/query_table_info.go +++ b/server/internal/db/ai/tools/query_table_info.go @@ -4,59 +4,36 @@ import ( "context" "mayfly-go/internal/db/application" - "mayfly-go/pkg/logx" - "mayfly-go/pkg/utils/jsonx" "github.com/cloudwego/eino/components/tool" - "github.com/cloudwego/eino/schema" + "github.com/cloudwego/eino/components/tool/utils" ) -func GetQueryTableInfo() tool.InvokableTool { - return &QueryTableInfo{} +type QueryTableInfoParam struct { + DbId uint64 `json:"dbId" jsonschema_description:"数据库ID"` + DbName string `json:"dbName" jsonschema_description:"数据库名称"` + TableName string `json:"tableName" jsonschema_description:"表名"` } -type QueryTableInfo struct { +type QueryTableInfoOutput struct { + DDL string `json:"ddl" jsonschema_description:"表DDL"` } -var _ tool.InvokableTool = (*QueryTableInfo)(nil) - -func (q QueryTableInfo) Info(ctx context.Context) (*schema.ToolInfo, error) { - return &schema.ToolInfo{ - Name: "QueryTableInfo", - Desc: "查询数据库表的详细信息,包括表结构、字段定义、索引等。当用户需要了解某个表的结构时使用此工具。", - ParamsOneOf: schema.NewParamsOneOfByParams(map[string]*schema.ParameterInfo{ - "dbId": { - Type: schema.Number, - Desc: "数据库ID", - Required: true, - }, - "dbName": { - Type: schema.String, - Desc: "数据库名称", - Required: true, - }, - "tableName": { - Type: schema.String, - Desc: "表名", - Required: true, - }, - }), - }, nil -} - -func (q QueryTableInfo) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { - logx.Debugf("开始查询数据库表信息: %s", argumentsInJSON) - m, err := jsonx.ToMap(argumentsInJSON) - if err != nil { - return "arguments json invalid", err - } - - tableName := m.GetStr("tableName") - conn, err := application.GetDbApp().GetDbConn(ctx, uint64(m.GetInt64("dbId")), m.GetStr("dbName")) - if err != nil { - return "获取数据库连接失败", err - } - - return conn.GetMetadata().GetTableDDL(tableName, false) +func GetQueryTableInfo() (tool.InvokableTool, error) { + return utils.InferTool("QueryTableInfo", + "当需要了解某个表结构时,请调用此工具。使用它来查询数据库表的DDL信息,包括表结构、字段定义、索引等。", + func(ctx context.Context, param *QueryTableInfoParam) (*QueryTableInfoOutput, error) { + conn, err := application.GetDbApp().GetDbConn(ctx, param.DbId, param.DbName) + if err != nil { + return nil, err + } + ddl, err := conn.GetMetadata().GetTableDDL(param.TableName, false) + if err != nil { + return nil, err + } + output := &QueryTableInfoOutput{DDL: ddl} + return output, nil + }, + ) } diff --git a/server/internal/db/ai/tools/tools.go b/server/internal/db/ai/tools/tools.go index aa80d820..032991ce 100644 --- a/server/internal/db/ai/tools/tools.go +++ b/server/internal/db/ai/tools/tools.go @@ -1,7 +1,14 @@ package tools -import "mayfly-go/internal/ai/agent" +import ( + "mayfly-go/internal/ai/agent" + "mayfly-go/pkg/logx" +) func Init() { - agent.RegisterTool(agent.ToolTypeDb, GetQueryTableInfo()) + if queryTableTool, err := GetQueryTableInfo(); err != nil { + logx.Errorf("agent tool - 获取QueryTableInfo工具失败: %v", err) + } else { + agent.RegisterTool(agent.ToolTypeDb, queryTableTool) + } } diff --git a/server/internal/db/api/db.go b/server/internal/db/api/db.go index 2b054bf1..cc614756 100644 --- a/server/internal/db/api/db.go +++ b/server/internal/db/api/db.go @@ -74,7 +74,7 @@ func (d *Db) ReqConfs() *req.Confs { // @router /api/dbs [get] func (d *Db) Dbs(rc *req.Ctx) { - queryCond := req.BindQuery[*entity.DbQuery](rc) + queryCond := req.BindQuery[entity.DbQuery](rc) // 不存在可访问标签id,即没有可操作数据 tags := d.tagApp.GetAccountTags(rc.GetLoginAccount().Id, &tagentity.TagTreeQuery{ @@ -112,7 +112,7 @@ func (d *Db) Dbs(rc *req.Ctx) { } func (d *Db) Save(rc *req.Ctx) { - form, db := req.BindJsonAndCopyTo[*form.DbForm, *entity.Db](rc) + form, db := req.BindJsonAndCopyTo[form.DbForm, entity.Db](rc) rc.ReqParam = form biz.ErrIsNil(d.dbApp.SaveDb(rc.MetaCtx, db)) @@ -132,7 +132,7 @@ func (d *Db) DeleteDb(rc *req.Ctx) { /** 数据库操作相关、执行sql等 ***/ func (d *Db) ExecSql(rc *req.Ctx) { - form := req.BindJson[*form.DbSqlExecForm](rc) + form := req.BindJson[form.DbSqlExecForm](rc) ctx, cancel := context.WithTimeout(rc.MetaCtx, time.Duration(config.GetDbms().SqlExecTl)*time.Second) defer cancel() @@ -339,7 +339,7 @@ func (d *Db) GetSchemas(rc *req.Ctx) { } func (d *Db) CopyTable(rc *req.Ctx) { - form, copy := req.BindJsonAndCopyTo[*form.DbCopyTableForm, *dbi.DbCopyTable](rc) + form, copy := req.BindJsonAndCopyTo[form.DbCopyTableForm, dbi.DbCopyTable](rc) conn, err := d.dbApp.GetDbConn(rc.MetaCtx, form.Id, form.Db) biz.ErrIsNilAppendErr(err, "copy table error: %s") diff --git a/server/internal/db/api/db_data_sync.go b/server/internal/db/api/db_data_sync.go index bb11a043..7b0e0460 100644 --- a/server/internal/db/api/db_data_sync.go +++ b/server/internal/db/api/db_data_sync.go @@ -50,21 +50,21 @@ func (d *DataSyncTask) ReqConfs() *req.Confs { } func (d *DataSyncTask) Tasks(rc *req.Ctx) { - queryCond := req.BindQuery[*entity.DataSyncTaskQuery](rc) + queryCond := req.BindQuery[entity.DataSyncTaskQuery](rc) res, err := d.dataSyncTaskApp.GetPageList(queryCond) biz.ErrIsNil(err) rc.ResData = model.PageResultConv[*entity.DataSyncTask, *vo.DataSyncTaskListVO](res) } func (d *DataSyncTask) Logs(rc *req.Ctx) { - queryCond := req.BindQuery[*entity.DataSyncLogQuery](rc) + queryCond := req.BindQuery[entity.DataSyncLogQuery](rc) res, err := d.dataSyncTaskApp.GetTaskLogList(queryCond) biz.ErrIsNil(err) rc.ResData = model.PageResultConv[*entity.DataSyncLog, *vo.DataSyncLogListVO](res) } func (d *DataSyncTask) SaveTask(rc *req.Ctx) { - form, task := req.BindJsonAndCopyTo[*form.DataSyncTaskForm, *entity.DataSyncTask](rc) + form, task := req.BindJsonAndCopyTo[form.DataSyncTaskForm, entity.DataSyncTask](rc) // 解码base64 sql sqlStr, err := utils.AesDecryptByLa(task.DataSql, rc.GetLoginAccount()) @@ -87,7 +87,7 @@ func (d *DataSyncTask) DeleteTask(rc *req.Ctx) { } func (d *DataSyncTask) ChangeStatus(rc *req.Ctx) { - form := req.BindJson[*form.DataSyncTaskStatusForm](rc) + form := req.BindJson[form.DataSyncTaskStatusForm](rc) rc.ReqParam = form task, err := d.dataSyncTaskApp.GetById(form.Id) diff --git a/server/internal/db/api/db_instance.go b/server/internal/db/api/db_instance.go index 9eed3c94..5addc27f 100644 --- a/server/internal/db/api/db_instance.go +++ b/server/internal/db/api/db_instance.go @@ -55,7 +55,7 @@ func (d *Instance) ReqConfs() *req.Confs { // Instances 获取数据库实例信息 // @router /api/instances [get] func (d *Instance) Instances(rc *req.Ctx) { - queryCond := req.BindQuery[*entity.InstanceQuery](rc) + queryCond := req.BindQuery[entity.InstanceQuery](rc) tags := d.tagApp.GetAccountTags(rc.GetLoginAccount().Id, &tagentity.TagTreeQuery{ TypePaths: collx.AsArray(tagentity.NewTypePaths(tagentity.TagTypeDbInstance, tagentity.TagTypeAuthCert)), @@ -90,14 +90,14 @@ func (d *Instance) Instances(rc *req.Ctx) { } func (d *Instance) TestConn(rc *req.Ctx) { - form, instance := req.BindJsonAndCopyTo[*form.InstanceForm, *entity.DbInstance](rc) + form, instance := req.BindJsonAndCopyTo[form.InstanceForm, entity.DbInstance](rc) biz.ErrIsNil(d.instanceApp.TestConn(rc.MetaCtx, instance, form.AuthCerts[0])) } // SaveInstance 保存数据库实例信息 // @router /api/instances [post] func (d *Instance) SaveInstance(rc *req.Ctx) { - form, instance := req.BindJsonAndCopyTo[*form.InstanceForm, *entity.DbInstance](rc) + form, instance := req.BindJsonAndCopyTo[form.InstanceForm, entity.DbInstance](rc) rc.ReqParam = form id, err := d.instanceApp.SaveDbInstance(rc.MetaCtx, &dto.SaveDbInstance{ @@ -132,7 +132,7 @@ func (d *Instance) DeleteInstance(rc *req.Ctx) { // 获取数据库实例的所有数据库名 func (d *Instance) GetDatabaseNames(rc *req.Ctx) { - form, instance := req.BindJsonAndCopyTo[*form.InstanceDbNamesForm, *entity.DbInstance](rc) + form, instance := req.BindJsonAndCopyTo[form.InstanceDbNamesForm, entity.DbInstance](rc) res, err := d.instanceApp.GetDatabases(rc.MetaCtx, instance, form.AuthCert) biz.ErrIsNil(err) rc.ResData = res diff --git a/server/internal/db/api/db_sql.go b/server/internal/db/api/db_sql.go index 91a5f4a8..8e671e8f 100644 --- a/server/internal/db/api/db_sql.go +++ b/server/internal/db/api/db_sql.go @@ -30,7 +30,7 @@ func (d *DbSql) ReqConfs() *req.Confs { // @router /api/db/:dbId/sql [post] func (d *DbSql) SaveSql(rc *req.Ctx) { - dbSqlForm := req.BindJson[*form.DbSqlSaveForm](rc) + dbSqlForm := req.BindJson[form.DbSqlSaveForm](rc) rc.ReqParam = dbSqlForm dbId := getDbId(rc) diff --git a/server/internal/db/api/db_sql_exec.go b/server/internal/db/api/db_sql_exec.go index a0ba3aca..44f238c8 100644 --- a/server/internal/db/api/db_sql_exec.go +++ b/server/internal/db/api/db_sql_exec.go @@ -26,7 +26,7 @@ func (d *DbSqlExec) ReqConfs() *req.Confs { } func (d *DbSqlExec) DbSqlExecs(rc *req.Ctx) { - queryCond := req.BindQuery[*entity.DbSqlExecQuery](rc) + queryCond := req.BindQuery[entity.DbSqlExecQuery](rc) if statusStr := rc.Query("status"); statusStr != "" { queryCond.Status = collx.ArrayMap[string, int8](strings.Split(statusStr, ","), func(val string) int8 { return cast.ToInt8(val) diff --git a/server/internal/db/api/db_transfer.go b/server/internal/db/api/db_transfer.go index 53704515..353a43d9 100644 --- a/server/internal/db/api/db_transfer.go +++ b/server/internal/db/api/db_transfer.go @@ -61,7 +61,7 @@ func (d *DbTransferTask) ReqConfs() *req.Confs { } func (d *DbTransferTask) Tasks(rc *req.Ctx) { - queryCond := req.BindQuery[*entity.DbTransferTaskQuery](rc) + queryCond := req.BindQuery[entity.DbTransferTaskQuery](rc) res, err := d.dbTransferTaskApp.GetPageList(queryCond) biz.ErrIsNil(err) @@ -78,7 +78,7 @@ func (d *DbTransferTask) Tasks(rc *req.Ctx) { } func (d *DbTransferTask) SaveTask(rc *req.Ctx) { - reqForm, task := req.BindJsonAndCopyTo[*form.DbTransferTaskForm, *entity.DbTransferTask](rc) + reqForm, task := req.BindJsonAndCopyTo[form.DbTransferTaskForm, entity.DbTransferTask](rc) rc.ReqParam = reqForm biz.ErrIsNil(d.dbTransferTaskApp.Save(rc.MetaCtx, task)) @@ -94,7 +94,7 @@ func (d *DbTransferTask) DeleteTask(rc *req.Ctx) { } func (d *DbTransferTask) ChangeStatus(rc *req.Ctx) { - form := req.BindJson[*form.DbTransferTaskStatusForm](rc) + form := req.BindJson[form.DbTransferTaskStatusForm](rc) rc.ReqParam = form task, err := d.dbTransferTaskApp.GetById(form.Id) @@ -117,7 +117,7 @@ func (d *DbTransferTask) Stop(rc *req.Ctx) { } func (d *DbTransferTask) Files(rc *req.Ctx) { - queryCond := req.BindQuery[*entity.DbTransferFileQuery](rc) + queryCond := req.BindQuery[entity.DbTransferFileQuery](rc) res, err := d.dbTransferFileApp.GetPageList(queryCond) biz.ErrIsNil(err) @@ -137,7 +137,7 @@ func (d *DbTransferTask) FileDel(rc *req.Ctx) { } func (d *DbTransferTask) FileRun(rc *req.Ctx) { - fm := req.BindJson[*form.DbTransferFileRunForm](rc) + fm := req.BindJson[form.DbTransferFileRunForm](rc) rc.ReqParam = fm diff --git a/server/internal/db/application/db_sql_exec.go b/server/internal/db/application/db_sql_exec.go index f324ae0e..4434d1df 100644 --- a/server/internal/db/application/db_sql_exec.go +++ b/server/internal/db/application/db_sql_exec.go @@ -298,7 +298,7 @@ func (d *dbSqlExecAppImpl) FlowBizHandle(ctx context.Context, bizHandleParam *fl return nil, nil } - execSqlBizForm, err := jsonx.To[*FlowDbExecSqlBizForm](procinst.BizForm) + execSqlBizForm, err := jsonx.To[FlowDbExecSqlBizForm](procinst.BizForm) if err != nil { return nil, errorx.NewBizf("failed to parse the business form information: %s", err.Error()) } @@ -603,7 +603,7 @@ func (d *dbSqlExecAppImpl) doExec(ctx context.Context, dbConn *dbi.DbConn, sql s return &dto.DbSqlExecRes{ Columns: []*dbi.QueryColumn{ - {Name: "rowsAffected", Type: "number"}, + {Name: "rowsAffected", Key:"rowsAffected", Type: "number"}, }, Res: res, Sql: sql, diff --git a/server/internal/docker/api/container_conf.go b/server/internal/docker/api/container_conf.go index f469b332..943249d4 100644 --- a/server/internal/docker/api/container_conf.go +++ b/server/internal/docker/api/container_conf.go @@ -35,7 +35,7 @@ func (cc *ContainerConf) ReqConfs() *req.Confs { } func (cc *ContainerConf) GetContainerPage(rc *req.Ctx) { - condition := req.BindQuery[*entity.ContainerQuery](rc) + condition := req.BindQuery[entity.ContainerQuery](rc) tags := cc.tagTreeApp.GetAccountTags(rc.GetLoginAccount().Id, &tagentity.TagTreeQuery{ TypePaths: collx.AsArray(tagentity.NewTypePaths(tagentity.TagTypeContainer)), @@ -68,7 +68,7 @@ func (cc *ContainerConf) GetContainerPage(rc *req.Ctx) { } func (c *ContainerConf) Save(rc *req.Ctx) { - machineForm, container := req.BindJsonAndCopyTo[*form.ContainerSave, *entity.Container](rc) + machineForm, container := req.BindJsonAndCopyTo[form.ContainerSave, entity.Container](rc) rc.ReqParam = machineForm biz.ErrIsNil(c.containerApp.SaveContainer(rc.MetaCtx, &dto.SaveContainer{ diff --git a/server/internal/es/api/es_instance.go b/server/internal/es/api/es_instance.go index 7cf13321..76e5619d 100644 --- a/server/internal/es/api/es_instance.go +++ b/server/internal/es/api/es_instance.go @@ -51,7 +51,7 @@ func (d *Instance) ReqConfs() *req.Confs { } func (d *Instance) Instances(rc *req.Ctx) { - queryCond := req.BindQuery[*entity.InstanceQuery](rc) + queryCond := req.BindQuery[entity.InstanceQuery](rc) // 只查询实例,兼容没有录入密码的实例 instTags := d.tagApp.GetAccountTags(rc.GetLoginAccount().Id, &tagentity.TagTreeQuery{ @@ -92,7 +92,7 @@ func (d *Instance) Instances(rc *req.Ctx) { } func (d *Instance) TestConn(rc *req.Ctx) { - fm, instance := req.BindJsonAndCopyTo[*form.InstanceForm, *entity.EsInstance](rc) + fm, instance := req.BindJsonAndCopyTo[form.InstanceForm, entity.EsInstance](rc) var ac *tagentity.ResourceAuthCert if len(fm.AuthCerts) > 0 { @@ -104,7 +104,7 @@ func (d *Instance) TestConn(rc *req.Ctx) { rc.ResData = res } func (d *Instance) SaveInstance(rc *req.Ctx) { - fm, instance := req.BindJsonAndCopyTo[*form.InstanceForm, *entity.EsInstance](rc) + fm, instance := req.BindJsonAndCopyTo[form.InstanceForm, entity.EsInstance](rc) rc.ReqParam = fm id, err := d.inst.SaveInst(rc.MetaCtx, &dto.SaveEsInstance{ diff --git a/server/internal/flow/api/procdef.go b/server/internal/flow/api/procdef.go index 887e2754..b85f5a04 100644 --- a/server/internal/flow/api/procdef.go +++ b/server/internal/flow/api/procdef.go @@ -48,7 +48,7 @@ func (p *Procdef) ReqConfs() *req.Confs { } func (p *Procdef) GetProcdefPage(rc *req.Ctx) { - cond, page := req.BindQueryAndPage[*entity.Procdef](rc) + cond, page := req.BindQueryAndPage[entity.Procdef](rc) res, err := p.procdefApp.GetPageList(cond, page) biz.ErrIsNil(err) @@ -87,7 +87,7 @@ func (p *Procdef) GetProcdef(rc *req.Ctx) { } func (a *Procdef) Save(rc *req.Ctx) { - form, procdef := req.BindJsonAndCopyTo[*form.Procdef, *entity.Procdef](rc) + form, procdef := req.BindJsonAndCopyTo[form.Procdef, entity.Procdef](rc) rc.ReqParam = form biz.ErrIsNil(a.procdefApp.SaveProcdef(rc.MetaCtx, &dto.SaveProcdef{ Procdef: procdef, @@ -97,7 +97,7 @@ func (a *Procdef) Save(rc *req.Ctx) { } func (a *Procdef) SaveFlowDef(rc *req.Ctx) { - form := req.BindJson[*form.ProcdefFlow](rc) + form := req.BindJson[form.ProcdefFlow](rc) rc.ReqParam = form biz.ErrIsNil(a.procdefApp.SaveFlowDef(rc.MetaCtx, &dto.SaveFlowDef{ diff --git a/server/internal/flow/api/procinst.go b/server/internal/flow/api/procinst.go index dfde105c..6817ee5f 100644 --- a/server/internal/flow/api/procinst.go +++ b/server/internal/flow/api/procinst.go @@ -35,7 +35,7 @@ func (p *Procinst) ReqConfs() *req.Confs { } func (p *Procinst) GetProcinstPage(rc *req.Ctx) { - cond := req.BindQuery[*entity.ProcinstQuery](rc) + cond := req.BindQuery[entity.ProcinstQuery](rc) // 非管理员只能获取自己申请的流程 if laId := rc.GetLoginAccount().Id; laId != consts.AdminId { cond.CreatorId = laId @@ -47,7 +47,7 @@ func (p *Procinst) GetProcinstPage(rc *req.Ctx) { } func (p *Procinst) ProcinstStart(rc *req.Ctx) { - startForm := req.BindJson[*form.ProcinstStart](rc) + startForm := req.BindJson[form.ProcinstStart](rc) _, err := p.procinstApp.StartProc(rc.MetaCtx, startForm.ProcdefId, &dto.StarProc{ BizType: startForm.BizType, BizKey: startForm.BizKey, diff --git a/server/internal/flow/api/procinst_task.go b/server/internal/flow/api/procinst_task.go index bb4a01b8..9d6699d5 100644 --- a/server/internal/flow/api/procinst_task.go +++ b/server/internal/flow/api/procinst_task.go @@ -41,7 +41,7 @@ func (p *ProcinstTask) ReqConfs() *req.Confs { } func (p *ProcinstTask) GetTasks(rc *req.Ctx) { - instTaskQuery := req.BindQuery[*entity.ProcinstTaskQuery](rc) + instTaskQuery := req.BindQuery[entity.ProcinstTaskQuery](rc) if laId := rc.GetLoginAccount().Id; laId != consts.AdminId { // 赋值操作人为当前登录账号 instTaskQuery.Assignee = fmt.Sprintf("%d", rc.GetLoginAccount().Id) @@ -74,7 +74,7 @@ func (p *ProcinstTask) GetTasks(rc *req.Ctx) { } func (p *ProcinstTask) PassTask(rc *req.Ctx) { - auditForm := req.BindJson[*form.ProcinstTaskAudit](rc) + auditForm := req.BindJson[form.ProcinstTaskAudit](rc) rc.ReqParam = auditForm la := rc.GetLoginAccount() @@ -84,7 +84,7 @@ func (p *ProcinstTask) PassTask(rc *req.Ctx) { } func (p *ProcinstTask) RejectTask(rc *req.Ctx) { - auditForm := req.BindJson[*form.ProcinstTaskAudit](rc) + auditForm := req.BindJson[form.ProcinstTaskAudit](rc) rc.ReqParam = auditForm la := rc.GetLoginAccount() @@ -94,7 +94,7 @@ func (p *ProcinstTask) RejectTask(rc *req.Ctx) { } func (p *ProcinstTask) BackTask(rc *req.Ctx) { - auditForm := req.BindJson[*form.ProcinstTaskAudit](rc) + auditForm := req.BindJson[form.ProcinstTaskAudit](rc) rc.ReqParam = auditForm la := rc.GetLoginAccount() diff --git a/server/internal/flow/application/node_aitask.go b/server/internal/flow/application/node_aitask.go index 757f703a..3c61c036 100644 --- a/server/internal/flow/application/node_aitask.go +++ b/server/internal/flow/application/node_aitask.go @@ -10,10 +10,9 @@ import ( "mayfly-go/internal/flow/infra/persistence" "mayfly-go/pkg/errorx" "mayfly-go/pkg/logx" + "mayfly-go/pkg/utils/collx" "mayfly-go/pkg/utils/jsonx" "time" - - "github.com/spf13/cast" ) /******************* AI任务节点 *******************/ @@ -69,13 +68,16 @@ func (u *AiTaskNodeBehavior) Execute(ctx *ExecutionCtx) error { flowNode := ctx.GetFlowNode() aitaskNode := ToAiTaskNode(flowNode) - aiagent, err := agent.NewAiAgent(ctx, agent.ToolTypeDb) + aiagent, err := agent.GetOpsExpertAgent(ctx, agent.ToolTypeDb) if err != nil { return err } auditRule := aitaskNode.AuditRule - sysPrompt := prompt.GetPrompt(prompt.FLOW_BIZ_AUDIT, auditRule) + sysPrompt, err := prompt.GetPrompt("flow_biz_audit.md", collx.Kvs("rule", auditRule)) + if err != nil { + return err + } procinst := ctx.Procinst now := time.Now() @@ -93,18 +95,18 @@ func (u *AiTaskNodeBehavior) Execute(ctx *ExecutionCtx) error { cancelCtx, cancelFunc := context.WithTimeout(context.Background(), 1*time.Minute) defer cancelFunc() - res, err := aiagent.GetChatMsg(cancelCtx, sysPrompt, jsonx.ToStr(procinst.BizForm)) + res, err := aiagent.Run(cancelCtx, sysPrompt, jsonx.ToStr(procinst.BizForm)) if err != nil { suggestion = fmt.Sprintf("AI agent response failed: %v", err) logx.Error(suggestion) } else { - resJson, err := jsonx.ToMap(res) + resJson, err := agent.ParseLLMJSON2Map(res) if err != nil { suggestion = fmt.Sprintf("AI agent response parsing to JSON failed: %v, response: %s", err, res) logx.Error(suggestion) } else { - allowExecute = cast.ToBool(resJson["allowExecute"]) - suggestion = cast.ToString(resJson["suggestion"]) + allowExecute = resJson.GetBool("allowExecute") + suggestion = resJson.GetStr("suggestion") } } @@ -152,7 +154,6 @@ func (u *AiTaskNodeBehavior) Execute(ctx *ExecutionCtx) error { // 跳转至开始节点,重新修改提交 ctx.Execution.State = entity.ExectionStateSuspended // 执行流挂起 return executionApp.MoveTo(ctx, ctx.GetFlowDef().GetNodeByType(FlowNodeTypeStart)[0]) - } return u.Leave(ctx) }) diff --git a/server/internal/flow/domain/entity/procdef.go b/server/internal/flow/domain/entity/procdef.go index e44c851d..ae987729 100644 --- a/server/internal/flow/domain/entity/procdef.go +++ b/server/internal/flow/domain/entity/procdef.go @@ -61,7 +61,7 @@ func (p *Procdef) GetFlowDef() *FlowDef { if p.FlowDef == "" { return nil } - flow, err := jsonx.To[*FlowDef](p.FlowDef) + flow, err := jsonx.To[FlowDef](p.FlowDef) if err != nil { logx.ErrorTrace("parse flow def failed", err) return flow diff --git a/server/internal/flow/domain/entity/procinst.go b/server/internal/flow/domain/entity/procinst.go index a9f2d348..2e6692df 100644 --- a/server/internal/flow/domain/entity/procinst.go +++ b/server/internal/flow/domain/entity/procinst.go @@ -43,7 +43,7 @@ func (a *Procinst) SetEnd() { // GetProcdefFlow 获取流程定义信息 func (p *Procinst) GetFlowDef() *FlowDef { - flow, err := jsonx.To[*FlowDef](p.FlowDef) + flow, err := jsonx.To[FlowDef](p.FlowDef) if err != nil { logx.ErrorTrace("parse procdef flow failed", err) return flow diff --git a/server/internal/machine/api/machine.go b/server/internal/machine/api/machine.go index c27351ad..67e67e9a 100644 --- a/server/internal/machine/api/machine.go +++ b/server/internal/machine/api/machine.go @@ -77,7 +77,7 @@ func (m *Machine) ReqConfs() *req.Confs { } func (m *Machine) Machines(rc *req.Ctx) { - condition := req.BindQuery[*entity.MachineQuery](rc) + condition := req.BindQuery[entity.MachineQuery](rc) tags := m.tagTreeApp.GetAccountTags(rc.GetLoginAccount().Id, &tagentity.TagTreeQuery{ TypePaths: collx.AsArray(tagentity.NewTypePaths(tagentity.TagTypeMachine, tagentity.TagTypeAuthCert)), @@ -144,7 +144,7 @@ func (m *Machine) MachineStats(rc *req.Ctx) { // 保存机器信息 func (m *Machine) SaveMachine(rc *req.Ctx) { - machineForm, me := req.BindJsonAndCopyTo[*form.MachineForm, *entity.Machine](rc) + machineForm, me := req.BindJsonAndCopyTo[form.MachineForm, entity.Machine](rc) rc.ReqParam = machineForm @@ -156,7 +156,7 @@ func (m *Machine) SaveMachine(rc *req.Ctx) { } func (m *Machine) TestConn(rc *req.Ctx) { - machineForm, me := req.BindJsonAndCopyTo[*form.MachineForm, *entity.Machine](rc) + machineForm, me := req.BindJsonAndCopyTo[form.MachineForm, entity.Machine](rc) // 测试连接 biz.ErrIsNilAppendErr(m.machineApp.TestConn(rc.MetaCtx, me, machineForm.AuthCerts[0]), "connection error: %s") } diff --git a/server/internal/machine/api/machine_cmd_conf.go b/server/internal/machine/api/machine_cmd_conf.go index 3efd6b9f..839eb6a0 100644 --- a/server/internal/machine/api/machine_cmd_conf.go +++ b/server/internal/machine/api/machine_cmd_conf.go @@ -33,7 +33,7 @@ func (mcc *MachineCmdConf) ReqConfs() *req.Confs { } func (m *MachineCmdConf) MachineCmdConfs(rc *req.Ctx) { - cond := req.BindQuery[*entity.MachineCmdConf](rc) + cond := req.BindQuery[entity.MachineCmdConf](rc) var vos []*vo.MachineCmdConfVO err := m.machineCmdConfApp.ListByCondToAny(cond, &vos) @@ -47,7 +47,7 @@ func (m *MachineCmdConf) MachineCmdConfs(rc *req.Ctx) { } func (m *MachineCmdConf) Save(rc *req.Ctx) { - cmdForm, mcj := req.BindJsonAndCopyTo[*form.MachineCmdConfForm, *entity.MachineCmdConf](rc) + cmdForm, mcj := req.BindJsonAndCopyTo[form.MachineCmdConfForm, entity.MachineCmdConf](rc) rc.ReqParam = cmdForm err := m.machineCmdConfApp.SaveCmdConf(rc.MetaCtx, &dto.SaveMachineCmdConf{ diff --git a/server/internal/machine/api/machine_cronjob.go b/server/internal/machine/api/machine_cronjob.go index 9d8a9e45..35e55c99 100644 --- a/server/internal/machine/api/machine_cronjob.go +++ b/server/internal/machine/api/machine_cronjob.go @@ -43,7 +43,7 @@ func (mcj *MachineCronJob) ReqConfs() *req.Confs { } func (m *MachineCronJob) MachineCronJobs(rc *req.Ctx) { - cond, pageParam := req.BindQueryAndPage[*entity.MachineCronJob](rc) + cond, pageParam := req.BindQueryAndPage[entity.MachineCronJob](rc) pageRes, err := m.machineCronJobApp.GetPageList(cond, pageParam) biz.ErrIsNil(err) @@ -62,7 +62,7 @@ func (m *MachineCronJob) MachineCronJobs(rc *req.Ctx) { } func (m *MachineCronJob) Save(rc *req.Ctx) { - jobForm, mcj := req.BindJsonAndCopyTo[*form.MachineCronJobForm, *entity.MachineCronJob](rc) + jobForm, mcj := req.BindJsonAndCopyTo[form.MachineCronJobForm, entity.MachineCronJob](rc) rc.ReqParam = jobForm err := m.machineCronJobApp.SaveMachineCronJob(rc.MetaCtx, &dto.SaveMachineCronJob{ @@ -89,7 +89,7 @@ func (m *MachineCronJob) RunCronJob(rc *req.Ctx) { } func (m *MachineCronJob) CronJobExecs(rc *req.Ctx) { - cond, pageParam := req.BindQueryAndPage[*entity.MachineCronJobExec](rc) + cond, pageParam := req.BindQueryAndPage[entity.MachineCronJobExec](rc) res, err := m.machineCronJobApp.GetExecPageList(cond, pageParam) biz.ErrIsNil(err) rc.ResData = res diff --git a/server/internal/machine/api/machine_file.go b/server/internal/machine/api/machine_file.go index ce41e3dc..80cee091 100644 --- a/server/internal/machine/api/machine_file.go +++ b/server/internal/machine/api/machine_file.go @@ -91,7 +91,7 @@ func (m *MachineFile) MachineFiles(rc *req.Ctx) { } func (m *MachineFile) SaveMachineFiles(rc *req.Ctx) { - fileForm, entity := req.BindJsonAndCopyTo[*form.MachineFileForm, *entity.MachineFile](rc) + fileForm, entity := req.BindJsonAndCopyTo[form.MachineFileForm, entity.MachineFile](rc) rc.ReqParam = fileForm biz.ErrIsNil(m.machineFileApp.Save(rc.MetaCtx, entity)) @@ -104,7 +104,7 @@ func (m *MachineFile) DeleteFile(rc *req.Ctx) { /*** sftp相关操作 */ func (m *MachineFile) CreateFile(rc *req.Ctx) { - opForm := req.BindJson[*form.CreateFileForm](rc) + opForm := req.BindJson[form.CreateFileForm](rc) path := opForm.Path attrs := collx.Kvs("path", path) @@ -123,7 +123,7 @@ func (m *MachineFile) CreateFile(rc *req.Ctx) { } func (m *MachineFile) ReadFileContent(rc *req.Ctx) { - opForm := req.BindQuery[*dto.MachineFileOp](rc) + opForm := req.BindQuery[dto.MachineFileOp](rc) readPath := opForm.Path ctx := rc.MetaCtx @@ -155,7 +155,7 @@ func (m *MachineFile) ReadFileContent(rc *req.Ctx) { } func (m *MachineFile) DownloadFile(rc *req.Ctx) { - opForm := req.BindQuery[*dto.MachineFileOp](rc) + opForm := req.BindQuery[dto.MachineFileOp](rc) readPath := opForm.Path @@ -183,7 +183,7 @@ func (m *MachineFile) DownloadFile(rc *req.Ctx) { } func (m *MachineFile) GetDirEntry(rc *req.Ctx) { - opForm := req.BindQuery[*dto.MachineFileOp](rc) + opForm := req.BindQuery[dto.MachineFileOp](rc) readPath := opForm.Path rc.ReqParam = fmt.Sprintf("path: %s", readPath) @@ -222,7 +222,7 @@ func (m *MachineFile) GetDirEntry(rc *req.Ctx) { } func (m *MachineFile) GetDirSize(rc *req.Ctx) { - opForm := req.BindQuery[*dto.MachineFileOp](rc) + opForm := req.BindQuery[dto.MachineFileOp](rc) size, err := m.machineFileApp.GetDirSize(rc.MetaCtx, opForm) biz.ErrIsNil(err) @@ -230,14 +230,14 @@ func (m *MachineFile) GetDirSize(rc *req.Ctx) { } func (m *MachineFile) GetFileStat(rc *req.Ctx) { - opForm := req.BindQuery[*dto.MachineFileOp](rc) + opForm := req.BindQuery[dto.MachineFileOp](rc) res, err := m.machineFileApp.FileStat(rc.MetaCtx, opForm) biz.ErrIsNil(err, res) rc.ResData = res } func (m *MachineFile) WriteFileContent(rc *req.Ctx) { - opForm := req.BindJson[*form.WriteFileContentForm](rc) + opForm := req.BindJson[form.WriteFileContentForm](rc) path := opForm.Path mi, err := m.machineFileApp.WriteFileContent(rc.MetaCtx, opForm.MachineFileOp, []byte(opForm.Content)) @@ -421,7 +421,7 @@ func (m *MachineFile) UploadFolder(rc *req.Ctx) { } func (m *MachineFile) RemoveFile(rc *req.Ctx) { - opForm := req.BindJson[*form.RemoveFileForm](rc) + opForm := req.BindJson[form.RemoveFileForm](rc) mi, err := m.machineFileApp.RemoveFile(rc.MetaCtx, opForm.MachineFileOp, opForm.Paths...) rc.ReqParam = collx.Kvs("machine", mi, "path", opForm) @@ -429,21 +429,21 @@ func (m *MachineFile) RemoveFile(rc *req.Ctx) { } func (m *MachineFile) CopyFile(rc *req.Ctx) { - opForm := req.BindJson[*form.CopyFileForm](rc) + opForm := req.BindJson[form.CopyFileForm](rc) mi, err := m.machineFileApp.Copy(rc.MetaCtx, opForm.MachineFileOp, opForm.ToPath, opForm.Paths...) biz.ErrIsNilAppendErr(err, "file copy error: %s") rc.ReqParam = collx.Kvs("machine", mi, "cp", opForm) } func (m *MachineFile) MvFile(rc *req.Ctx) { - opForm := req.BindJson[*form.CopyFileForm](rc) + opForm := req.BindJson[form.CopyFileForm](rc) mi, err := m.machineFileApp.Mv(rc.MetaCtx, opForm.MachineFileOp, opForm.ToPath, opForm.Paths...) rc.ReqParam = collx.Kvs("machine", mi, "mv", opForm) biz.ErrIsNilAppendErr(err, "file move error: %s") } func (m *MachineFile) Rename(rc *req.Ctx) { - renameForm := req.BindJson[*form.RenameForm](rc) + renameForm := req.BindJson[form.RenameForm](rc) mi, err := m.machineFileApp.Rename(rc.MetaCtx, renameForm.MachineFileOp, renameForm.Newname) rc.ReqParam = collx.Kvs("machine", mi, "rename", renameForm) biz.ErrIsNilAppendErr(err, "file rename error: %s") diff --git a/server/internal/machine/api/machine_script.go b/server/internal/machine/api/machine_script.go index a7c57108..1c34b113 100644 --- a/server/internal/machine/api/machine_script.go +++ b/server/internal/machine/api/machine_script.go @@ -54,7 +54,7 @@ func (m *MachineScript) MachineScriptCategorys(rc *req.Ctx) { } func (m *MachineScript) SaveMachineScript(rc *req.Ctx) { - form, machineScript := req.BindJsonAndCopyTo[*form.MachineScriptForm, *entity.MachineScript](rc) + form, machineScript := req.BindJsonAndCopyTo[form.MachineScriptForm, entity.MachineScript](rc) rc.ReqParam = form biz.ErrIsNil(m.machineScriptApp.Save(rc.MetaCtx, machineScript)) diff --git a/server/internal/machine/infra/cache/machine_stats.go b/server/internal/machine/infra/cache/machine_stats.go index be70a623..b5195d02 100644 --- a/server/internal/machine/infra/cache/machine_stats.go +++ b/server/internal/machine/infra/cache/machine_stats.go @@ -20,5 +20,5 @@ func GetMachineStats(machineId uint64) (*mcm.Stats, error) { if cacheStr == "" { return nil, errors.New("不存在该值") } - return jsonx.To[*mcm.Stats](cacheStr) + return jsonx.To[mcm.Stats](cacheStr) } diff --git a/server/internal/mongo/api/mongo.go b/server/internal/mongo/api/mongo.go index ab9907ba..f579b8fd 100644 --- a/server/internal/mongo/api/mongo.go +++ b/server/internal/mongo/api/mongo.go @@ -68,7 +68,7 @@ func (ma *Mongo) ReqConfs() *req.Confs { } func (m *Mongo) Mongos(rc *req.Ctx) { - queryCond := req.BindQuery[*entity.MongoQuery](rc) + queryCond := req.BindQuery[entity.MongoQuery](rc) // 不存在可访问标签id,即没有可操作数据 tags := m.tagTreeApp.GetAccountTags(rc.GetLoginAccount().Id, &tagentity.TagTreeQuery{ @@ -95,12 +95,12 @@ func (m *Mongo) Mongos(rc *req.Ctx) { } func (m *Mongo) TestConn(rc *req.Ctx) { - _, mongo := req.BindJsonAndCopyTo[*form.Mongo, *entity.Mongo](rc) + _, mongo := req.BindJsonAndCopyTo[form.Mongo, entity.Mongo](rc) biz.ErrIsNilAppendErr(m.mongoApp.TestConn(mongo), "connection error: %s") } func (m *Mongo) Save(rc *req.Ctx) { - form, mongo := req.BindJsonAndCopyTo[*form.Mongo, *entity.Mongo](rc) + form, mongo := req.BindJsonAndCopyTo[form.Mongo, entity.Mongo](rc) // 密码脱敏记录日志 form.Uri = func(str string) string { @@ -146,7 +146,7 @@ func (m *Mongo) Collections(rc *req.Ctx) { } func (m *Mongo) RunCommand(rc *req.Ctx) { - commandForm := req.BindJson[*form.MongoRunCommand](rc) + commandForm := req.BindJson[form.MongoRunCommand](rc) conn, err := m.mongoApp.GetMongoConn(rc.MetaCtx, m.GetMongoId(rc)) biz.ErrIsNil(err) @@ -176,7 +176,7 @@ func (m *Mongo) RunCommand(rc *req.Ctx) { } func (m *Mongo) FindCommand(rc *req.Ctx) { - commandForm := req.BindJson[*form.MongoFindCommand](rc) + commandForm := req.BindJson[form.MongoFindCommand](rc) conn, err := m.mongoApp.GetMongoConn(rc.MetaCtx, m.GetMongoId(rc)) biz.ErrIsNil(err) @@ -211,7 +211,7 @@ func (m *Mongo) FindCommand(rc *req.Ctx) { } func (m *Mongo) UpdateByIdCommand(rc *req.Ctx) { - commandForm := req.BindJson[*form.MongoUpdateByIdCommand](rc) + commandForm := req.BindJson[form.MongoUpdateByIdCommand](rc) conn, err := m.mongoApp.GetMongoConn(rc.MetaCtx, m.GetMongoId(rc)) biz.ErrIsNil(err) @@ -235,7 +235,7 @@ func (m *Mongo) UpdateByIdCommand(rc *req.Ctx) { } func (m *Mongo) DeleteByIdCommand(rc *req.Ctx) { - commandForm := req.BindJson[*form.MongoUpdateByIdCommand](rc) + commandForm := req.BindJson[form.MongoUpdateByIdCommand](rc) conn, err := m.mongoApp.GetMongoConn(rc.MetaCtx, m.GetMongoId(rc)) biz.ErrIsNil(err) @@ -258,7 +258,7 @@ func (m *Mongo) DeleteByIdCommand(rc *req.Ctx) { } func (m *Mongo) InsertOneCommand(rc *req.Ctx) { - commandForm := req.BindJson[*form.MongoInsertCommand](rc) + commandForm := req.BindJson[form.MongoInsertCommand](rc) conn, err := m.mongoApp.GetMongoConn(rc.MetaCtx, m.GetMongoId(rc)) biz.ErrIsNil(err) diff --git a/server/internal/msg/api/msg_channel.go b/server/internal/msg/api/msg_channel.go index 605b8a27..dbd244e1 100644 --- a/server/internal/msg/api/msg_channel.go +++ b/server/internal/msg/api/msg_channel.go @@ -36,7 +36,7 @@ func (m *MsgChannel) GetMsgChannels(rc *req.Ctx) { } func (m *MsgChannel) SaveMsgChannels(rc *req.Ctx) { - form, channel := req.BindJsonAndCopyTo[*form.MsgChannel, *entity.MsgChannel](rc) + form, channel := req.BindJsonAndCopyTo[form.MsgChannel, entity.MsgChannel](rc) rc.ReqParam = form err := m.msgChannelApp.SaveChannel(rc.MetaCtx, channel) biz.ErrIsNil(err) diff --git a/server/internal/msg/api/msg_tmpl.go b/server/internal/msg/api/msg_tmpl.go index c6897386..f186b62e 100644 --- a/server/internal/msg/api/msg_tmpl.go +++ b/server/internal/msg/api/msg_tmpl.go @@ -57,7 +57,7 @@ func (m *MsgTmpl) GetMsgTmplChannels(rc *req.Ctx) { } func (m *MsgTmpl) SaveMsgTmpl(rc *req.Ctx) { - form, channel := req.BindJsonAndCopyTo[*form.MsgTmpl, *dto.MsgTmplSave](rc) + form, channel := req.BindJsonAndCopyTo[form.MsgTmpl, dto.MsgTmplSave](rc) rc.ReqParam = form biz.ErrIsNil(m.msgTmplApp.SaveTmpl(rc.MetaCtx, channel)) } @@ -74,7 +74,7 @@ func (m *MsgTmpl) DelMsgTmpls(rc *req.Ctx) { func (m *MsgTmpl) SendMsg(rc *req.Ctx) { code := rc.PathParam("code") - form := req.BindJson[*form.SendMsg](rc) + form := req.BindJson[form.SendMsg](rc) rc.ReqParam = form diff --git a/server/internal/pkg/config/app.go b/server/internal/pkg/config/app.go index 4d9a6d16..e865077e 100644 --- a/server/internal/pkg/config/app.go +++ b/server/internal/pkg/config/app.go @@ -1,5 +1,5 @@ package config const ( - Version = "v1.10.9" + Version = "v1.10.10" ) diff --git a/server/internal/redis/api/cmd.go b/server/internal/redis/api/cmd.go index 2c07ad6b..667536e4 100644 --- a/server/internal/redis/api/cmd.go +++ b/server/internal/redis/api/cmd.go @@ -11,7 +11,7 @@ import ( ) func (r *Redis) RunCmd(rc *req.Ctx) { - cmdReq, runCmdParam := req.BindJsonAndCopyTo[*form.RunCmdForm, *dto.RunCmd](rc) + cmdReq, runCmdParam := req.BindJsonAndCopyTo[form.RunCmdForm, dto.RunCmd](rc) biz.IsTrue(len(cmdReq.Cmd) > 0, "redis cmd cannot be empty") redisConn := r.getRedisConn(rc) diff --git a/server/internal/redis/api/key.go b/server/internal/redis/api/key.go index b77827b5..3c75ce70 100644 --- a/server/internal/redis/api/key.go +++ b/server/internal/redis/api/key.go @@ -17,7 +17,7 @@ import ( func (r *Redis) ScanKeys(rc *req.Ctx) { ri := r.getRedisConn(rc) - form := req.BindJson[*form.RedisScanForm](rc) + form := req.BindJson[form.RedisScanForm](rc) cmd := ri.GetCmdable() ctx := context.Background() diff --git a/server/internal/redis/api/redis.go b/server/internal/redis/api/redis.go index f5d25779..45778eda 100644 --- a/server/internal/redis/api/redis.go +++ b/server/internal/redis/api/redis.go @@ -60,7 +60,7 @@ func (rs *Redis) ReqConfs() *req.Confs { } func (r *Redis) RedisList(rc *req.Ctx) { - queryCond := req.BindQuery[*entity.RedisQuery](rc) + queryCond := req.BindQuery[entity.RedisQuery](rc) // 不存在可访问标签id,即没有可操作数据 tags := r.tagApp.GetAccountTags(rc.GetLoginAccount().Id, &tagentity.TagTreeQuery{ @@ -87,7 +87,7 @@ func (r *Redis) RedisList(rc *req.Ctx) { } func (r *Redis) TestConn(rc *req.Ctx) { - form, redis := req.BindJsonAndCopyTo[*form.Redis, *entity.Redis](rc) + form, redis := req.BindJsonAndCopyTo[form.Redis, entity.Redis](rc) authCert := &tagentity.ResourceAuthCert{ Username: form.Username, @@ -109,7 +109,7 @@ func (r *Redis) TestConn(rc *req.Ctx) { } func (r *Redis) Save(rc *req.Ctx) { - form, redis := req.BindJsonAndCopyTo[*form.Redis, *entity.Redis](rc) + form, redis := req.BindJsonAndCopyTo[form.Redis, entity.Redis](rc) redisParam := &dto.SaveRedis{ Redis: redis, diff --git a/server/internal/redis/application/redis.go b/server/internal/redis/application/redis.go index 66b924e6..f2ef256d 100644 --- a/server/internal/redis/application/redis.go +++ b/server/internal/redis/application/redis.go @@ -252,7 +252,7 @@ func (r *redisAppImpl) FlowBizHandle(ctx context.Context, bizHandleParam *flowap return nil, nil } - runCmdParam, err := jsonx.To[*FlowRedisRunCmdBizForm](procinst.BizForm) + runCmdParam, err := jsonx.To[FlowRedisRunCmdBizForm](procinst.BizForm) if err != nil { return nil, errorx.NewBizf("failed to parse the business form information: %s", err.Error()) } diff --git a/server/internal/sys/api/account.go b/server/internal/sys/api/account.go index 319283f4..ac9afc6d 100644 --- a/server/internal/sys/api/account.go +++ b/server/internal/sys/api/account.go @@ -109,7 +109,7 @@ func (a *Account) GetPermissions(rc *req.Ctx) { func (a *Account) ChangePassword(rc *req.Ctx) { ctx := rc.MetaCtx - form := req.BindJson[*form.AccountChangePasswordForm](rc) + form := req.BindJson[form.AccountChangePasswordForm](rc) originOldPwd, err := utils.DefaultRsaDecrypt(form.OldPassword, true) biz.ErrIsNilAppendErr(err, "Wrong to decrypt old password: %s") @@ -146,7 +146,7 @@ func (a *Account) AccountInfo(rc *req.Ctx) { // 更新个人账号信息 func (a *Account) UpdateAccount(rc *req.Ctx) { - form, updateAccount := req.BindJsonAndCopyTo[*form.AccountUpdateForm, *entity.Account](rc) + form, updateAccount := req.BindJsonAndCopyTo[form.AccountUpdateForm, entity.Account](rc) // 账号id为登录者账号 updateAccount.Id = rc.GetLoginAccount().Id rc.ReqParam = form @@ -212,7 +212,7 @@ func (a *Account) AccountDetail(rc *req.Ctx) { // @router /accounts func (a *Account) SaveAccount(rc *req.Ctx) { - form, account := req.BindJsonAndCopyTo[*form.AccountCreateForm, *entity.Account](rc) + form, account := req.BindJsonAndCopyTo[form.AccountCreateForm, entity.Account](rc) form.Password = "*****" rc.ReqParam = form @@ -308,7 +308,7 @@ func (a *Account) AccountResources(rc *req.Ctx) { // 关联账号角色 func (a *Account) RelateRole(rc *req.Ctx) { - form := req.BindJson[*form.AccountRoleForm](rc) + form := req.BindJson[form.AccountRoleForm](rc) rc.ReqParam = form biz.ErrIsNil(a.roleApp.RelateAccountRole(rc.MetaCtx, form.Id, form.RoleId, consts.AccountRoleRelateType(form.RelateType))) } diff --git a/server/internal/sys/api/config.go b/server/internal/sys/api/config.go index fdc2c4be..65b03270 100644 --- a/server/internal/sys/api/config.go +++ b/server/internal/sys/api/config.go @@ -55,7 +55,7 @@ func (c *Config) GetConfigValueByKey(rc *req.Ctx) { } func (c *Config) SaveConfig(rc *req.Ctx) { - form, config := req.BindJsonAndCopyTo[*form.ConfigForm, *entity.Config](rc) + form, config := req.BindJsonAndCopyTo[form.ConfigForm, entity.Config](rc) rc.ReqParam = form biz.ErrIsNil(c.configApp.Save(rc.MetaCtx, config)) } diff --git a/server/internal/sys/api/resource.go b/server/internal/sys/api/resource.go index bc10a327..3f3c94db 100644 --- a/server/internal/sys/api/resource.go +++ b/server/internal/sys/api/resource.go @@ -50,7 +50,7 @@ func (r *Resource) GetById(rc *req.Ctx) { } func (r *Resource) SaveResource(rc *req.Ctx) { - form, entity := req.BindJsonAndCopyTo[*form.ResourceForm, *entity.Resource](rc) + form, entity := req.BindJsonAndCopyTo[form.ResourceForm, entity.Resource](rc) rc.ReqParam = form diff --git a/server/internal/sys/api/role.go b/server/internal/sys/api/role.go index 98882e69..5810e668 100644 --- a/server/internal/sys/api/role.go +++ b/server/internal/sys/api/role.go @@ -39,7 +39,7 @@ func (r *Role) ReqConfs() *req.Confs { } func (r *Role) Roles(rc *req.Ctx) { - cond := req.BindQuery[*entity.RoleQuery](rc) + cond := req.BindQuery[entity.RoleQuery](rc) notIdsStr := rc.Query("notIds") if notIdsStr != "" { @@ -61,7 +61,7 @@ func (r *Role) Roles(rc *req.Ctx) { // 保存角色信息 func (r *Role) SaveRole(rc *req.Ctx) { - form, role := req.BindJsonAndCopyTo[*form.RoleForm, *entity.Role](rc) + form, role := req.BindJsonAndCopyTo[form.RoleForm, entity.Role](rc) rc.ReqParam = form r.roleApp.SaveRole(rc.MetaCtx, role) @@ -92,7 +92,7 @@ func (r *Role) RoleResource(rc *req.Ctx) { // 保存角色资源 func (r *Role) SaveResource(rc *req.Ctx) { - form := req.BindJson[*form.RoleResourceForm](rc) + form := req.BindJson[form.RoleResourceForm](rc) rc.ReqParam = form // 将,拼接的字符串进行切割并转换 @@ -105,7 +105,7 @@ func (r *Role) SaveResource(rc *req.Ctx) { // 查看角色关联的用户 func (r *Role) RoleAccount(rc *req.Ctx) { - cond := req.BindQuery[*entity.RoleAccountQuery](rc) + cond := req.BindQuery[entity.RoleAccountQuery](rc) cond.RoleId = uint64(rc.PathParamInt("id")) res, err := r.roleApp.GetRoleAccountPage(cond) biz.ErrIsNil(err) diff --git a/server/internal/sys/api/syslog.go b/server/internal/sys/api/syslog.go index e2058204..da1cebc6 100644 --- a/server/internal/sys/api/syslog.go +++ b/server/internal/sys/api/syslog.go @@ -21,7 +21,7 @@ func (s *Syslog) ReqConfs() *req.Confs { } func (r *Syslog) Syslogs(rc *req.Ctx) { - queryCond := req.BindQuery[*entity.SysLogQuery](rc) + queryCond := req.BindQuery[entity.SysLogQuery](rc) res, err := r.syslogApp.GetPageList(queryCond, "create_time DESC") biz.ErrIsNil(err) rc.ResData = res diff --git a/server/internal/tag/api/resource_auth_cert.go b/server/internal/tag/api/resource_auth_cert.go index 1ac22740..e411c04e 100644 --- a/server/internal/tag/api/resource_auth_cert.go +++ b/server/internal/tag/api/resource_auth_cert.go @@ -72,7 +72,7 @@ func (r *ResourceAuthCert) GetCompleteAuthCert(rc *req.Ctx) { } func (c *ResourceAuthCert) SaveAuthCert(rc *req.Ctx) { - acForm, ac := req.BindJsonAndCopyTo[*form.AuthCertForm, *entity.ResourceAuthCert](rc) + acForm, ac := req.BindJsonAndCopyTo[form.AuthCertForm, entity.ResourceAuthCert](rc) // 脱敏记录日志 acForm.Ciphertext = "***" diff --git a/server/internal/tag/api/tag_tree.go b/server/internal/tag/api/tag_tree.go index 27992d81..ebe1d386 100644 --- a/server/internal/tag/api/tag_tree.go +++ b/server/internal/tag/api/tag_tree.go @@ -119,7 +119,7 @@ func (p *TagTree) ListByQuery(rc *req.Ctx) { } func (p *TagTree) SaveTagTree(rc *req.Ctx) { - tagForm, tagTree := req.BindJsonAndCopyTo[*form.TagTree, *entity.TagTree](rc) + tagForm, tagTree := req.BindJsonAndCopyTo[form.TagTree, entity.TagTree](rc) rc.ReqParam = fmt.Sprintf("tagTreeId: %d, tagName: %s, code: %s", tagTree.Id, tagTree.Name, tagTree.Code) @@ -131,7 +131,7 @@ func (p *TagTree) DelTagTree(rc *req.Ctx) { } func (p *TagTree) MovingTag(rc *req.Ctx) { - movingForm := req.BindJson[*form.MovingTag](rc) + movingForm := req.BindJson[form.MovingTag](rc) rc.ReqParam = movingForm biz.ErrIsNil(p.tagTreeApp.MovingTag(rc.MetaCtx, movingForm.FromPath, movingForm.ToPath)) } diff --git a/server/internal/tag/api/team.go b/server/internal/tag/api/team.go index 0aad5dc1..f3f56b20 100644 --- a/server/internal/tag/api/team.go +++ b/server/internal/tag/api/team.go @@ -46,7 +46,7 @@ func (t *Team) ReqConfs() *req.Confs { } func (p *Team) GetTeams(rc *req.Ctx) { - queryCond := req.BindQuery[*entity.TeamQuery](rc) + queryCond := req.BindQuery[entity.TeamQuery](rc) res, err := p.teamApp.GetPageList(queryCond) biz.ErrIsNil(err) @@ -60,7 +60,7 @@ func (p *Team) GetTeams(rc *req.Ctx) { } func (p *Team) SaveTeam(rc *req.Ctx) { - team := req.BindJson[*dto.SaveTeam](rc) + team := req.BindJson[dto.SaveTeam](rc) rc.ReqParam = team biz.ErrIsNil(p.teamApp.SaveTeam(rc.MetaCtx, team)) } @@ -87,7 +87,7 @@ func (p *Team) GetTeamMembers(rc *req.Ctx) { // 保存团队信息 func (p *Team) SaveTeamMember(rc *req.Ctx) { - teamMems := req.BindJson[*form.TeamMember](rc) + teamMems := req.BindJson[form.TeamMember](rc) teamId := teamMems.TeamId diff --git a/server/internal/tag/application/resouce_auth_cert.go b/server/internal/tag/application/resouce_auth_cert.go index 9862fa5f..c0ddf142 100644 --- a/server/internal/tag/application/resouce_auth_cert.go +++ b/server/internal/tag/application/resouce_auth_cert.go @@ -35,13 +35,13 @@ type ResourceAuthCert interface { GetResourceAuthCert(resourceType entity.TagType, resourceCode string) (*entity.ResourceAuthCert, error) // FillAuthCertByAcs 根据授权凭证列表填充资源的授权凭证信息 - // @param authCerts 授权凭证列表 - // @param resources 实现了entity.IAuthCert接口的资源信息 + // - authCerts 授权凭证列表 + // - resources 实现了entity.IAuthCert接口的资源信息 FillAuthCertByAcs(authCerts []*entity.ResourceAuthCert, resources ...entity.IAuthCert) // FillAuthCert 填充资源对应的授权凭证信息 - // @param resourceType 资源类型 - // @param resources 实现了entity.IAuthCert接口的资源信息 + // - resourceType 资源类型 + // - resources 实现了entity.IAuthCert接口的资源信息 FillAuthCert(resourceType int8, resources ...entity.IAuthCert) // FillAuthCertByAcNames 根据授权凭证名称填充资源对应的凭证信息 @@ -151,11 +151,15 @@ func (r *resourceAuthCertAppImpl) RelateAuthCert(ctx context.Context, params *dt unmodifyAc := name2AuthCert[unmodifyAcName] oldAuthCert := oldName2AuthCert[unmodifyAcName] - if !unmodifyAc.HasChanged(oldAuthCert) { - logx.DebugfContext(ctx, "RelateAuthCert[%d-%s] - Authorization credential [%s] No field changes", resourceType, resourceCode, unmodifyAcName) - continue + // 密文和密码前端不回显,故如果没修改,需要重新赋值,避免被修改为空 + if unmodifyAc.Ciphertext == "" { + unmodifyAc.Ciphertext = oldAuthCert.Ciphertext } - + passphrase := unmodifyAc.GetExtraString(entity.ExtraKeyPassphrase) + if passphrase == "" { + unmodifyAc.SetExtraValue(entity.ExtraKeyPassphrase, oldAuthCert.GetExtraString(entity.ExtraKeyPassphrase)) + } + // 如果修改了用户名,且该凭证关联至标签,则需要更新对应的标签名(资源授权凭证类型的标签名为username) if oldAuthCert.Username != unmodifyAc.Username { r.tagTreeApp.UpdateTagName(ctx, entity.TagTypeAuthCert, unmodifyAcName, unmodifyAc.Username) @@ -376,10 +380,6 @@ func (r *resourceAuthCertAppImpl) updateAuthCert(ctx context.Context, rac *entit return errorx.NewBiz("ac not found") } - if !oldRac.HasChanged(rac) { - return nil - } - if oldRac.Type == entity.AuthCertTypePublic { // 如果旧凭证为公共凭证,则不允许修改凭证类型 if rac.Type != entity.AuthCertTypePublic { diff --git a/server/internal/tag/domain/entity/resource_auth_cert.go b/server/internal/tag/domain/entity/resource_auth_cert.go index 0c761801..f7848be5 100644 --- a/server/internal/tag/domain/entity/resource_auth_cert.go +++ b/server/internal/tag/domain/entity/resource_auth_cert.go @@ -82,19 +82,6 @@ func (m *ResourceAuthCert) CiphertextClear() { m.SetExtraValue(ExtraKeyPassphrase, "") } -// HasChanged 与指定授权凭证比较是否有变更 -func (m *ResourceAuthCert) HasChanged(rac *ResourceAuthCert) bool { - if rac == nil { - return true - } - return m.Username != rac.Username || - (m.Ciphertext != rac.Ciphertext) || - m.CiphertextType != rac.CiphertextType || - m.Remark != rac.Remark || - m.Type != rac.Type || - m.GetExtraString(ExtraKeyPassphrase) != rac.GetExtraString(ExtraKeyPassphrase) -} - // 密文类型 type AuthCertCiphertextType int8 diff --git a/server/pkg/req/util.go b/server/pkg/req/util.go index 7d9b4025..ae452e1c 100644 --- a/server/pkg/req/util.go +++ b/server/pkg/req/util.go @@ -10,38 +10,38 @@ import ( ) // BindJson 绑定并校验请求结构体参数 -func BindJson[T any](rc *Ctx) T { - data := structx.NewInstance[T]() - if err := rc.BindJSON(data); err != nil { +func BindJson[T any](rc *Ctx) *T { + var data T + if err := rc.BindJSON(&data); err != nil { panic(ConvBindValidationError(data, err)) } else { - return data + return &data } } // BindJsonAndCopyTo 绑定请求体中的json至form结构体,并拷贝至指定结构体 -func BindJsonAndCopyTo[F, T any](rc *Ctx) (F, T) { +func BindJsonAndCopyTo[F, T any](rc *Ctx) (*F, *T) { f := BindJson[F](rc) return f, structx.CopyTo[T](f) } // BindQuery 绑定查询字符串到指定结构体 -func BindQuery[T any](rc *Ctx) T { - data := structx.NewInstance[T]() - if err := rc.BindQuery(data); err != nil { +func BindQuery[T any](rc *Ctx) *T { + var data T + if err := rc.BindQuery(&data); err != nil { panic(ConvBindValidationError(data, err)) } else { - return data + return &data } } // BindQueryAndPage 绑定查询字符串到指定结构体,并将分页信息也返回 -func BindQueryAndPage[T any](rc *Ctx) (T, model.PageParam) { - data := structx.NewInstance[T]() - if err := rc.BindQuery(data); err != nil { +func BindQueryAndPage[T any](rc *Ctx) (*T, model.PageParam) { + var data T + if err := rc.BindQuery(&data); err != nil { panic(ConvBindValidationError(data, err)) } else { - return data, rc.GetPageParam() + return &data, rc.GetPageParam() } } diff --git a/server/pkg/utils/jsonx/jsonx.go b/server/pkg/utils/jsonx/jsonx.go index c080f354..fdd55a8e 100644 --- a/server/pkg/utils/jsonx/jsonx.go +++ b/server/pkg/utils/jsonx/jsonx.go @@ -4,12 +4,11 @@ import ( "encoding/json" "mayfly-go/pkg/logx" "mayfly-go/pkg/utils/collx" - "mayfly-go/pkg/utils/structx" "github.com/tidwall/gjson" ) -// json字符串转map +// ToMap json字符串转map func ToMap(jsonStr string) (collx.M, error) { if jsonStr == "" { return map[string]any{}, nil @@ -17,13 +16,16 @@ func ToMap(jsonStr string) (collx.M, error) { return ToMapByBytes([]byte(jsonStr)) } -// json字符串转结构体, T需为指针类型 -func To[T any](jsonStr string) (T, error) { - res := structx.NewInstance[T]() - return res, json.Unmarshal([]byte(jsonStr), res) +// To json字符串转指定类型 +func To[T any](jsonStr string) (*T, error) { + var v T + if err := json.Unmarshal([]byte(jsonStr), &v); err != nil { + return nil, err + } + return &v, nil } -// json字节数组转map +// ToMapByBytes json字节数组转map func ToMapByBytes(bytes []byte) (collx.M, error) { var res map[string]any err := json.Unmarshal(bytes, &res) @@ -42,42 +44,42 @@ func ToStr(val any) string { // 根据json字节数组获取对应字段路径的string类型值 // -// @param fieldPath字段路径。如user.username等 +// - fieldPath字段路径。如user.username等 func GetStringByBytes(bytes []byte, fieldPath string) (string, error) { return gjson.GetBytes(bytes, fieldPath).String(), nil } // 根据json字符串获取对应字段路径的string类型值 // -// @param fieldPath字段路径。如user.username等 +// - fieldPath字段路径。如user.username等 func GetString(jsonStr string, fieldPath string) (string, error) { return gjson.Get(jsonStr, fieldPath).String(), nil } // 根据json字节数组获取对应字段路径的int类型值 // -// @param fieldPath字段路径。如user.age等 +// - fieldPath字段路径。如user.age等 func GetIntByBytes(bytes []byte, fieldPath string) (int64, error) { return gjson.GetBytes(bytes, fieldPath).Int(), nil } // 根据json字符串获取对应字段路径的int类型值 // -// @param fieldPath字段路径。如user.age等 +// - fieldPath字段路径。如user.age等 func GetInt(jsonStr string, fieldPath string) (int64, error) { return gjson.Get(jsonStr, fieldPath).Int(), nil } // 根据json字节数组获取对应字段路径的bool类型值 // -// @param fieldPath字段路径。如user.isDeleted等 +// - fieldPath字段路径。如user.isDeleted等 func GetBoolByBytes(bytes []byte, fieldPath string) (bool, error) { return gjson.GetBytes(bytes, fieldPath).Bool(), nil } // 根据json字符串获取对应字段路径的bool类型值 // -// @param fieldPath字段路径。如user.isDeleted等 +// - fieldPath字段路径。如user.isDeleted等 func GetBool(jsonStr string, fieldPath string) (bool, error) { return GetBoolByBytes([]byte(jsonStr), fieldPath) } diff --git a/server/pkg/utils/structx/reflect.go b/server/pkg/utils/structx/reflect.go index 791a86c2..acf5ea51 100644 --- a/server/pkg/utils/structx/reflect.go +++ b/server/pkg/utils/structx/reflect.go @@ -2,25 +2,6 @@ package structx import "reflect" -// NewInstance 创建泛型 T 的实例。如果 T 是指针类型,则创建其指向类型的实例并返回指针。 -func NewInstance[T any]() T { - var t T - - // 反射判断是否是指针类型,并且是否为 nil - if reflect.ValueOf(t).Kind() == reflect.Ptr { - // 创建 T 对应的非指针类型的实例,并取其地址作为新的 T - t = reflect.New(reflect.TypeOf(t).Elem()).Interface().(T) - } else if kind := reflect.TypeOf(t).Kind(); kind == reflect.Array || kind == reflect.Slice { - // 如果是数组或切片类型,创建一个新的切片(数组) - elemType := reflect.TypeOf(t).Elem() - newSlice := reflect.MakeSlice(reflect.SliceOf(elemType), 0, 0) - t = newSlice.Interface().(T) - } - - return t -} - - // IsZeroValue 检查字段是否为零值 func IsZeroValue(v reflect.Value) bool { switch v.Kind() { @@ -39,4 +20,4 @@ func IsZeroValue(v reflect.Value) bool { default: return reflect.DeepEqual(v.Interface(), reflect.Zero(v.Type()).Interface()) } -} \ No newline at end of file +} diff --git a/server/pkg/utils/structx/structx.go b/server/pkg/utils/structx/structx.go index 081eb376..5e370a95 100644 --- a/server/pkg/utils/structx/structx.go +++ b/server/pkg/utils/structx/structx.go @@ -11,10 +11,10 @@ import ( ) // CopyTo 将fromValue转为T类型并返回 -func CopyTo[T any](fromValue any) T { - t := NewInstance[T]() - Copy(t, fromValue) - return t +func CopyTo[T any](fromValue any) *T { + var t T + Copy(&t, fromValue) + return &t } // CopySliceTo 将fromValue转为[]T类型并返回