diff --git a/internal/rpc/services/service_server.go b/internal/rpc/services/service_server.go index 76b05391..1a8d8d3c 100644 --- a/internal/rpc/services/service_server.go +++ b/internal/rpc/services/service_server.go @@ -9,13 +9,17 @@ import ( "github.com/TeaOSLab/EdgeAPI/internal/db/models/clients" "github.com/TeaOSLab/EdgeAPI/internal/db/models/dns" "github.com/TeaOSLab/EdgeAPI/internal/utils" + "github.com/TeaOSLab/EdgeAPI/internal/utils/domainutils" "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" + "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/sslconfigs" "github.com/iwind/TeaGo/dbs" "github.com/iwind/TeaGo/lists" "github.com/iwind/TeaGo/maps" "github.com/iwind/TeaGo/types" timeutil "github.com/iwind/TeaGo/utils/time" + "net" + "net/url" "regexp" "strings" ) @@ -184,6 +188,261 @@ func (this *ServerService) CreateServer(ctx context.Context, req *pb.CreateServe return &pb.CreateServerResponse{ServerId: serverId}, nil } +// CreateBasicHTTPServer 快速创建基本的HTTP网站 +func (this *ServerService) CreateBasicHTTPServer(ctx context.Context, req *pb.CreateBasicHTTPServerRequest) (*pb.CreateBasicHTTPServerResponse, error) { + adminId, userId, err := this.ValidateAdminAndUser(ctx, true) + if err != nil { + return nil, err + } + + // 集群 + var tx = this.NullTx() + if userId > 0 { + req.UserId = userId + + nodeClusterId, err := models.SharedUserDAO.FindUserClusterId(tx, userId) + if err != nil { + return nil, err + } + req.NodeClusterId = nodeClusterId + } else if adminId > 0 && req.UserId > 0 && req.NodeClusterId <= 0 { + // check user + existUser, err := models.SharedUserDAO.Exist(tx, req.UserId) + if err != nil { + return nil, err + } + if !existUser { + return nil, errors.New("user id '" + types.String(req.UserId) + "' not found") + } + + nodeClusterId, err := models.SharedUserDAO.FindUserClusterId(tx, userId) + if err != nil { + return nil, err + } + req.NodeClusterId = nodeClusterId + } + + if req.NodeClusterId <= 0 { + return nil, errors.New("invalid 'nodeClusterId'") + } + + if len(req.Domains) == 0 { + return nil, errors.New("'domains' should not be empty") + } + var serverNames = []*serverconfigs.ServerNameConfig{} + for _, domain := range req.Domains { + if !domainutils.ValidateDomainFormat(domain) { + return nil, errors.New("invalid domain format '" + domain + "'") + } + serverNames = append(serverNames, &serverconfigs.ServerNameConfig{Name: strings.ToLower(domain)}) + } + serverNamesJSON, err := json.Marshal(serverNames) + if err != nil { + return nil, errors.New("encode 'serverNames' failed: " + err.Error()) + } + + // 是否需要审核 + var isAuditing = false + var auditingServerNamesJSON = []byte("[]") + if userId > 0 { + // 如果域名不为空的时候需要审核 + if len(serverNamesJSON) > 0 && string(serverNamesJSON) != "[]" { + globalConfig, err := models.SharedSysSettingDAO.ReadGlobalConfig(tx) + if err != nil { + return nil, err + } + if globalConfig != nil && globalConfig.HTTPAll.DomainAuditingIsOn { + isAuditing = true + serverNamesJSON = []byte("[]") + auditingServerNamesJSON = serverNamesJSON + } + } + } + + // HTTP + var httpConfig = &serverconfigs.HTTPProtocolConfig{ + BaseProtocol: serverconfigs.BaseProtocol{ + IsOn: true, + Listen: []*serverconfigs.NetworkAddressConfig{ + { + Protocol: "http", + PortRange: "80", + }, + }, + }, + } + httpJSON, err := json.Marshal(httpConfig) + if err != nil { + return nil, err + } + + // HTTPS + var certRefs = []*sslconfigs.SSLCertRef{} + for _, certId := range req.SslCertIds { + // 检查所有权 + if userId > 0 { + err = models.SharedSSLCertDAO.CheckUserCert(tx, certId, userId) + if err != nil { + return nil, errors.New("check cert permission failed: " + err.Error()) + } + } else { + existCert, err := models.SharedSSLCertDAO.Exist(tx, certId) + if err != nil { + return nil, err + } + if !existCert { + return nil, errors.New("cert '" + types.String(certId) + "' not found") + } + } + + certRefs = append(certRefs, &sslconfigs.SSLCertRef{ + IsOn: true, + CertId: certId, + }) + } + certRefsJSON, err := json.Marshal(certRefs) + if err != nil { + return nil, err + } + sslPolicyId, err := models.SharedSSLPolicyDAO.CreatePolicy(tx, adminId, req.UserId, false, false, "TLS 1.0", certRefsJSON, nil, false, 0, nil, false, nil) + if err != nil { + return nil, err + } + + var httpsConfig = &serverconfigs.HTTPSProtocolConfig{ + BaseProtocol: serverconfigs.BaseProtocol{ + IsOn: true, + Listen: []*serverconfigs.NetworkAddressConfig{ + { + Protocol: "https", + PortRange: "443", + }, + }, + }, + SSLPolicyRef: &sslconfigs.SSLPolicyRef{ + IsOn: true, + SSLPolicyId: sslPolicyId, + }, + } + httpsJSON, err := json.Marshal(httpsConfig) + if err != nil { + return nil, err + } + + // Reverse Proxy + var reverseProxyScheduleConfig = &serverconfigs.SchedulingConfig{ + Code: "random", + Options: nil, + } + reverseProxyScheduleJSON, err := json.Marshal(reverseProxyScheduleConfig) + + var primaryOrigins = []*serverconfigs.OriginRef{} + for _, originAddr := range req.OriginAddrs { + u, err := url.Parse(originAddr) + if err != nil { + return nil, errors.New("parse origin address '" + originAddr + "' failed: " + err.Error()) + } + if len(u.Scheme) == 0 || (u.Scheme != "http" && u.Scheme != "https" /** 特意不支持大写形式 **/) { + return nil, errors.New("invalid scheme in origin address '" + originAddr + "'") + } + + if len(u.Host) == 0 { + return nil, errors.New("invalid host address '" + originAddr + "', contains no host") + } + + host, port, err := net.SplitHostPort(u.Host) + if err != nil { + err = nil // ignore error + + if domainutils.ValidateDomainFormat(u.Host) { // host with no port + host = u.Host + port = "" + } else { + return nil, errors.New("invalid host address '" + originAddr + "', invalid host format") + } + } + + if len(port) == 0 { + switch u.Scheme { + case "http": + port = "80" + case "https": + port = "443" + } + } + + var addr = &serverconfigs.NetworkAddressConfig{ + Protocol: serverconfigs.Protocol(u.Scheme), + Host: host, + PortRange: port, + } + addrJSON, err := json.Marshal(addr) + if err != nil { + return nil, err + } + + originId, err := models.SharedOriginDAO.CreateOrigin(tx, adminId, req.UserId, "", addrJSON, nil, "", 10, true, nil, nil, nil, 0, 0, nil, nil, u.Host, false) + if err != nil { + return nil, err + } + primaryOrigins = append(primaryOrigins, &serverconfigs.OriginRef{ + IsOn: true, + OriginId: originId, + }) + } + primaryOriginsJSON, err := json.Marshal(primaryOrigins) + if err != nil { + return nil, err + } + + reverseProxyId, err := models.SharedReverseProxyDAO.CreateReverseProxy(tx, adminId, req.UserId, reverseProxyScheduleJSON, primaryOriginsJSON, nil) + if err != nil { + return nil, err + } + reverseProxyJSON, err := json.Marshal(&serverconfigs.ReverseProxyRef{ + IsPrior: false, + IsOn: true, + ReverseProxyId: reverseProxyId, + }) + if err != nil { + return nil, err + } + + // Web + webId, err := models.SharedHTTPWebDAO.CreateWeb(tx, adminId, req.UserId, nil) + if err != nil { + return nil, err + } + + // Enable websocket + if req.EnableWebsocket { + websocketId, err := models.SharedHTTPWebsocketDAO.CreateWebsocket(tx, nil, true, nil, true, "") + if err != nil { + return nil, err + } + websocketRef, err := json.Marshal(&serverconfigs.HTTPWebsocketRef{ + IsPrior: false, + IsOn: true, + WebsocketId: websocketId, + }) + if err != nil { + return nil, err + } + err = models.SharedHTTPWebDAO.UpdateWebsocket(tx, webId, websocketRef) + if err != nil { + return nil, err + } + } + + // finally, we create ... + serverId, err := models.SharedServerDAO.CreateServer(tx, adminId, req.UserId, serverconfigs.ServerTypeHTTPProxy, req.Domains[0], "", serverNamesJSON, isAuditing, auditingServerNamesJSON, httpJSON, httpsJSON, nil, nil, nil, nil, webId, reverseProxyJSON, req.NodeClusterId, nil, nil, nil, 0) + if err != nil { + return nil, err + } + + return &pb.CreateBasicHTTPServerResponse{ServerId: serverId}, nil +} + // UpdateServerBasic 修改服务基本信息 func (this *ServerService) UpdateServerBasic(ctx context.Context, req *pb.UpdateServerBasicRequest) (*pb.RPCSuccess, error) { // 校验请求 diff --git a/internal/utils/domainutils/utils.go b/internal/utils/domainutils/utils.go new file mode 100644 index 00000000..65f44c2c --- /dev/null +++ b/internal/utils/domainutils/utils.go @@ -0,0 +1,31 @@ +// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn . + +package domainutils + +import ( + "regexp" + "strings" +) + +// ValidateDomainFormat 校验域名格式 +func ValidateDomainFormat(domain string) bool { + pieces := strings.Split(domain, ".") + for _, piece := range pieces { + if piece == "-" || + strings.HasPrefix(piece, "-") || + strings.HasSuffix(piece, "-") || + //strings.Contains(piece, "--") || + len(piece) > 63 || + // 支持中文、大写字母、下划线 + !regexp.MustCompile(`^[\p{Han}_a-zA-Z0-9-]+$`).MatchString(piece) { + return false + } + } + + // 最后一段不能是全数字 + if regexp.MustCompile(`^(\d+)$`).MatchString(pieces[len(pieces)-1]) { + return false + } + + return true +} diff --git a/internal/utils/regexputils/expr.go b/internal/utils/regexputils/expr.go index 5a22a5a2..63f9a7fa 100644 --- a/internal/utils/regexputils/expr.go +++ b/internal/utils/regexputils/expr.go @@ -9,3 +9,7 @@ var ( YYYYMMDD = regexp.MustCompile(`^\d{8}$`) YYYYMM = regexp.MustCompile(`^\d{6}$`) ) + +var ( + HTTPProtocol = regexp.MustCompile("^(?i)(http|https)://") +)