feat: 实现 LDAP 登录

This commit is contained in:
kanzihuang
2023-08-23 22:09:41 +08:00
committed by Wanli
parent 2e969d46fb
commit 4e1350d1cc
13 changed files with 321 additions and 5174 deletions

View File

@@ -43,6 +43,7 @@ type Config struct {
Mysql *Mysql `yaml:"mysql"`
Redis *Redis `yaml:"redis"`
Log *Log `yaml:"log"`
Ldap *Ldap `yaml:"ldap"`
}
// 配置文件内容校验

45
server/pkg/config/ldap.go Normal file
View File

@@ -0,0 +1,45 @@
package config
// FieldMapping 表示用户属性和 LDAP 字段名之间的映射关系
type FieldMapping struct {
// Identifier 表示用户标识
Identifier string `yaml:"identifier,omitempty"`
// DisplayName 表示用户姓名
DisplayName string `yaml:"displayName,omitempty"`
// Email 表示 Email 地址
Email string `yaml:"email,omitempty"`
}
// SecurityProtocol 表示连接 LDAP 服务器的安全协议
type SecurityProtocol string
const (
// SecurityProtocolStartTLS 表示 StartTLS 安全协议
SecurityProtocolStartTLS SecurityProtocol = "starttls"
// SecurityProtocolLDAPS 表示 LDAPS 安全协议
SecurityProtocolLDAPS SecurityProtocol = "ldaps"
)
// Ldap 是 LDAP 服务配置
type Ldap struct {
// Enabled 表示是否启用 LDAP 登录
Enabled bool `yaml:"enabled"`
// Host 是 LDAP 服务地址, 如: "ldap.example.com"
Host string `yaml:"host"`
// Port 是 LDAP 服务端口号, 如: 389
Port int `yaml:"port"`
// SkipTLSVerify 控制客户端是否跳过 TLS 证书验证
SkipTLSVerify bool `yaml:"skipTlsVerify"`
// BindDN 是 LDAP 服务的管理员账号,如: "cn=admin,dc=example,dc=com"
BindDN string `yaml:"bindDn"`
// BindPassword 是 LDAP 服务的管理员密码
BindPassword string `yaml:"bindPassword"`
// BaseDN 是用户所在的 base DN, 如: "ou=users,dc=example,dc=com".
BaseDN string `yaml:"baseDn"`
// UserFilter 是过滤用户的方式, 如: "(uid=%s)".
UserFilter string `yaml:"userFilter"`
// SecurityProtocol 是连接使用的 LDAP 安全协议(为空不使用安全协议),如: StartTLS, LDAPS
SecurityProtocol SecurityProtocol `yaml:"securityProtocol"`
// FieldMapping 表示用户属性和 LDAP 字段名之间的映射关系
FieldMapping FieldMapping `yaml:"fieldMapping"`
}

107
server/pkg/ldap/ldap.go Normal file
View File

@@ -0,0 +1,107 @@
package ldap
import (
"crypto/tls"
"fmt"
"github.com/go-ldap/ldap/v3"
"github.com/pkg/errors"
"mayfly-go/pkg/config"
"strings"
)
type UserInfo struct {
UserName string
DisplayName string
Email string
}
func dial() (*ldap.Conn, error) {
conf := config.Conf.Ldap
addr := fmt.Sprintf("%s:%d", conf.Host, conf.Port)
tlsConfig := &tls.Config{
ServerName: conf.Host,
InsecureSkipVerify: conf.SkipTLSVerify,
}
if conf.SecurityProtocol == config.SecurityProtocolLDAPS {
conn, err := ldap.DialTLS("tcp", addr, tlsConfig)
if err != nil {
return nil, errors.Errorf("dial TLS: %v", err)
}
return conn, nil
}
conn, err := ldap.Dial("tcp", addr)
if err != nil {
return nil, errors.Errorf("dial: %v", err)
}
if conf.SecurityProtocol == config.SecurityProtocolStartTLS {
if err = conn.StartTLS(tlsConfig); err != nil {
_ = conn.Close()
return nil, errors.Errorf("start TLS: %v", err)
}
}
return conn, nil
}
// Connect 创建 LDAP 连接
func Connect() (*ldap.Conn, error) {
conn, err := dial()
if err != nil {
return nil, err
}
// Bind with a system account
conf := config.Conf.Ldap
err = conn.Bind(conf.BindDN, conf.BindPassword)
if err != nil {
_ = conn.Close()
return nil, errors.Errorf("bind: %v", err)
}
return conn, nil
}
// Authenticate 通过 LDAP 验证用户名密码
func Authenticate(username, password string) (*UserInfo, error) {
conn, err := Connect()
if err != nil {
return nil, errors.Errorf("connect: %v", err)
}
defer func() { _ = conn.Close() }()
conf := config.Conf.Ldap
sr, err := conn.Search(
ldap.NewSearchRequest(
conf.BaseDN,
ldap.ScopeWholeSubtree,
ldap.NeverDerefAliases,
0,
0,
false,
strings.ReplaceAll(conf.UserFilter, "%s", username),
[]string{"dn", conf.FieldMapping.Identifier, conf.FieldMapping.DisplayName, conf.FieldMapping.Email},
nil,
),
)
if err != nil {
return nil, errors.Errorf("search user DN: %v", err)
} else if len(sr.Entries) != 1 {
return nil, errors.Errorf("expect 1 user DN but got %d", len(sr.Entries))
}
entry := sr.Entries[0]
// Bind as the user to verify their password
err = conn.Bind(entry.DN, password)
if err != nil {
return nil, errors.Errorf("bind user: %v", err)
}
userName := entry.GetAttributeValue(conf.FieldMapping.Identifier)
if userName == "" {
return nil, errors.Errorf("the attribute %q is not found or has empty value", conf.FieldMapping.Identifier)
}
return &UserInfo{
UserName: userName,
DisplayName: entry.GetAttributeValue(conf.FieldMapping.DisplayName),
Email: entry.GetAttributeValue(conf.FieldMapping.Email),
}, nil
}