diff --git a/cmd/sql-dump/main.go b/cmd/sql-dump/main.go index c4742cf2..ed31d6fc 100644 --- a/cmd/sql-dump/main.go +++ b/cmd/sql-dump/main.go @@ -18,7 +18,7 @@ func main() { fmt.Println("[ERROR]" + err.Error()) return } - results, err := setup.NewSQLDump().Dump(db) + results, err := setup.NewSQLDump().Dump(db, true) if err != nil { fmt.Println("[ERROR]" + err.Error()) return diff --git a/internal/setup/sql_dump.go b/internal/setup/sql_dump.go index 9166d2bc..72b0759c 100644 --- a/internal/setup/sql_dump.go +++ b/internal/setup/sql_dump.go @@ -8,6 +8,7 @@ import ( "github.com/iwind/TeaGo/lists" "github.com/iwind/TeaGo/types" "regexp" + "runtime" "strings" "sync" ) @@ -53,14 +54,20 @@ func NewSQLDump() *SQLDump { } // Dump 导出数据 -func (this *SQLDump) Dump(db *dbs.DB) (result *SQLDumpResult, err error) { +func (this *SQLDump) Dump(db *dbs.DB, includingRecords bool) (result *SQLDumpResult, err error) { result = &SQLDumpResult{} tableNames, err := db.TableNames() if err != nil { return result, err } - for _, tableName := range tableNames { + + fullTableMap, err := this.findFullTables(db, tableNames) + if err != nil { + return nil, err + } + + for tableName, table := range fullTableMap { // 忽略一些分表 if strings.HasPrefix(strings.ToLower(tableName), strings.ToLower("edgeHTTPAccessLogs_")) { continue @@ -69,10 +76,6 @@ func (this *SQLDump) Dump(db *dbs.DB) (result *SQLDumpResult, err error) { continue } - table, err := db.FindFullTable(tableName) - if err != nil { - return nil, err - } sqlTable := &SQLTable{ Name: table.Name, Engine: table.Engine, @@ -102,28 +105,30 @@ func (this *SQLDump) Dump(db *dbs.DB) (result *SQLDumpResult, err error) { // Records var records = []*SQLRecord{} - recordsTable := this.findRecordsTable(tableName) - if recordsTable != nil { - ones, _, err := db.FindOnes("SELECT * FROM " + tableName + " ORDER BY id ASC") - if err != nil { - return result, err - } - for _, one := range ones { - record := &SQLRecord{ - Id: one.GetInt64("id"), - Values: map[string]string{}, - UniqueFields: recordsTable.UniqueFields, - ExceptFields: recordsTable.ExceptFields, + if includingRecords { + recordsTable := this.findRecordsTable(tableName) + if recordsTable != nil { + ones, _, err := db.FindOnes("SELECT * FROM " + tableName + " ORDER BY id ASC") + if err != nil { + return result, err } - for k, v := range one { - // 需要排除的字段 - if lists.ContainsString(record.ExceptFields, k) { - continue + for _, one := range ones { + record := &SQLRecord{ + Id: one.GetInt64("id"), + Values: map[string]string{}, + UniqueFields: recordsTable.UniqueFields, + ExceptFields: recordsTable.ExceptFields, } + for k, v := range one { + // 需要排除的字段 + if lists.ContainsString(record.ExceptFields, k) { + continue + } - record.Values[k] = types.String(v) + record.Values[k] = types.String(v) + } + records = append(records, record) } - records = append(records, record) } } sqlTable.Records = records @@ -218,7 +223,7 @@ func (this *SQLDump) applyQueue(db *dbs.DB, newResult *SQLDumpResult, showLog bo } } - currentResult, err := this.Dump(db) + currentResult, err := this.Dump(db, false) if err != nil { return nil, err } @@ -343,7 +348,7 @@ func (this *SQLDump) applyQueue(db *dbs.DB, newResult *SQLDumpResult, showLog bo // + for _, record := range newTable.Records { var queryArgs = []string{} - var queryValues = []interface{}{} + var queryValues = []any{} var valueStrings = []string{} for _, field := range record.UniqueFields { queryArgs = append(queryArgs, field+"=?") @@ -390,8 +395,8 @@ func (this *SQLDump) applyQueue(db *dbs.DB, newResult *SQLDumpResult, showLog bo if showLog { fmt.Println("* record " + newTable.Name + " " + strings.Join(valueStrings, ", ")) } - args := []string{} - values := []interface{}{} + var args = []string{} + var values = []any{} for k, v := range record.Values { if k == "id" { continue @@ -418,6 +423,59 @@ func (this *SQLDump) applyQueue(db *dbs.DB, newResult *SQLDumpResult, showLog bo return } +// 查找所有表的完整信息 +func (this *SQLDump) findFullTables(db *dbs.DB, tableNames []string) (map[string]*dbs.Table, error) { + var fullTableMap = map[string]*dbs.Table{} + if len(tableNames) == 0 { + return fullTableMap, nil + } + + var locker = &sync.Mutex{} + var queue = make(chan string, len(tableNames)) + for _, tableName := range tableNames { + queue <- tableName + } + + var wg = &sync.WaitGroup{} + var concurrent = 8 + + if runtime.NumCPU() > 4 { + concurrent = 32 + } + + wg.Add(concurrent) + var lastErr error + for i := 0; i < concurrent; i++ { + go func() { + defer wg.Done() + + for { + select { + case tableName := <-queue: + table, err := db.FindFullTable(tableName) + if err != nil { + locker.Lock() + lastErr = err + locker.Unlock() + return + } + locker.Lock() + fullTableMap[tableName] = table + locker.Unlock() + default: + return + } + } + }() + } + wg.Wait() + if lastErr != nil { + return nil, lastErr + } + + return fullTableMap, nil +} + // 查找有记录的表 func (this *SQLDump) findRecordsTable(tableName string) *SQLRecordsTable { for _, table := range recordsTables { diff --git a/internal/setup/sql_dump_test.go b/internal/setup/sql_dump_test.go index 4e146e08..e28faee7 100644 --- a/internal/setup/sql_dump_test.go +++ b/internal/setup/sql_dump_test.go @@ -21,7 +21,7 @@ func TestSQLDump_Dump(t *testing.T) { }() dump := NewSQLDump() - result, err := dump.Dump(db) + result, err := dump.Dump(db, true) if err != nil { t.Fatal(err) } @@ -64,7 +64,7 @@ func TestSQLDump_Apply(t *testing.T) { }() var dump = NewSQLDump() - result, err := dump.Dump(db) + result, err := dump.Dump(db, true) if err != nil { t.Fatal(err) } diff --git a/internal/setup/sql_executor.go b/internal/setup/sql_executor.go index 71facc0b..f7508619 100644 --- a/internal/setup/sql_executor.go +++ b/internal/setup/sql_executor.go @@ -55,7 +55,7 @@ func (this *SQLExecutor) Run(showLog bool) error { _ = db.Close() }() - sqlDump := NewSQLDump() + var sqlDump = NewSQLDump() _, err = sqlDump.Apply(db, LatestSQLResult, showLog) if err != nil { return err