diff --git a/internal/db/models/origin_dao.go b/internal/db/models/origin_dao.go index a2a73066..8771406b 100644 --- a/internal/db/models/origin_dao.go +++ b/internal/db/models/origin_dao.go @@ -86,12 +86,45 @@ func (this *OriginDAO) FindOriginName(tx *dbs.Tx, id int64) (string, error) { } // 创建源站 -func (this *OriginDAO) CreateOrigin(tx *dbs.Tx, adminId int64, userId int64, name string, addrJSON string, description string, weight int32, isOn bool) (originId int64, err error) { +func (this *OriginDAO) CreateOrigin(tx *dbs.Tx, adminId int64, userId int64, name string, addrJSON string, description string, weight int32, isOn bool, connTimeout *shared.TimeDuration, readTimeout *shared.TimeDuration, idleTimeout *shared.TimeDuration, maxConns int32, maxIdleConns int32) (originId int64, err error) { op := NewOriginOperator() op.AdminId = adminId op.UserId = userId op.IsOn = isOn op.Name = name + + if connTimeout != nil { + connTimeoutJSON, err := connTimeout.AsJSON() + if err != nil { + return 0, err + } + op.ConnTimeout = connTimeoutJSON + } + if readTimeout != nil { + readTimeoutJSON, err := readTimeout.AsJSON() + if err != nil { + return 0, err + } + op.ReadTimeout = readTimeoutJSON + } + if idleTimeout != nil { + idleTimeoutJSON, err := idleTimeout.AsJSON() + if err != nil { + return 0, err + } + op.IdleTimeout = idleTimeoutJSON + } + if maxConns >= 0 { + op.MaxConns = maxConns + } else { + op.MaxConns = 0 + } + if maxIdleConns >= 0 { + op.MaxIdleConns = maxIdleConns + } else { + op.MaxIdleConns = 0 + } + op.Addr = addrJSON op.Description = description if weight < 0 { @@ -107,7 +140,7 @@ func (this *OriginDAO) CreateOrigin(tx *dbs.Tx, adminId int64, userId int64, nam } // 修改源站 -func (this *OriginDAO) UpdateOrigin(tx *dbs.Tx, originId int64, name string, addrJSON string, description string, weight int32, isOn bool) error { +func (this *OriginDAO) UpdateOrigin(tx *dbs.Tx, originId int64, name string, addrJSON string, description string, weight int32, isOn bool, connTimeout *shared.TimeDuration, readTimeout *shared.TimeDuration, idleTimeout *shared.TimeDuration, maxConns int32, maxIdleConns int32) error { if originId <= 0 { return errors.New("invalid originId") } @@ -120,6 +153,39 @@ func (this *OriginDAO) UpdateOrigin(tx *dbs.Tx, originId int64, name string, add weight = 0 } op.Weight = weight + + if connTimeout != nil { + connTimeoutJSON, err := connTimeout.AsJSON() + if err != nil { + return err + } + op.ConnTimeout = connTimeoutJSON + } + if readTimeout != nil { + readTimeoutJSON, err := readTimeout.AsJSON() + if err != nil { + return err + } + op.ReadTimeout = readTimeoutJSON + } + if idleTimeout != nil { + idleTimeoutJSON, err := idleTimeout.AsJSON() + if err != nil { + return err + } + op.IdleTimeout = idleTimeoutJSON + } + if maxConns >= 0 { + op.MaxConns = maxConns + } else { + op.MaxConns = 0 + } + if maxIdleConns >= 0 { + op.MaxIdleConns = maxIdleConns + } else { + op.MaxIdleConns = 0 + } + op.IsOn = isOn op.Version = dbs.SQL("version+1") err := this.Save(tx, op) diff --git a/internal/rpc/services/service_origin.go b/internal/rpc/services/service_origin.go index 0012ce8c..114b77e7 100644 --- a/internal/rpc/services/service_origin.go +++ b/internal/rpc/services/service_origin.go @@ -6,6 +6,7 @@ import ( "errors" "github.com/TeaOSLab/EdgeAPI/internal/db/models" "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" + "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared" "github.com/iwind/TeaGo/maps" ) @@ -32,7 +33,32 @@ func (this *OriginService) CreateOrigin(ctx context.Context, req *pb.CreateOrigi tx := this.NullTx() - originId, err := models.SharedOriginDAO.CreateOrigin(tx, adminId, userId, req.Name, string(addrMap.AsJSON()), req.Description, req.Weight, req.IsOn) + // 校验参数 + var connTimeout = &shared.TimeDuration{} + if len(req.ConnTimeoutJSON) > 0 { + err = json.Unmarshal(req.ConnTimeoutJSON, connTimeout) + if err != nil { + return nil, err + } + } + + var readTimeout = &shared.TimeDuration{} + if len(req.ReadTimeoutJSON) > 0 { + err = json.Unmarshal(req.ReadTimeoutJSON, readTimeout) + if err != nil { + return nil, err + } + } + + var idleTimeout = &shared.TimeDuration{} + if len(req.IdleTimeoutJSON) > 0 { + err = json.Unmarshal(req.IdleTimeoutJSON, idleTimeout) + if err != nil { + return nil, err + } + } + + originId, err := models.SharedOriginDAO.CreateOrigin(tx, adminId, userId, req.Name, string(addrMap.AsJSON()), req.Description, req.Weight, req.IsOn, connTimeout, readTimeout, idleTimeout, req.MaxConns, req.MaxIdleConns) if err != nil { return nil, err } @@ -61,7 +87,32 @@ func (this *OriginService) UpdateOrigin(ctx context.Context, req *pb.UpdateOrigi tx := this.NullTx() - err = models.SharedOriginDAO.UpdateOrigin(tx, req.OriginId, req.Name, string(addrMap.AsJSON()), req.Description, req.Weight, req.IsOn) + // 校验参数 + var connTimeout = &shared.TimeDuration{} + if len(req.ConnTimeoutJSON) > 0 { + err = json.Unmarshal(req.ConnTimeoutJSON, connTimeout) + if err != nil { + return nil, err + } + } + + var readTimeout = &shared.TimeDuration{} + if len(req.ReadTimeoutJSON) > 0 { + err = json.Unmarshal(req.ReadTimeoutJSON, readTimeout) + if err != nil { + return nil, err + } + } + + var idleTimeout = &shared.TimeDuration{} + if len(req.IdleTimeoutJSON) > 0 { + err = json.Unmarshal(req.IdleTimeoutJSON, idleTimeout) + if err != nil { + return nil, err + } + } + + err = models.SharedOriginDAO.UpdateOrigin(tx, req.OriginId, req.Name, string(addrMap.AsJSON()), req.Description, req.Weight, req.IsOn, connTimeout, readTimeout, idleTimeout, req.MaxConns, req.MaxIdleConns) if err != nil { return nil, err }