修复升级数据库时主键可能冲突的问题

This commit is contained in:
GoEdgeLab
2022-08-15 00:02:38 +08:00
parent 82b7eee2b7
commit 24c22ea624
2 changed files with 23 additions and 12 deletions

File diff suppressed because one or more lines are too long

View File

@@ -74,7 +74,7 @@ func (this *SQLDump) Dump(db *dbs.DB) (result *SQLDumpResult, err error) {
} }
// 字段 // 字段
fields := []*SQLField{} var fields = []*SQLField{}
for _, field := range table.Fields { for _, field := range table.Fields {
fields = append(fields, &SQLField{ fields = append(fields, &SQLField{
Name: field.Name, Name: field.Name,
@@ -84,7 +84,7 @@ func (this *SQLDump) Dump(db *dbs.DB) (result *SQLDumpResult, err error) {
sqlTable.Fields = fields sqlTable.Fields = fields
// 索引 // 索引
indexes := []*SQLIndex{} var indexes = []*SQLIndex{}
for _, index := range table.Indexes { for _, index := range table.Indexes {
indexes = append(indexes, &SQLIndex{ indexes = append(indexes, &SQLIndex{
Name: index.Name, Name: index.Name,
@@ -94,7 +94,7 @@ func (this *SQLDump) Dump(db *dbs.DB) (result *SQLDumpResult, err error) {
sqlTable.Indexes = indexes sqlTable.Indexes = indexes
// Records // Records
records := []*SQLRecord{} var records = []*SQLRecord{}
recordsTable := this.findRecordsTable(tableName) recordsTable := this.findRecordsTable(tableName)
if recordsTable != nil { if recordsTable != nil {
ones, _, err := db.FindOnes("SELECT * FROM " + tableName + " ORDER BY id ASC") ones, _, err := db.FindOnes("SELECT * FROM " + tableName + " ORDER BY id ASC")
@@ -151,7 +151,7 @@ func (this *SQLDump) Apply(db *dbs.DB, newResult *SQLDumpResult, showLog bool) (
// 对比字段 // 对比字段
// + // +
for _, newField := range newTable.Fields { for _, newField := range newTable.Fields {
oldField := oldTable.FindField(newField.Name) var oldField = oldTable.FindField(newField.Name)
if oldField == nil { if oldField == nil {
var op = "+ " + newTable.Name + " " + newField.Name var op = "+ " + newTable.Name + " " + newField.Name
ops = append(ops, op) ops = append(ops, op)
@@ -178,7 +178,7 @@ func (this *SQLDump) Apply(db *dbs.DB, newResult *SQLDumpResult, showLog bool) (
// 对比索引 // 对比索引
// + // +
for _, newIndex := range newTable.Indexes { for _, newIndex := range newTable.Indexes {
oldIndex := oldTable.FindIndex(newIndex.Name) var oldIndex = oldTable.FindIndex(newIndex.Name)
if oldIndex == nil { if oldIndex == nil {
var op = "+ index " + newTable.Name + " " + newIndex.Name var op = "+ index " + newTable.Name + " " + newIndex.Name
ops = append(ops, op) ops = append(ops, op)
@@ -214,7 +214,7 @@ func (this *SQLDump) Apply(db *dbs.DB, newResult *SQLDumpResult, showLog bool) (
// - // -
for _, oldIndex := range oldTable.Indexes { for _, oldIndex := range oldTable.Indexes {
newIndex := newTable.FindIndex(oldIndex.Name) var newIndex = newTable.FindIndex(oldIndex.Name)
if newIndex == nil { if newIndex == nil {
var op = "- index " + oldTable.Name + " " + oldIndex.Name var op = "- index " + oldTable.Name + " " + oldIndex.Name
ops = append(ops, op) ops = append(ops, op)
@@ -231,7 +231,7 @@ func (this *SQLDump) Apply(db *dbs.DB, newResult *SQLDumpResult, showLog bool) (
// 对比字段 // 对比字段
// - // -
for _, oldField := range oldTable.Fields { for _, oldField := range oldTable.Fields {
newField := newTable.FindField(oldField.Name) var newField = newTable.FindField(oldField.Name)
if newField == nil { if newField == nil {
var op = "- field " + oldTable.Name + " " + oldField.Name var op = "- field " + oldTable.Name + " " + oldField.Name
ops = append(ops, op) ops = append(ops, op)
@@ -257,7 +257,17 @@ func (this *SQLDump) Apply(db *dbs.DB, newResult *SQLDumpResult, showLog bool) (
queryValues = append(queryValues, record.Values[field]) queryValues = append(queryValues, record.Values[field])
valueStrings = append(valueStrings, record.Values[field]) valueStrings = append(valueStrings, record.Values[field])
} }
one, err := db.FindOne("SELECT * FROM "+newTable.Name+" WHERE "+strings.Join(queryArgs, " AND "), queryValues...)
var recordId int64
for field, recordValue := range record.Values {
if field == "id" {
recordId = types.Int64(recordValue)
break
}
}
queryValues = append(queryValues, recordId)
one, err := db.FindOne("SELECT * FROM "+newTable.Name+" WHERE (("+strings.Join(queryArgs, " AND ")+") OR id=?)", queryValues...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -266,9 +276,9 @@ func (this *SQLDump) Apply(db *dbs.DB, newResult *SQLDumpResult, showLog bool) (
if showLog { if showLog {
fmt.Println("+ record " + newTable.Name + " " + strings.Join(valueStrings, ", ")) fmt.Println("+ record " + newTable.Name + " " + strings.Join(valueStrings, ", "))
} }
params := []string{} var params = []string{}
args := []string{} var args = []string{}
values := []interface{}{} var values = []any{}
for k, v := range record.Values { for k, v := range record.Values {
// 需要排除的字段 // 需要排除的字段
if lists.ContainsString(record.ExceptFields, k) { if lists.ContainsString(record.ExceptFields, k) {
@@ -280,6 +290,7 @@ func (this *SQLDump) Apply(db *dbs.DB, newResult *SQLDumpResult, showLog bool) (
args = append(args, "?") args = append(args, "?")
values = append(values, v) values = append(values, v)
} }
_, err = db.Exec("INSERT INTO "+newTable.Name+" ("+strings.Join(params, ", ")+") VALUES ("+strings.Join(args, ", ")+")", values...) _, err = db.Exec("INSERT INTO "+newTable.Name+" ("+strings.Join(params, ", ")+") VALUES ("+strings.Join(args, ", ")+")", values...)
if err != nil { if err != nil {
return nil, err return nil, err