mirror of
				https://gitee.com/gitea/gitea
				synced 2025-11-04 08:30:25 +08:00 
			
		
		
		
	Refactor auth package (#17962)
This commit is contained in:
		
							
								
								
									
										22
									
								
								models/auth/main_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								models/auth/main_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,22 @@
 | 
			
		||||
// Copyright 2020 The Gitea Authors. All rights reserved.
 | 
			
		||||
// Use of this source code is governed by a MIT-style
 | 
			
		||||
// license that can be found in the LICENSE file.
 | 
			
		||||
 | 
			
		||||
package auth
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"path/filepath"
 | 
			
		||||
	"testing"
 | 
			
		||||
 | 
			
		||||
	"code.gitea.io/gitea/models/unittest"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestMain(m *testing.M) {
 | 
			
		||||
	unittest.MainTest(m, filepath.Join("..", ".."),
 | 
			
		||||
		"login_source.yml",
 | 
			
		||||
		"oauth2_application.yml",
 | 
			
		||||
		"oauth2_authorization_code.yml",
 | 
			
		||||
		"oauth2_grant.yml",
 | 
			
		||||
		"u2f_registration.yml",
 | 
			
		||||
	)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										564
									
								
								models/auth/oauth2.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										564
									
								
								models/auth/oauth2.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,564 @@
 | 
			
		||||
// Copyright 2019 The Gitea Authors. All rights reserved.
 | 
			
		||||
// Use of this source code is governed by a MIT-style
 | 
			
		||||
// license that can be found in the LICENSE file.
 | 
			
		||||
 | 
			
		||||
package auth
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"crypto/sha256"
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"code.gitea.io/gitea/models/db"
 | 
			
		||||
	"code.gitea.io/gitea/modules/secret"
 | 
			
		||||
	"code.gitea.io/gitea/modules/timeutil"
 | 
			
		||||
	"code.gitea.io/gitea/modules/util"
 | 
			
		||||
 | 
			
		||||
	uuid "github.com/google/uuid"
 | 
			
		||||
	"golang.org/x/crypto/bcrypt"
 | 
			
		||||
	"xorm.io/xorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// OAuth2Application represents an OAuth2 client (RFC 6749)
 | 
			
		||||
type OAuth2Application struct {
 | 
			
		||||
	ID           int64 `xorm:"pk autoincr"`
 | 
			
		||||
	UID          int64 `xorm:"INDEX"`
 | 
			
		||||
	Name         string
 | 
			
		||||
	ClientID     string `xorm:"unique"`
 | 
			
		||||
	ClientSecret string
 | 
			
		||||
	RedirectURIs []string           `xorm:"redirect_uris JSON TEXT"`
 | 
			
		||||
	CreatedUnix  timeutil.TimeStamp `xorm:"INDEX created"`
 | 
			
		||||
	UpdatedUnix  timeutil.TimeStamp `xorm:"INDEX updated"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
	db.RegisterModel(new(OAuth2Application))
 | 
			
		||||
	db.RegisterModel(new(OAuth2AuthorizationCode))
 | 
			
		||||
	db.RegisterModel(new(OAuth2Grant))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TableName sets the table name to `oauth2_application`
 | 
			
		||||
func (app *OAuth2Application) TableName() string {
 | 
			
		||||
	return "oauth2_application"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// PrimaryRedirectURI returns the first redirect uri or an empty string if empty
 | 
			
		||||
func (app *OAuth2Application) PrimaryRedirectURI() string {
 | 
			
		||||
	if len(app.RedirectURIs) == 0 {
 | 
			
		||||
		return ""
 | 
			
		||||
	}
 | 
			
		||||
	return app.RedirectURIs[0]
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ContainsRedirectURI checks if redirectURI is allowed for app
 | 
			
		||||
func (app *OAuth2Application) ContainsRedirectURI(redirectURI string) bool {
 | 
			
		||||
	return util.IsStringInSlice(redirectURI, app.RedirectURIs, true)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GenerateClientSecret will generate the client secret and returns the plaintext and saves the hash at the database
 | 
			
		||||
func (app *OAuth2Application) GenerateClientSecret() (string, error) {
 | 
			
		||||
	clientSecret, err := secret.New()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
	hashedSecret, err := bcrypt.GenerateFromPassword([]byte(clientSecret), bcrypt.DefaultCost)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
	app.ClientSecret = string(hashedSecret)
 | 
			
		||||
	if _, err := db.GetEngine(db.DefaultContext).ID(app.ID).Cols("client_secret").Update(app); err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
	return clientSecret, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ValidateClientSecret validates the given secret by the hash saved in database
 | 
			
		||||
func (app *OAuth2Application) ValidateClientSecret(secret []byte) bool {
 | 
			
		||||
	return bcrypt.CompareHashAndPassword([]byte(app.ClientSecret), secret) == nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetGrantByUserID returns a OAuth2Grant by its user and application ID
 | 
			
		||||
func (app *OAuth2Application) GetGrantByUserID(userID int64) (*OAuth2Grant, error) {
 | 
			
		||||
	return app.getGrantByUserID(db.GetEngine(db.DefaultContext), userID)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (app *OAuth2Application) getGrantByUserID(e db.Engine, userID int64) (grant *OAuth2Grant, err error) {
 | 
			
		||||
	grant = new(OAuth2Grant)
 | 
			
		||||
	if has, err := e.Where("user_id = ? AND application_id = ?", userID, app.ID).Get(grant); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	} else if !has {
 | 
			
		||||
		return nil, nil
 | 
			
		||||
	}
 | 
			
		||||
	return grant, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CreateGrant generates a grant for an user
 | 
			
		||||
func (app *OAuth2Application) CreateGrant(userID int64, scope string) (*OAuth2Grant, error) {
 | 
			
		||||
	return app.createGrant(db.GetEngine(db.DefaultContext), userID, scope)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (app *OAuth2Application) createGrant(e db.Engine, userID int64, scope string) (*OAuth2Grant, error) {
 | 
			
		||||
	grant := &OAuth2Grant{
 | 
			
		||||
		ApplicationID: app.ID,
 | 
			
		||||
		UserID:        userID,
 | 
			
		||||
		Scope:         scope,
 | 
			
		||||
	}
 | 
			
		||||
	_, err := e.Insert(grant)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return grant, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetOAuth2ApplicationByClientID returns the oauth2 application with the given client_id. Returns an error if not found.
 | 
			
		||||
func GetOAuth2ApplicationByClientID(clientID string) (app *OAuth2Application, err error) {
 | 
			
		||||
	return getOAuth2ApplicationByClientID(db.GetEngine(db.DefaultContext), clientID)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getOAuth2ApplicationByClientID(e db.Engine, clientID string) (app *OAuth2Application, err error) {
 | 
			
		||||
	app = new(OAuth2Application)
 | 
			
		||||
	has, err := e.Where("client_id = ?", clientID).Get(app)
 | 
			
		||||
	if !has {
 | 
			
		||||
		return nil, ErrOAuthClientIDInvalid{ClientID: clientID}
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetOAuth2ApplicationByID returns the oauth2 application with the given id. Returns an error if not found.
 | 
			
		||||
func GetOAuth2ApplicationByID(id int64) (app *OAuth2Application, err error) {
 | 
			
		||||
	return getOAuth2ApplicationByID(db.GetEngine(db.DefaultContext), id)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getOAuth2ApplicationByID(e db.Engine, id int64) (app *OAuth2Application, err error) {
 | 
			
		||||
	app = new(OAuth2Application)
 | 
			
		||||
	has, err := e.ID(id).Get(app)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	if !has {
 | 
			
		||||
		return nil, ErrOAuthApplicationNotFound{ID: id}
 | 
			
		||||
	}
 | 
			
		||||
	return app, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetOAuth2ApplicationsByUserID returns all oauth2 applications owned by the user
 | 
			
		||||
func GetOAuth2ApplicationsByUserID(userID int64) (apps []*OAuth2Application, err error) {
 | 
			
		||||
	return getOAuth2ApplicationsByUserID(db.GetEngine(db.DefaultContext), userID)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getOAuth2ApplicationsByUserID(e db.Engine, userID int64) (apps []*OAuth2Application, err error) {
 | 
			
		||||
	apps = make([]*OAuth2Application, 0)
 | 
			
		||||
	err = e.Where("uid = ?", userID).Find(&apps)
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CreateOAuth2ApplicationOptions holds options to create an oauth2 application
 | 
			
		||||
type CreateOAuth2ApplicationOptions struct {
 | 
			
		||||
	Name         string
 | 
			
		||||
	UserID       int64
 | 
			
		||||
	RedirectURIs []string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CreateOAuth2Application inserts a new oauth2 application
 | 
			
		||||
func CreateOAuth2Application(opts CreateOAuth2ApplicationOptions) (*OAuth2Application, error) {
 | 
			
		||||
	return createOAuth2Application(db.GetEngine(db.DefaultContext), opts)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func createOAuth2Application(e db.Engine, opts CreateOAuth2ApplicationOptions) (*OAuth2Application, error) {
 | 
			
		||||
	clientID := uuid.New().String()
 | 
			
		||||
	app := &OAuth2Application{
 | 
			
		||||
		UID:          opts.UserID,
 | 
			
		||||
		Name:         opts.Name,
 | 
			
		||||
		ClientID:     clientID,
 | 
			
		||||
		RedirectURIs: opts.RedirectURIs,
 | 
			
		||||
	}
 | 
			
		||||
	if _, err := e.Insert(app); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return app, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// UpdateOAuth2ApplicationOptions holds options to update an oauth2 application
 | 
			
		||||
type UpdateOAuth2ApplicationOptions struct {
 | 
			
		||||
	ID           int64
 | 
			
		||||
	Name         string
 | 
			
		||||
	UserID       int64
 | 
			
		||||
	RedirectURIs []string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// UpdateOAuth2Application updates an oauth2 application
 | 
			
		||||
func UpdateOAuth2Application(opts UpdateOAuth2ApplicationOptions) (*OAuth2Application, error) {
 | 
			
		||||
	ctx, committer, err := db.TxContext()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	defer committer.Close()
 | 
			
		||||
	sess := db.GetEngine(ctx)
 | 
			
		||||
 | 
			
		||||
	app, err := getOAuth2ApplicationByID(sess, opts.ID)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	if app.UID != opts.UserID {
 | 
			
		||||
		return nil, fmt.Errorf("UID mismatch")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	app.Name = opts.Name
 | 
			
		||||
	app.RedirectURIs = opts.RedirectURIs
 | 
			
		||||
 | 
			
		||||
	if err = updateOAuth2Application(sess, app); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	app.ClientSecret = ""
 | 
			
		||||
 | 
			
		||||
	return app, committer.Commit()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func updateOAuth2Application(e db.Engine, app *OAuth2Application) error {
 | 
			
		||||
	if _, err := e.ID(app.ID).Update(app); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func deleteOAuth2Application(sess db.Engine, id, userid int64) error {
 | 
			
		||||
	if deleted, err := sess.Delete(&OAuth2Application{ID: id, UID: userid}); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	} else if deleted == 0 {
 | 
			
		||||
		return ErrOAuthApplicationNotFound{ID: id}
 | 
			
		||||
	}
 | 
			
		||||
	codes := make([]*OAuth2AuthorizationCode, 0)
 | 
			
		||||
	// delete correlating auth codes
 | 
			
		||||
	if err := sess.Join("INNER", "oauth2_grant",
 | 
			
		||||
		"oauth2_authorization_code.grant_id = oauth2_grant.id AND oauth2_grant.application_id = ?", id).Find(&codes); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	codeIDs := make([]int64, 0)
 | 
			
		||||
	for _, grant := range codes {
 | 
			
		||||
		codeIDs = append(codeIDs, grant.ID)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if _, err := sess.In("id", codeIDs).Delete(new(OAuth2AuthorizationCode)); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if _, err := sess.Where("application_id = ?", id).Delete(new(OAuth2Grant)); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// DeleteOAuth2Application deletes the application with the given id and the grants and auth codes related to it. It checks if the userid was the creator of the app.
 | 
			
		||||
func DeleteOAuth2Application(id, userid int64) error {
 | 
			
		||||
	ctx, committer, err := db.TxContext()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	defer committer.Close()
 | 
			
		||||
	if err := deleteOAuth2Application(db.GetEngine(ctx), id, userid); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	return committer.Commit()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ListOAuth2Applications returns a list of oauth2 applications belongs to given user.
 | 
			
		||||
func ListOAuth2Applications(uid int64, listOptions db.ListOptions) ([]*OAuth2Application, int64, error) {
 | 
			
		||||
	sess := db.GetEngine(db.DefaultContext).
 | 
			
		||||
		Where("uid=?", uid).
 | 
			
		||||
		Desc("id")
 | 
			
		||||
 | 
			
		||||
	if listOptions.Page != 0 {
 | 
			
		||||
		sess = db.SetSessionPagination(sess, &listOptions)
 | 
			
		||||
 | 
			
		||||
		apps := make([]*OAuth2Application, 0, listOptions.PageSize)
 | 
			
		||||
		total, err := sess.FindAndCount(&apps)
 | 
			
		||||
		return apps, total, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	apps := make([]*OAuth2Application, 0, 5)
 | 
			
		||||
	total, err := sess.FindAndCount(&apps)
 | 
			
		||||
	return apps, total, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//////////////////////////////////////////////////////
 | 
			
		||||
 | 
			
		||||
// OAuth2AuthorizationCode is a code to obtain an access token in combination with the client secret once. It has a limited lifetime.
 | 
			
		||||
type OAuth2AuthorizationCode struct {
 | 
			
		||||
	ID                  int64        `xorm:"pk autoincr"`
 | 
			
		||||
	Grant               *OAuth2Grant `xorm:"-"`
 | 
			
		||||
	GrantID             int64
 | 
			
		||||
	Code                string `xorm:"INDEX unique"`
 | 
			
		||||
	CodeChallenge       string
 | 
			
		||||
	CodeChallengeMethod string
 | 
			
		||||
	RedirectURI         string
 | 
			
		||||
	ValidUntil          timeutil.TimeStamp `xorm:"index"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TableName sets the table name to `oauth2_authorization_code`
 | 
			
		||||
func (code *OAuth2AuthorizationCode) TableName() string {
 | 
			
		||||
	return "oauth2_authorization_code"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GenerateRedirectURI generates a redirect URI for a successful authorization request. State will be used if not empty.
 | 
			
		||||
func (code *OAuth2AuthorizationCode) GenerateRedirectURI(state string) (redirect *url.URL, err error) {
 | 
			
		||||
	if redirect, err = url.Parse(code.RedirectURI); err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	q := redirect.Query()
 | 
			
		||||
	if state != "" {
 | 
			
		||||
		q.Set("state", state)
 | 
			
		||||
	}
 | 
			
		||||
	q.Set("code", code.Code)
 | 
			
		||||
	redirect.RawQuery = q.Encode()
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Invalidate deletes the auth code from the database to invalidate this code
 | 
			
		||||
func (code *OAuth2AuthorizationCode) Invalidate() error {
 | 
			
		||||
	return code.invalidate(db.GetEngine(db.DefaultContext))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (code *OAuth2AuthorizationCode) invalidate(e db.Engine) error {
 | 
			
		||||
	_, err := e.Delete(code)
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ValidateCodeChallenge validates the given verifier against the saved code challenge. This is part of the PKCE implementation.
 | 
			
		||||
func (code *OAuth2AuthorizationCode) ValidateCodeChallenge(verifier string) bool {
 | 
			
		||||
	return code.validateCodeChallenge(verifier)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (code *OAuth2AuthorizationCode) validateCodeChallenge(verifier string) bool {
 | 
			
		||||
	switch code.CodeChallengeMethod {
 | 
			
		||||
	case "S256":
 | 
			
		||||
		// base64url(SHA256(verifier)) see https://tools.ietf.org/html/rfc7636#section-4.6
 | 
			
		||||
		h := sha256.Sum256([]byte(verifier))
 | 
			
		||||
		hashedVerifier := base64.RawURLEncoding.EncodeToString(h[:])
 | 
			
		||||
		return hashedVerifier == code.CodeChallenge
 | 
			
		||||
	case "plain":
 | 
			
		||||
		return verifier == code.CodeChallenge
 | 
			
		||||
	case "":
 | 
			
		||||
		return true
 | 
			
		||||
	default:
 | 
			
		||||
		// unsupported method -> return false
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetOAuth2AuthorizationByCode returns an authorization by its code
 | 
			
		||||
func GetOAuth2AuthorizationByCode(code string) (*OAuth2AuthorizationCode, error) {
 | 
			
		||||
	return getOAuth2AuthorizationByCode(db.GetEngine(db.DefaultContext), code)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getOAuth2AuthorizationByCode(e db.Engine, code string) (auth *OAuth2AuthorizationCode, err error) {
 | 
			
		||||
	auth = new(OAuth2AuthorizationCode)
 | 
			
		||||
	if has, err := e.Where("code = ?", code).Get(auth); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	} else if !has {
 | 
			
		||||
		return nil, nil
 | 
			
		||||
	}
 | 
			
		||||
	auth.Grant = new(OAuth2Grant)
 | 
			
		||||
	if has, err := e.ID(auth.GrantID).Get(auth.Grant); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	} else if !has {
 | 
			
		||||
		return nil, nil
 | 
			
		||||
	}
 | 
			
		||||
	return auth, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//////////////////////////////////////////////////////
 | 
			
		||||
 | 
			
		||||
// OAuth2Grant represents the permission of an user for a specific application to access resources
 | 
			
		||||
type OAuth2Grant struct {
 | 
			
		||||
	ID            int64              `xorm:"pk autoincr"`
 | 
			
		||||
	UserID        int64              `xorm:"INDEX unique(user_application)"`
 | 
			
		||||
	Application   *OAuth2Application `xorm:"-"`
 | 
			
		||||
	ApplicationID int64              `xorm:"INDEX unique(user_application)"`
 | 
			
		||||
	Counter       int64              `xorm:"NOT NULL DEFAULT 1"`
 | 
			
		||||
	Scope         string             `xorm:"TEXT"`
 | 
			
		||||
	Nonce         string             `xorm:"TEXT"`
 | 
			
		||||
	CreatedUnix   timeutil.TimeStamp `xorm:"created"`
 | 
			
		||||
	UpdatedUnix   timeutil.TimeStamp `xorm:"updated"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TableName sets the table name to `oauth2_grant`
 | 
			
		||||
func (grant *OAuth2Grant) TableName() string {
 | 
			
		||||
	return "oauth2_grant"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GenerateNewAuthorizationCode generates a new authorization code for a grant and saves it to the database
 | 
			
		||||
func (grant *OAuth2Grant) GenerateNewAuthorizationCode(redirectURI, codeChallenge, codeChallengeMethod string) (*OAuth2AuthorizationCode, error) {
 | 
			
		||||
	return grant.generateNewAuthorizationCode(db.GetEngine(db.DefaultContext), redirectURI, codeChallenge, codeChallengeMethod)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (grant *OAuth2Grant) generateNewAuthorizationCode(e db.Engine, redirectURI, codeChallenge, codeChallengeMethod string) (code *OAuth2AuthorizationCode, err error) {
 | 
			
		||||
	var codeSecret string
 | 
			
		||||
	if codeSecret, err = secret.New(); err != nil {
 | 
			
		||||
		return &OAuth2AuthorizationCode{}, err
 | 
			
		||||
	}
 | 
			
		||||
	code = &OAuth2AuthorizationCode{
 | 
			
		||||
		Grant:               grant,
 | 
			
		||||
		GrantID:             grant.ID,
 | 
			
		||||
		RedirectURI:         redirectURI,
 | 
			
		||||
		Code:                codeSecret,
 | 
			
		||||
		CodeChallenge:       codeChallenge,
 | 
			
		||||
		CodeChallengeMethod: codeChallengeMethod,
 | 
			
		||||
	}
 | 
			
		||||
	if _, err := e.Insert(code); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return code, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// IncreaseCounter increases the counter and updates the grant
 | 
			
		||||
func (grant *OAuth2Grant) IncreaseCounter() error {
 | 
			
		||||
	return grant.increaseCount(db.GetEngine(db.DefaultContext))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (grant *OAuth2Grant) increaseCount(e db.Engine) error {
 | 
			
		||||
	_, err := e.ID(grant.ID).Incr("counter").Update(new(OAuth2Grant))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	updatedGrant, err := getOAuth2GrantByID(e, grant.ID)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	grant.Counter = updatedGrant.Counter
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ScopeContains returns true if the grant scope contains the specified scope
 | 
			
		||||
func (grant *OAuth2Grant) ScopeContains(scope string) bool {
 | 
			
		||||
	for _, currentScope := range strings.Split(grant.Scope, " ") {
 | 
			
		||||
		if scope == currentScope {
 | 
			
		||||
			return true
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SetNonce updates the current nonce value of a grant
 | 
			
		||||
func (grant *OAuth2Grant) SetNonce(nonce string) error {
 | 
			
		||||
	return grant.setNonce(db.GetEngine(db.DefaultContext), nonce)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (grant *OAuth2Grant) setNonce(e db.Engine, nonce string) error {
 | 
			
		||||
	grant.Nonce = nonce
 | 
			
		||||
	_, err := e.ID(grant.ID).Cols("nonce").Update(grant)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetOAuth2GrantByID returns the grant with the given ID
 | 
			
		||||
func GetOAuth2GrantByID(id int64) (*OAuth2Grant, error) {
 | 
			
		||||
	return getOAuth2GrantByID(db.GetEngine(db.DefaultContext), id)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getOAuth2GrantByID(e db.Engine, id int64) (grant *OAuth2Grant, err error) {
 | 
			
		||||
	grant = new(OAuth2Grant)
 | 
			
		||||
	if has, err := e.ID(id).Get(grant); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	} else if !has {
 | 
			
		||||
		return nil, nil
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetOAuth2GrantsByUserID lists all grants of a certain user
 | 
			
		||||
func GetOAuth2GrantsByUserID(uid int64) ([]*OAuth2Grant, error) {
 | 
			
		||||
	return getOAuth2GrantsByUserID(db.GetEngine(db.DefaultContext), uid)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getOAuth2GrantsByUserID(e db.Engine, uid int64) ([]*OAuth2Grant, error) {
 | 
			
		||||
	type joinedOAuth2Grant struct {
 | 
			
		||||
		Grant       *OAuth2Grant       `xorm:"extends"`
 | 
			
		||||
		Application *OAuth2Application `xorm:"extends"`
 | 
			
		||||
	}
 | 
			
		||||
	var results *xorm.Rows
 | 
			
		||||
	var err error
 | 
			
		||||
	if results, err = e.
 | 
			
		||||
		Table("oauth2_grant").
 | 
			
		||||
		Where("user_id = ?", uid).
 | 
			
		||||
		Join("INNER", "oauth2_application", "application_id = oauth2_application.id").
 | 
			
		||||
		Rows(new(joinedOAuth2Grant)); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	defer results.Close()
 | 
			
		||||
	grants := make([]*OAuth2Grant, 0)
 | 
			
		||||
	for results.Next() {
 | 
			
		||||
		joinedGrant := new(joinedOAuth2Grant)
 | 
			
		||||
		if err := results.Scan(joinedGrant); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		joinedGrant.Grant.Application = joinedGrant.Application
 | 
			
		||||
		grants = append(grants, joinedGrant.Grant)
 | 
			
		||||
	}
 | 
			
		||||
	return grants, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RevokeOAuth2Grant deletes the grant with grantID and userID
 | 
			
		||||
func RevokeOAuth2Grant(grantID, userID int64) error {
 | 
			
		||||
	return revokeOAuth2Grant(db.GetEngine(db.DefaultContext), grantID, userID)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func revokeOAuth2Grant(e db.Engine, grantID, userID int64) error {
 | 
			
		||||
	_, err := e.Delete(&OAuth2Grant{ID: grantID, UserID: userID})
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ErrOAuthClientIDInvalid will be thrown if client id cannot be found
 | 
			
		||||
type ErrOAuthClientIDInvalid struct {
 | 
			
		||||
	ClientID string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// IsErrOauthClientIDInvalid checks if an error is a ErrReviewNotExist.
 | 
			
		||||
func IsErrOauthClientIDInvalid(err error) bool {
 | 
			
		||||
	_, ok := err.(ErrOAuthClientIDInvalid)
 | 
			
		||||
	return ok
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Error returns the error message
 | 
			
		||||
func (err ErrOAuthClientIDInvalid) Error() string {
 | 
			
		||||
	return fmt.Sprintf("Client ID invalid [Client ID: %s]", err.ClientID)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ErrOAuthApplicationNotFound will be thrown if id cannot be found
 | 
			
		||||
type ErrOAuthApplicationNotFound struct {
 | 
			
		||||
	ID int64
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// IsErrOAuthApplicationNotFound checks if an error is a ErrReviewNotExist.
 | 
			
		||||
func IsErrOAuthApplicationNotFound(err error) bool {
 | 
			
		||||
	_, ok := err.(ErrOAuthApplicationNotFound)
 | 
			
		||||
	return ok
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Error returns the error message
 | 
			
		||||
func (err ErrOAuthApplicationNotFound) Error() string {
 | 
			
		||||
	return fmt.Sprintf("OAuth application not found [ID: %d]", err.ID)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetActiveOAuth2ProviderSources returns all actived LoginOAuth2 sources
 | 
			
		||||
func GetActiveOAuth2ProviderSources() ([]*Source, error) {
 | 
			
		||||
	sources := make([]*Source, 0, 1)
 | 
			
		||||
	if err := db.GetEngine(db.DefaultContext).Where("is_active = ? and type = ?", true, OAuth2).Find(&sources); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return sources, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetActiveOAuth2SourceByName returns a OAuth2 AuthSource based on the given name
 | 
			
		||||
func GetActiveOAuth2SourceByName(name string) (*Source, error) {
 | 
			
		||||
	authSource := new(Source)
 | 
			
		||||
	has, err := db.GetEngine(db.DefaultContext).Where("name = ? and type = ? and is_active = ?", name, OAuth2, true).Get(authSource)
 | 
			
		||||
	if !has || err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return authSource, nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										233
									
								
								models/auth/oauth2_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										233
									
								
								models/auth/oauth2_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,233 @@
 | 
			
		||||
// Copyright 2019 The Gitea Authors. All rights reserved.
 | 
			
		||||
// Use of this source code is governed by a MIT-style
 | 
			
		||||
// license that can be found in the LICENSE file.
 | 
			
		||||
 | 
			
		||||
package auth
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"testing"
 | 
			
		||||
 | 
			
		||||
	"code.gitea.io/gitea/models/unittest"
 | 
			
		||||
 | 
			
		||||
	"github.com/stretchr/testify/assert"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
//////////////////// Application
 | 
			
		||||
 | 
			
		||||
func TestOAuth2Application_GenerateClientSecret(t *testing.T) {
 | 
			
		||||
	assert.NoError(t, unittest.PrepareTestDatabase())
 | 
			
		||||
	app := unittest.AssertExistsAndLoadBean(t, &OAuth2Application{ID: 1}).(*OAuth2Application)
 | 
			
		||||
	secret, err := app.GenerateClientSecret()
 | 
			
		||||
	assert.NoError(t, err)
 | 
			
		||||
	assert.True(t, len(secret) > 0)
 | 
			
		||||
	unittest.AssertExistsAndLoadBean(t, &OAuth2Application{ID: 1, ClientSecret: app.ClientSecret})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func BenchmarkOAuth2Application_GenerateClientSecret(b *testing.B) {
 | 
			
		||||
	assert.NoError(b, unittest.PrepareTestDatabase())
 | 
			
		||||
	app := unittest.AssertExistsAndLoadBean(b, &OAuth2Application{ID: 1}).(*OAuth2Application)
 | 
			
		||||
	for i := 0; i < b.N; i++ {
 | 
			
		||||
		_, _ = app.GenerateClientSecret()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestOAuth2Application_ContainsRedirectURI(t *testing.T) {
 | 
			
		||||
	app := &OAuth2Application{
 | 
			
		||||
		RedirectURIs: []string{"a", "b", "c"},
 | 
			
		||||
	}
 | 
			
		||||
	assert.True(t, app.ContainsRedirectURI("a"))
 | 
			
		||||
	assert.True(t, app.ContainsRedirectURI("b"))
 | 
			
		||||
	assert.True(t, app.ContainsRedirectURI("c"))
 | 
			
		||||
	assert.False(t, app.ContainsRedirectURI("d"))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestOAuth2Application_ValidateClientSecret(t *testing.T) {
 | 
			
		||||
	assert.NoError(t, unittest.PrepareTestDatabase())
 | 
			
		||||
	app := unittest.AssertExistsAndLoadBean(t, &OAuth2Application{ID: 1}).(*OAuth2Application)
 | 
			
		||||
	secret, err := app.GenerateClientSecret()
 | 
			
		||||
	assert.NoError(t, err)
 | 
			
		||||
	assert.True(t, app.ValidateClientSecret([]byte(secret)))
 | 
			
		||||
	assert.False(t, app.ValidateClientSecret([]byte("fewijfowejgfiowjeoifew")))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestGetOAuth2ApplicationByClientID(t *testing.T) {
 | 
			
		||||
	assert.NoError(t, unittest.PrepareTestDatabase())
 | 
			
		||||
	app, err := GetOAuth2ApplicationByClientID("da7da3ba-9a13-4167-856f-3899de0b0138")
 | 
			
		||||
	assert.NoError(t, err)
 | 
			
		||||
	assert.Equal(t, "da7da3ba-9a13-4167-856f-3899de0b0138", app.ClientID)
 | 
			
		||||
 | 
			
		||||
	app, err = GetOAuth2ApplicationByClientID("invalid client id")
 | 
			
		||||
	assert.Error(t, err)
 | 
			
		||||
	assert.Nil(t, app)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestCreateOAuth2Application(t *testing.T) {
 | 
			
		||||
	assert.NoError(t, unittest.PrepareTestDatabase())
 | 
			
		||||
	app, err := CreateOAuth2Application(CreateOAuth2ApplicationOptions{Name: "newapp", UserID: 1})
 | 
			
		||||
	assert.NoError(t, err)
 | 
			
		||||
	assert.Equal(t, "newapp", app.Name)
 | 
			
		||||
	assert.Len(t, app.ClientID, 36)
 | 
			
		||||
	unittest.AssertExistsAndLoadBean(t, &OAuth2Application{Name: "newapp"})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestOAuth2Application_TableName(t *testing.T) {
 | 
			
		||||
	assert.Equal(t, "oauth2_application", new(OAuth2Application).TableName())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestOAuth2Application_GetGrantByUserID(t *testing.T) {
 | 
			
		||||
	assert.NoError(t, unittest.PrepareTestDatabase())
 | 
			
		||||
	app := unittest.AssertExistsAndLoadBean(t, &OAuth2Application{ID: 1}).(*OAuth2Application)
 | 
			
		||||
	grant, err := app.GetGrantByUserID(1)
 | 
			
		||||
	assert.NoError(t, err)
 | 
			
		||||
	assert.Equal(t, int64(1), grant.UserID)
 | 
			
		||||
 | 
			
		||||
	grant, err = app.GetGrantByUserID(34923458)
 | 
			
		||||
	assert.NoError(t, err)
 | 
			
		||||
	assert.Nil(t, grant)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestOAuth2Application_CreateGrant(t *testing.T) {
 | 
			
		||||
	assert.NoError(t, unittest.PrepareTestDatabase())
 | 
			
		||||
	app := unittest.AssertExistsAndLoadBean(t, &OAuth2Application{ID: 1}).(*OAuth2Application)
 | 
			
		||||
	grant, err := app.CreateGrant(2, "")
 | 
			
		||||
	assert.NoError(t, err)
 | 
			
		||||
	assert.NotNil(t, grant)
 | 
			
		||||
	assert.Equal(t, int64(2), grant.UserID)
 | 
			
		||||
	assert.Equal(t, int64(1), grant.ApplicationID)
 | 
			
		||||
	assert.Equal(t, "", grant.Scope)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//////////////////// Grant
 | 
			
		||||
 | 
			
		||||
func TestGetOAuth2GrantByID(t *testing.T) {
 | 
			
		||||
	assert.NoError(t, unittest.PrepareTestDatabase())
 | 
			
		||||
	grant, err := GetOAuth2GrantByID(1)
 | 
			
		||||
	assert.NoError(t, err)
 | 
			
		||||
	assert.Equal(t, int64(1), grant.ID)
 | 
			
		||||
 | 
			
		||||
	grant, err = GetOAuth2GrantByID(34923458)
 | 
			
		||||
	assert.NoError(t, err)
 | 
			
		||||
	assert.Nil(t, grant)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestOAuth2Grant_IncreaseCounter(t *testing.T) {
 | 
			
		||||
	assert.NoError(t, unittest.PrepareTestDatabase())
 | 
			
		||||
	grant := unittest.AssertExistsAndLoadBean(t, &OAuth2Grant{ID: 1, Counter: 1}).(*OAuth2Grant)
 | 
			
		||||
	assert.NoError(t, grant.IncreaseCounter())
 | 
			
		||||
	assert.Equal(t, int64(2), grant.Counter)
 | 
			
		||||
	unittest.AssertExistsAndLoadBean(t, &OAuth2Grant{ID: 1, Counter: 2})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestOAuth2Grant_ScopeContains(t *testing.T) {
 | 
			
		||||
	assert.NoError(t, unittest.PrepareTestDatabase())
 | 
			
		||||
	grant := unittest.AssertExistsAndLoadBean(t, &OAuth2Grant{ID: 1, Scope: "openid profile"}).(*OAuth2Grant)
 | 
			
		||||
	assert.True(t, grant.ScopeContains("openid"))
 | 
			
		||||
	assert.True(t, grant.ScopeContains("profile"))
 | 
			
		||||
	assert.False(t, grant.ScopeContains("profil"))
 | 
			
		||||
	assert.False(t, grant.ScopeContains("profile2"))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestOAuth2Grant_GenerateNewAuthorizationCode(t *testing.T) {
 | 
			
		||||
	assert.NoError(t, unittest.PrepareTestDatabase())
 | 
			
		||||
	grant := unittest.AssertExistsAndLoadBean(t, &OAuth2Grant{ID: 1}).(*OAuth2Grant)
 | 
			
		||||
	code, err := grant.GenerateNewAuthorizationCode("https://example2.com/callback", "CjvyTLSdR47G5zYenDA-eDWW4lRrO8yvjcWwbD_deOg", "S256")
 | 
			
		||||
	assert.NoError(t, err)
 | 
			
		||||
	assert.NotNil(t, code)
 | 
			
		||||
	assert.True(t, len(code.Code) > 32) // secret length > 32
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestOAuth2Grant_TableName(t *testing.T) {
 | 
			
		||||
	assert.Equal(t, "oauth2_grant", new(OAuth2Grant).TableName())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestGetOAuth2GrantsByUserID(t *testing.T) {
 | 
			
		||||
	assert.NoError(t, unittest.PrepareTestDatabase())
 | 
			
		||||
	result, err := GetOAuth2GrantsByUserID(1)
 | 
			
		||||
	assert.NoError(t, err)
 | 
			
		||||
	assert.Len(t, result, 1)
 | 
			
		||||
	assert.Equal(t, int64(1), result[0].ID)
 | 
			
		||||
	assert.Equal(t, result[0].ApplicationID, result[0].Application.ID)
 | 
			
		||||
 | 
			
		||||
	result, err = GetOAuth2GrantsByUserID(34134)
 | 
			
		||||
	assert.NoError(t, err)
 | 
			
		||||
	assert.Empty(t, result)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestRevokeOAuth2Grant(t *testing.T) {
 | 
			
		||||
	assert.NoError(t, unittest.PrepareTestDatabase())
 | 
			
		||||
	assert.NoError(t, RevokeOAuth2Grant(1, 1))
 | 
			
		||||
	unittest.AssertNotExistsBean(t, &OAuth2Grant{ID: 1, UserID: 1})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//////////////////// Authorization Code
 | 
			
		||||
 | 
			
		||||
func TestGetOAuth2AuthorizationByCode(t *testing.T) {
 | 
			
		||||
	assert.NoError(t, unittest.PrepareTestDatabase())
 | 
			
		||||
	code, err := GetOAuth2AuthorizationByCode("authcode")
 | 
			
		||||
	assert.NoError(t, err)
 | 
			
		||||
	assert.NotNil(t, code)
 | 
			
		||||
	assert.Equal(t, "authcode", code.Code)
 | 
			
		||||
	assert.Equal(t, int64(1), code.ID)
 | 
			
		||||
 | 
			
		||||
	code, err = GetOAuth2AuthorizationByCode("does not exist")
 | 
			
		||||
	assert.NoError(t, err)
 | 
			
		||||
	assert.Nil(t, code)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestOAuth2AuthorizationCode_ValidateCodeChallenge(t *testing.T) {
 | 
			
		||||
	// test plain
 | 
			
		||||
	code := &OAuth2AuthorizationCode{
 | 
			
		||||
		CodeChallengeMethod: "plain",
 | 
			
		||||
		CodeChallenge:       "test123",
 | 
			
		||||
	}
 | 
			
		||||
	assert.True(t, code.ValidateCodeChallenge("test123"))
 | 
			
		||||
	assert.False(t, code.ValidateCodeChallenge("ierwgjoergjio"))
 | 
			
		||||
 | 
			
		||||
	// test S256
 | 
			
		||||
	code = &OAuth2AuthorizationCode{
 | 
			
		||||
		CodeChallengeMethod: "S256",
 | 
			
		||||
		CodeChallenge:       "CjvyTLSdR47G5zYenDA-eDWW4lRrO8yvjcWwbD_deOg",
 | 
			
		||||
	}
 | 
			
		||||
	assert.True(t, code.ValidateCodeChallenge("N1Zo9-8Rfwhkt68r1r29ty8YwIraXR8eh_1Qwxg7yQXsonBt"))
 | 
			
		||||
	assert.False(t, code.ValidateCodeChallenge("wiogjerogorewngoenrgoiuenorg"))
 | 
			
		||||
 | 
			
		||||
	// test unknown
 | 
			
		||||
	code = &OAuth2AuthorizationCode{
 | 
			
		||||
		CodeChallengeMethod: "monkey",
 | 
			
		||||
		CodeChallenge:       "foiwgjioriogeiogjerger",
 | 
			
		||||
	}
 | 
			
		||||
	assert.False(t, code.ValidateCodeChallenge("foiwgjioriogeiogjerger"))
 | 
			
		||||
 | 
			
		||||
	// test no code challenge
 | 
			
		||||
	code = &OAuth2AuthorizationCode{
 | 
			
		||||
		CodeChallengeMethod: "",
 | 
			
		||||
		CodeChallenge:       "foierjiogerogerg",
 | 
			
		||||
	}
 | 
			
		||||
	assert.True(t, code.ValidateCodeChallenge(""))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestOAuth2AuthorizationCode_GenerateRedirectURI(t *testing.T) {
 | 
			
		||||
	code := &OAuth2AuthorizationCode{
 | 
			
		||||
		RedirectURI: "https://example.com/callback",
 | 
			
		||||
		Code:        "thecode",
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	redirect, err := code.GenerateRedirectURI("thestate")
 | 
			
		||||
	assert.NoError(t, err)
 | 
			
		||||
	assert.Equal(t, "https://example.com/callback?code=thecode&state=thestate", redirect.String())
 | 
			
		||||
 | 
			
		||||
	redirect, err = code.GenerateRedirectURI("")
 | 
			
		||||
	assert.NoError(t, err)
 | 
			
		||||
	assert.Equal(t, "https://example.com/callback?code=thecode", redirect.String())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestOAuth2AuthorizationCode_Invalidate(t *testing.T) {
 | 
			
		||||
	assert.NoError(t, unittest.PrepareTestDatabase())
 | 
			
		||||
	code := unittest.AssertExistsAndLoadBean(t, &OAuth2AuthorizationCode{Code: "authcode"}).(*OAuth2AuthorizationCode)
 | 
			
		||||
	assert.NoError(t, code.Invalidate())
 | 
			
		||||
	unittest.AssertNotExistsBean(t, &OAuth2AuthorizationCode{Code: "authcode"})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestOAuth2AuthorizationCode_TableName(t *testing.T) {
 | 
			
		||||
	assert.Equal(t, "oauth2_authorization_code", new(OAuth2AuthorizationCode).TableName())
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										126
									
								
								models/auth/session.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										126
									
								
								models/auth/session.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,126 @@
 | 
			
		||||
// Copyright 2020 The Gitea Authors. All rights reserved.
 | 
			
		||||
// Use of this source code is governed by a MIT-style
 | 
			
		||||
// license that can be found in the LICENSE file.
 | 
			
		||||
 | 
			
		||||
package auth
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
 | 
			
		||||
	"code.gitea.io/gitea/models/db"
 | 
			
		||||
	"code.gitea.io/gitea/modules/timeutil"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Session represents a session compatible for go-chi session
 | 
			
		||||
type Session struct {
 | 
			
		||||
	Key    string             `xorm:"pk CHAR(16)"` // has to be Key to match with go-chi/session
 | 
			
		||||
	Data   []byte             `xorm:"BLOB"`        // on MySQL this has a maximum size of 64Kb - this may need to be increased
 | 
			
		||||
	Expiry timeutil.TimeStamp // has to be Expiry to match with go-chi/session
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
	db.RegisterModel(new(Session))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// UpdateSession updates the session with provided id
 | 
			
		||||
func UpdateSession(key string, data []byte) error {
 | 
			
		||||
	_, err := db.GetEngine(db.DefaultContext).ID(key).Update(&Session{
 | 
			
		||||
		Data:   data,
 | 
			
		||||
		Expiry: timeutil.TimeStampNow(),
 | 
			
		||||
	})
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ReadSession reads the data for the provided session
 | 
			
		||||
func ReadSession(key string) (*Session, error) {
 | 
			
		||||
	session := Session{
 | 
			
		||||
		Key: key,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ctx, committer, err := db.TxContext()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	defer committer.Close()
 | 
			
		||||
 | 
			
		||||
	if has, err := db.GetByBean(ctx, &session); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	} else if !has {
 | 
			
		||||
		session.Expiry = timeutil.TimeStampNow()
 | 
			
		||||
		if err := db.Insert(ctx, &session); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &session, committer.Commit()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ExistSession checks if a session exists
 | 
			
		||||
func ExistSession(key string) (bool, error) {
 | 
			
		||||
	session := Session{
 | 
			
		||||
		Key: key,
 | 
			
		||||
	}
 | 
			
		||||
	return db.GetEngine(db.DefaultContext).Get(&session)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// DestroySession destroys a session
 | 
			
		||||
func DestroySession(key string) error {
 | 
			
		||||
	_, err := db.GetEngine(db.DefaultContext).Delete(&Session{
 | 
			
		||||
		Key: key,
 | 
			
		||||
	})
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RegenerateSession regenerates a session from the old id
 | 
			
		||||
func RegenerateSession(oldKey, newKey string) (*Session, error) {
 | 
			
		||||
	ctx, committer, err := db.TxContext()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	defer committer.Close()
 | 
			
		||||
 | 
			
		||||
	if has, err := db.GetByBean(ctx, &Session{
 | 
			
		||||
		Key: newKey,
 | 
			
		||||
	}); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	} else if has {
 | 
			
		||||
		return nil, fmt.Errorf("session Key: %s already exists", newKey)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if has, err := db.GetByBean(ctx, &Session{
 | 
			
		||||
		Key: oldKey,
 | 
			
		||||
	}); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	} else if !has {
 | 
			
		||||
		if err := db.Insert(ctx, &Session{
 | 
			
		||||
			Key:    oldKey,
 | 
			
		||||
			Expiry: timeutil.TimeStampNow(),
 | 
			
		||||
		}); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if _, err := db.Exec(ctx, "UPDATE "+db.TableName(&Session{})+" SET `key` = ? WHERE `key`=?", newKey, oldKey); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	s := Session{
 | 
			
		||||
		Key: newKey,
 | 
			
		||||
	}
 | 
			
		||||
	if _, err := db.GetByBean(ctx, &s); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &s, committer.Commit()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CountSessions returns the number of sessions
 | 
			
		||||
func CountSessions() (int64, error) {
 | 
			
		||||
	return db.GetEngine(db.DefaultContext).Count(&Session{})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CleanupSessions cleans up expired sessions
 | 
			
		||||
func CleanupSessions(maxLifetime int64) error {
 | 
			
		||||
	_, err := db.GetEngine(db.DefaultContext).Where("expiry <= ?", timeutil.TimeStampNow().Add(-maxLifetime)).Delete(&Session{})
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										397
									
								
								models/auth/source.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										397
									
								
								models/auth/source.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,397 @@
 | 
			
		||||
// Copyright 2014 The Gogs Authors. All rights reserved.
 | 
			
		||||
// Copyright 2019 The Gitea Authors. All rights reserved.
 | 
			
		||||
// Use of this source code is governed by a MIT-style
 | 
			
		||||
// license that can be found in the LICENSE file.
 | 
			
		||||
 | 
			
		||||
package auth
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"reflect"
 | 
			
		||||
 | 
			
		||||
	"code.gitea.io/gitea/models/db"
 | 
			
		||||
	"code.gitea.io/gitea/modules/log"
 | 
			
		||||
	"code.gitea.io/gitea/modules/timeutil"
 | 
			
		||||
 | 
			
		||||
	"xorm.io/xorm"
 | 
			
		||||
	"xorm.io/xorm/convert"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Type represents an login type.
 | 
			
		||||
type Type int
 | 
			
		||||
 | 
			
		||||
// Note: new type must append to the end of list to maintain compatibility.
 | 
			
		||||
const (
 | 
			
		||||
	NoType Type = iota
 | 
			
		||||
	Plain       // 1
 | 
			
		||||
	LDAP        // 2
 | 
			
		||||
	SMTP        // 3
 | 
			
		||||
	PAM         // 4
 | 
			
		||||
	DLDAP       // 5
 | 
			
		||||
	OAuth2      // 6
 | 
			
		||||
	SSPI        // 7
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// String returns the string name of the LoginType
 | 
			
		||||
func (typ Type) String() string {
 | 
			
		||||
	return Names[typ]
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Int returns the int value of the LoginType
 | 
			
		||||
func (typ Type) Int() int {
 | 
			
		||||
	return int(typ)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Names contains the name of LoginType values.
 | 
			
		||||
var Names = map[Type]string{
 | 
			
		||||
	LDAP:   "LDAP (via BindDN)",
 | 
			
		||||
	DLDAP:  "LDAP (simple auth)", // Via direct bind
 | 
			
		||||
	SMTP:   "SMTP",
 | 
			
		||||
	PAM:    "PAM",
 | 
			
		||||
	OAuth2: "OAuth2",
 | 
			
		||||
	SSPI:   "SPNEGO with SSPI",
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Config represents login config as far as the db is concerned
 | 
			
		||||
type Config interface {
 | 
			
		||||
	convert.Conversion
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SkipVerifiable configurations provide a IsSkipVerify to check if SkipVerify is set
 | 
			
		||||
type SkipVerifiable interface {
 | 
			
		||||
	IsSkipVerify() bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// HasTLSer configurations provide a HasTLS to check if TLS can be enabled
 | 
			
		||||
type HasTLSer interface {
 | 
			
		||||
	HasTLS() bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// UseTLSer configurations provide a HasTLS to check if TLS is enabled
 | 
			
		||||
type UseTLSer interface {
 | 
			
		||||
	UseTLS() bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SSHKeyProvider configurations provide ProvidesSSHKeys to check if they provide SSHKeys
 | 
			
		||||
type SSHKeyProvider interface {
 | 
			
		||||
	ProvidesSSHKeys() bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RegisterableSource configurations provide RegisterSource which needs to be run on creation
 | 
			
		||||
type RegisterableSource interface {
 | 
			
		||||
	RegisterSource() error
 | 
			
		||||
	UnregisterSource() error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var registeredConfigs = map[Type]func() Config{}
 | 
			
		||||
 | 
			
		||||
// RegisterTypeConfig register a config for a provided type
 | 
			
		||||
func RegisterTypeConfig(typ Type, exemplar Config) {
 | 
			
		||||
	if reflect.TypeOf(exemplar).Kind() == reflect.Ptr {
 | 
			
		||||
		// Pointer:
 | 
			
		||||
		registeredConfigs[typ] = func() Config {
 | 
			
		||||
			return reflect.New(reflect.ValueOf(exemplar).Elem().Type()).Interface().(Config)
 | 
			
		||||
		}
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Not a Pointer
 | 
			
		||||
	registeredConfigs[typ] = func() Config {
 | 
			
		||||
		return reflect.New(reflect.TypeOf(exemplar)).Elem().Interface().(Config)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SourceSettable configurations can have their authSource set on them
 | 
			
		||||
type SourceSettable interface {
 | 
			
		||||
	SetAuthSource(*Source)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Source represents an external way for authorizing users.
 | 
			
		||||
type Source struct {
 | 
			
		||||
	ID            int64 `xorm:"pk autoincr"`
 | 
			
		||||
	Type          Type
 | 
			
		||||
	Name          string             `xorm:"UNIQUE"`
 | 
			
		||||
	IsActive      bool               `xorm:"INDEX NOT NULL DEFAULT false"`
 | 
			
		||||
	IsSyncEnabled bool               `xorm:"INDEX NOT NULL DEFAULT false"`
 | 
			
		||||
	Cfg           convert.Conversion `xorm:"TEXT"`
 | 
			
		||||
 | 
			
		||||
	CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"`
 | 
			
		||||
	UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TableName xorm will read the table name from this method
 | 
			
		||||
func (Source) TableName() string {
 | 
			
		||||
	return "login_source"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
	db.RegisterModel(new(Source))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// BeforeSet is invoked from XORM before setting the value of a field of this object.
 | 
			
		||||
func (source *Source) BeforeSet(colName string, val xorm.Cell) {
 | 
			
		||||
	if colName == "type" {
 | 
			
		||||
		typ := Type(db.Cell2Int64(val))
 | 
			
		||||
		constructor, ok := registeredConfigs[typ]
 | 
			
		||||
		if !ok {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		source.Cfg = constructor()
 | 
			
		||||
		if settable, ok := source.Cfg.(SourceSettable); ok {
 | 
			
		||||
			settable.SetAuthSource(source)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TypeName return name of this login source type.
 | 
			
		||||
func (source *Source) TypeName() string {
 | 
			
		||||
	return Names[source.Type]
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// IsLDAP returns true of this source is of the LDAP type.
 | 
			
		||||
func (source *Source) IsLDAP() bool {
 | 
			
		||||
	return source.Type == LDAP
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// IsDLDAP returns true of this source is of the DLDAP type.
 | 
			
		||||
func (source *Source) IsDLDAP() bool {
 | 
			
		||||
	return source.Type == DLDAP
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// IsSMTP returns true of this source is of the SMTP type.
 | 
			
		||||
func (source *Source) IsSMTP() bool {
 | 
			
		||||
	return source.Type == SMTP
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// IsPAM returns true of this source is of the PAM type.
 | 
			
		||||
func (source *Source) IsPAM() bool {
 | 
			
		||||
	return source.Type == PAM
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// IsOAuth2 returns true of this source is of the OAuth2 type.
 | 
			
		||||
func (source *Source) IsOAuth2() bool {
 | 
			
		||||
	return source.Type == OAuth2
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// IsSSPI returns true of this source is of the SSPI type.
 | 
			
		||||
func (source *Source) IsSSPI() bool {
 | 
			
		||||
	return source.Type == SSPI
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// HasTLS returns true of this source supports TLS.
 | 
			
		||||
func (source *Source) HasTLS() bool {
 | 
			
		||||
	hasTLSer, ok := source.Cfg.(HasTLSer)
 | 
			
		||||
	return ok && hasTLSer.HasTLS()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// UseTLS returns true of this source is configured to use TLS.
 | 
			
		||||
func (source *Source) UseTLS() bool {
 | 
			
		||||
	useTLSer, ok := source.Cfg.(UseTLSer)
 | 
			
		||||
	return ok && useTLSer.UseTLS()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SkipVerify returns true if this source is configured to skip SSL
 | 
			
		||||
// verification.
 | 
			
		||||
func (source *Source) SkipVerify() bool {
 | 
			
		||||
	skipVerifiable, ok := source.Cfg.(SkipVerifiable)
 | 
			
		||||
	return ok && skipVerifiable.IsSkipVerify()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CreateSource inserts a AuthSource in the DB if not already
 | 
			
		||||
// existing with the given name.
 | 
			
		||||
func CreateSource(source *Source) error {
 | 
			
		||||
	has, err := db.GetEngine(db.DefaultContext).Where("name=?", source.Name).Exist(new(Source))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	} else if has {
 | 
			
		||||
		return ErrSourceAlreadyExist{source.Name}
 | 
			
		||||
	}
 | 
			
		||||
	// Synchronization is only available with LDAP for now
 | 
			
		||||
	if !source.IsLDAP() {
 | 
			
		||||
		source.IsSyncEnabled = false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	_, err = db.GetEngine(db.DefaultContext).Insert(source)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if !source.IsActive {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if settable, ok := source.Cfg.(SourceSettable); ok {
 | 
			
		||||
		settable.SetAuthSource(source)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	registerableSource, ok := source.Cfg.(RegisterableSource)
 | 
			
		||||
	if !ok {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = registerableSource.RegisterSource()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		// remove the AuthSource in case of errors while registering configuration
 | 
			
		||||
		if _, err := db.GetEngine(db.DefaultContext).Delete(source); err != nil {
 | 
			
		||||
			log.Error("CreateSource: Error while wrapOpenIDConnectInitializeError: %v", err)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Sources returns a slice of all login sources found in DB.
 | 
			
		||||
func Sources() ([]*Source, error) {
 | 
			
		||||
	auths := make([]*Source, 0, 6)
 | 
			
		||||
	return auths, db.GetEngine(db.DefaultContext).Find(&auths)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SourcesByType returns all sources of the specified type
 | 
			
		||||
func SourcesByType(loginType Type) ([]*Source, error) {
 | 
			
		||||
	sources := make([]*Source, 0, 1)
 | 
			
		||||
	if err := db.GetEngine(db.DefaultContext).Where("type = ?", loginType).Find(&sources); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return sources, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// AllActiveSources returns all active sources
 | 
			
		||||
func AllActiveSources() ([]*Source, error) {
 | 
			
		||||
	sources := make([]*Source, 0, 5)
 | 
			
		||||
	if err := db.GetEngine(db.DefaultContext).Where("is_active = ?", true).Find(&sources); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return sources, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ActiveSources returns all active sources of the specified type
 | 
			
		||||
func ActiveSources(tp Type) ([]*Source, error) {
 | 
			
		||||
	sources := make([]*Source, 0, 1)
 | 
			
		||||
	if err := db.GetEngine(db.DefaultContext).Where("is_active = ? and type = ?", true, tp).Find(&sources); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return sources, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// IsSSPIEnabled returns true if there is at least one activated login
 | 
			
		||||
// source of type LoginSSPI
 | 
			
		||||
func IsSSPIEnabled() bool {
 | 
			
		||||
	if !db.HasEngine {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	sources, err := ActiveSources(SSPI)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Error("ActiveSources: %v", err)
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	return len(sources) > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetSourceByID returns login source by given ID.
 | 
			
		||||
func GetSourceByID(id int64) (*Source, error) {
 | 
			
		||||
	source := new(Source)
 | 
			
		||||
	if id == 0 {
 | 
			
		||||
		source.Cfg = registeredConfigs[NoType]()
 | 
			
		||||
		// Set this source to active
 | 
			
		||||
		// FIXME: allow disabling of db based password authentication in future
 | 
			
		||||
		source.IsActive = true
 | 
			
		||||
		return source, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	has, err := db.GetEngine(db.DefaultContext).ID(id).Get(source)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	} else if !has {
 | 
			
		||||
		return nil, ErrSourceNotExist{id}
 | 
			
		||||
	}
 | 
			
		||||
	return source, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// UpdateSource updates a Source record in DB.
 | 
			
		||||
func UpdateSource(source *Source) error {
 | 
			
		||||
	var originalSource *Source
 | 
			
		||||
	if source.IsOAuth2() {
 | 
			
		||||
		// keep track of the original values so we can restore in case of errors while registering OAuth2 providers
 | 
			
		||||
		var err error
 | 
			
		||||
		if originalSource, err = GetSourceByID(source.ID); err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	_, err := db.GetEngine(db.DefaultContext).ID(source.ID).AllCols().Update(source)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if !source.IsActive {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if settable, ok := source.Cfg.(SourceSettable); ok {
 | 
			
		||||
		settable.SetAuthSource(source)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	registerableSource, ok := source.Cfg.(RegisterableSource)
 | 
			
		||||
	if !ok {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = registerableSource.RegisterSource()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		// restore original values since we cannot update the provider it self
 | 
			
		||||
		if _, err := db.GetEngine(db.DefaultContext).ID(source.ID).AllCols().Update(originalSource); err != nil {
 | 
			
		||||
			log.Error("UpdateSource: Error while wrapOpenIDConnectInitializeError: %v", err)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CountSources returns number of login sources.
 | 
			
		||||
func CountSources() int64 {
 | 
			
		||||
	count, _ := db.GetEngine(db.DefaultContext).Count(new(Source))
 | 
			
		||||
	return count
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ErrSourceNotExist represents a "SourceNotExist" kind of error.
 | 
			
		||||
type ErrSourceNotExist struct {
 | 
			
		||||
	ID int64
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// IsErrSourceNotExist checks if an error is a ErrSourceNotExist.
 | 
			
		||||
func IsErrSourceNotExist(err error) bool {
 | 
			
		||||
	_, ok := err.(ErrSourceNotExist)
 | 
			
		||||
	return ok
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (err ErrSourceNotExist) Error() string {
 | 
			
		||||
	return fmt.Sprintf("login source does not exist [id: %d]", err.ID)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ErrSourceAlreadyExist represents a "SourceAlreadyExist" kind of error.
 | 
			
		||||
type ErrSourceAlreadyExist struct {
 | 
			
		||||
	Name string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// IsErrSourceAlreadyExist checks if an error is a ErrSourceAlreadyExist.
 | 
			
		||||
func IsErrSourceAlreadyExist(err error) bool {
 | 
			
		||||
	_, ok := err.(ErrSourceAlreadyExist)
 | 
			
		||||
	return ok
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (err ErrSourceAlreadyExist) Error() string {
 | 
			
		||||
	return fmt.Sprintf("login source already exists [name: %s]", err.Name)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ErrSourceInUse represents a "SourceInUse" kind of error.
 | 
			
		||||
type ErrSourceInUse struct {
 | 
			
		||||
	ID int64
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// IsErrSourceInUse checks if an error is a ErrSourceInUse.
 | 
			
		||||
func IsErrSourceInUse(err error) bool {
 | 
			
		||||
	_, ok := err.(ErrSourceInUse)
 | 
			
		||||
	return ok
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (err ErrSourceInUse) Error() string {
 | 
			
		||||
	return fmt.Sprintf("login source is still used by some users [id: %d]", err.ID)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										60
									
								
								models/auth/source_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										60
									
								
								models/auth/source_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,60 @@
 | 
			
		||||
// Copyright 2019 The Gitea Authors. All rights reserved.
 | 
			
		||||
// Use of this source code is governed by a MIT-style
 | 
			
		||||
// license that can be found in the LICENSE file.
 | 
			
		||||
 | 
			
		||||
package auth
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"strings"
 | 
			
		||||
	"testing"
 | 
			
		||||
 | 
			
		||||
	"code.gitea.io/gitea/models/db"
 | 
			
		||||
	"code.gitea.io/gitea/models/unittest"
 | 
			
		||||
	"code.gitea.io/gitea/modules/json"
 | 
			
		||||
 | 
			
		||||
	"github.com/stretchr/testify/assert"
 | 
			
		||||
	"xorm.io/xorm/schemas"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type TestSource struct {
 | 
			
		||||
	Provider                      string
 | 
			
		||||
	ClientID                      string
 | 
			
		||||
	ClientSecret                  string
 | 
			
		||||
	OpenIDConnectAutoDiscoveryURL string
 | 
			
		||||
	IconURL                       string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// FromDB fills up a LDAPConfig from serialized format.
 | 
			
		||||
func (source *TestSource) FromDB(bs []byte) error {
 | 
			
		||||
	return json.Unmarshal(bs, &source)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ToDB exports a LDAPConfig to a serialized format.
 | 
			
		||||
func (source *TestSource) ToDB() ([]byte, error) {
 | 
			
		||||
	return json.Marshal(source)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestDumpAuthSource(t *testing.T) {
 | 
			
		||||
	assert.NoError(t, unittest.PrepareTestDatabase())
 | 
			
		||||
 | 
			
		||||
	authSourceSchema, err := db.TableInfo(new(Source))
 | 
			
		||||
	assert.NoError(t, err)
 | 
			
		||||
 | 
			
		||||
	RegisterTypeConfig(OAuth2, new(TestSource))
 | 
			
		||||
 | 
			
		||||
	CreateSource(&Source{
 | 
			
		||||
		Type:     OAuth2,
 | 
			
		||||
		Name:     "TestSource",
 | 
			
		||||
		IsActive: false,
 | 
			
		||||
		Cfg: &TestSource{
 | 
			
		||||
			Provider: "ConvertibleSourceName",
 | 
			
		||||
			ClientID: "42",
 | 
			
		||||
		},
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	sb := new(strings.Builder)
 | 
			
		||||
 | 
			
		||||
	db.DumpTables([]*schemas.Table{authSourceSchema}, sb)
 | 
			
		||||
 | 
			
		||||
	assert.Contains(t, sb.String(), `"Provider":"ConvertibleSourceName"`)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										156
									
								
								models/auth/twofactor.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										156
									
								
								models/auth/twofactor.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,156 @@
 | 
			
		||||
// Copyright 2017 The Gitea Authors. All rights reserved.
 | 
			
		||||
// Use of this source code is governed by a MIT-style
 | 
			
		||||
// license that can be found in the LICENSE file.
 | 
			
		||||
 | 
			
		||||
package auth
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"crypto/md5"
 | 
			
		||||
	"crypto/sha256"
 | 
			
		||||
	"crypto/subtle"
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"fmt"
 | 
			
		||||
 | 
			
		||||
	"code.gitea.io/gitea/models/db"
 | 
			
		||||
	"code.gitea.io/gitea/modules/secret"
 | 
			
		||||
	"code.gitea.io/gitea/modules/setting"
 | 
			
		||||
	"code.gitea.io/gitea/modules/timeutil"
 | 
			
		||||
	"code.gitea.io/gitea/modules/util"
 | 
			
		||||
 | 
			
		||||
	"github.com/pquerna/otp/totp"
 | 
			
		||||
	"golang.org/x/crypto/pbkdf2"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
//
 | 
			
		||||
// Two-factor authentication
 | 
			
		||||
//
 | 
			
		||||
 | 
			
		||||
// ErrTwoFactorNotEnrolled indicates that a user is not enrolled in two-factor authentication.
 | 
			
		||||
type ErrTwoFactorNotEnrolled struct {
 | 
			
		||||
	UID int64
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// IsErrTwoFactorNotEnrolled checks if an error is a ErrTwoFactorNotEnrolled.
 | 
			
		||||
func IsErrTwoFactorNotEnrolled(err error) bool {
 | 
			
		||||
	_, ok := err.(ErrTwoFactorNotEnrolled)
 | 
			
		||||
	return ok
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (err ErrTwoFactorNotEnrolled) Error() string {
 | 
			
		||||
	return fmt.Sprintf("user not enrolled in 2FA [uid: %d]", err.UID)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TwoFactor represents a two-factor authentication token.
 | 
			
		||||
type TwoFactor struct {
 | 
			
		||||
	ID               int64 `xorm:"pk autoincr"`
 | 
			
		||||
	UID              int64 `xorm:"UNIQUE"`
 | 
			
		||||
	Secret           string
 | 
			
		||||
	ScratchSalt      string
 | 
			
		||||
	ScratchHash      string
 | 
			
		||||
	LastUsedPasscode string             `xorm:"VARCHAR(10)"`
 | 
			
		||||
	CreatedUnix      timeutil.TimeStamp `xorm:"INDEX created"`
 | 
			
		||||
	UpdatedUnix      timeutil.TimeStamp `xorm:"INDEX updated"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
	db.RegisterModel(new(TwoFactor))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GenerateScratchToken recreates the scratch token the user is using.
 | 
			
		||||
func (t *TwoFactor) GenerateScratchToken() (string, error) {
 | 
			
		||||
	token, err := util.RandomString(8)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
	t.ScratchSalt, _ = util.RandomString(10)
 | 
			
		||||
	t.ScratchHash = HashToken(token, t.ScratchSalt)
 | 
			
		||||
	return token, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// HashToken return the hashable salt
 | 
			
		||||
func HashToken(token, salt string) string {
 | 
			
		||||
	tempHash := pbkdf2.Key([]byte(token), []byte(salt), 10000, 50, sha256.New)
 | 
			
		||||
	return fmt.Sprintf("%x", tempHash)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// VerifyScratchToken verifies if the specified scratch token is valid.
 | 
			
		||||
func (t *TwoFactor) VerifyScratchToken(token string) bool {
 | 
			
		||||
	if len(token) == 0 {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	tempHash := HashToken(token, t.ScratchSalt)
 | 
			
		||||
	return subtle.ConstantTimeCompare([]byte(t.ScratchHash), []byte(tempHash)) == 1
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (t *TwoFactor) getEncryptionKey() []byte {
 | 
			
		||||
	k := md5.Sum([]byte(setting.SecretKey))
 | 
			
		||||
	return k[:]
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SetSecret sets the 2FA secret.
 | 
			
		||||
func (t *TwoFactor) SetSecret(secretString string) error {
 | 
			
		||||
	secretBytes, err := secret.AesEncrypt(t.getEncryptionKey(), []byte(secretString))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	t.Secret = base64.StdEncoding.EncodeToString(secretBytes)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ValidateTOTP validates the provided passcode.
 | 
			
		||||
func (t *TwoFactor) ValidateTOTP(passcode string) (bool, error) {
 | 
			
		||||
	decodedStoredSecret, err := base64.StdEncoding.DecodeString(t.Secret)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return false, err
 | 
			
		||||
	}
 | 
			
		||||
	secretBytes, err := secret.AesDecrypt(t.getEncryptionKey(), decodedStoredSecret)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return false, err
 | 
			
		||||
	}
 | 
			
		||||
	secretStr := string(secretBytes)
 | 
			
		||||
	return totp.Validate(passcode, secretStr), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewTwoFactor creates a new two-factor authentication token.
 | 
			
		||||
func NewTwoFactor(t *TwoFactor) error {
 | 
			
		||||
	_, err := db.GetEngine(db.DefaultContext).Insert(t)
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// UpdateTwoFactor updates a two-factor authentication token.
 | 
			
		||||
func UpdateTwoFactor(t *TwoFactor) error {
 | 
			
		||||
	_, err := db.GetEngine(db.DefaultContext).ID(t.ID).AllCols().Update(t)
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetTwoFactorByUID returns the two-factor authentication token associated with
 | 
			
		||||
// the user, if any.
 | 
			
		||||
func GetTwoFactorByUID(uid int64) (*TwoFactor, error) {
 | 
			
		||||
	twofa := &TwoFactor{}
 | 
			
		||||
	has, err := db.GetEngine(db.DefaultContext).Where("uid=?", uid).Get(twofa)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	} else if !has {
 | 
			
		||||
		return nil, ErrTwoFactorNotEnrolled{uid}
 | 
			
		||||
	}
 | 
			
		||||
	return twofa, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// HasTwoFactorByUID returns the two-factor authentication token associated with
 | 
			
		||||
// the user, if any.
 | 
			
		||||
func HasTwoFactorByUID(uid int64) (bool, error) {
 | 
			
		||||
	return db.GetEngine(db.DefaultContext).Where("uid=?", uid).Exist(&TwoFactor{})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// DeleteTwoFactorByID deletes two-factor authentication token by given ID.
 | 
			
		||||
func DeleteTwoFactorByID(id, userID int64) error {
 | 
			
		||||
	cnt, err := db.GetEngine(db.DefaultContext).ID(id).Delete(&TwoFactor{
 | 
			
		||||
		UID: userID,
 | 
			
		||||
	})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	} else if cnt != 1 {
 | 
			
		||||
		return ErrTwoFactorNotEnrolled{userID}
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										154
									
								
								models/auth/u2f.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										154
									
								
								models/auth/u2f.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,154 @@
 | 
			
		||||
// Copyright 2018 The Gitea Authors. All rights reserved.
 | 
			
		||||
// Use of this source code is governed by a MIT-style
 | 
			
		||||
// license that can be found in the LICENSE file.
 | 
			
		||||
 | 
			
		||||
package auth
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
 | 
			
		||||
	"code.gitea.io/gitea/models/db"
 | 
			
		||||
	"code.gitea.io/gitea/modules/log"
 | 
			
		||||
	"code.gitea.io/gitea/modules/timeutil"
 | 
			
		||||
 | 
			
		||||
	"github.com/tstranex/u2f"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// ____ ________________________________              .__          __                 __  .__
 | 
			
		||||
// |    |   \_____  \_   _____/\______   \ ____   ____ |__| _______/  |_____________ _/  |_|__| ____   ____
 | 
			
		||||
// |    |   //  ____/|    __)   |       _// __ \ / ___\|  |/  ___/\   __\_  __ \__  \\   __\  |/  _ \ /    \
 | 
			
		||||
// |    |  //       \|     \    |    |   \  ___// /_/  >  |\___ \  |  |  |  | \// __ \|  | |  (  <_> )   |  \
 | 
			
		||||
// |______/ \_______ \___  /    |____|_  /\___  >___  /|__/____  > |__|  |__|  (____  /__| |__|\____/|___|  /
 | 
			
		||||
// \/   \/            \/     \/_____/         \/                   \/                    \/
 | 
			
		||||
 | 
			
		||||
// ErrU2FRegistrationNotExist represents a "ErrU2FRegistrationNotExist" kind of error.
 | 
			
		||||
type ErrU2FRegistrationNotExist struct {
 | 
			
		||||
	ID int64
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (err ErrU2FRegistrationNotExist) Error() string {
 | 
			
		||||
	return fmt.Sprintf("U2F registration does not exist [id: %d]", err.ID)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// IsErrU2FRegistrationNotExist checks if an error is a ErrU2FRegistrationNotExist.
 | 
			
		||||
func IsErrU2FRegistrationNotExist(err error) bool {
 | 
			
		||||
	_, ok := err.(ErrU2FRegistrationNotExist)
 | 
			
		||||
	return ok
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// U2FRegistration represents the registration data and counter of a security key
 | 
			
		||||
type U2FRegistration struct {
 | 
			
		||||
	ID          int64 `xorm:"pk autoincr"`
 | 
			
		||||
	Name        string
 | 
			
		||||
	UserID      int64 `xorm:"INDEX"`
 | 
			
		||||
	Raw         []byte
 | 
			
		||||
	Counter     uint32             `xorm:"BIGINT"`
 | 
			
		||||
	CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"`
 | 
			
		||||
	UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
	db.RegisterModel(new(U2FRegistration))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TableName returns a better table name for U2FRegistration
 | 
			
		||||
func (reg U2FRegistration) TableName() string {
 | 
			
		||||
	return "u2f_registration"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Parse will convert the db entry U2FRegistration to an u2f.Registration struct
 | 
			
		||||
func (reg *U2FRegistration) Parse() (*u2f.Registration, error) {
 | 
			
		||||
	r := new(u2f.Registration)
 | 
			
		||||
	return r, r.UnmarshalBinary(reg.Raw)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (reg *U2FRegistration) updateCounter(e db.Engine) error {
 | 
			
		||||
	_, err := e.ID(reg.ID).Cols("counter").Update(reg)
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// UpdateCounter will update the database value of counter
 | 
			
		||||
func (reg *U2FRegistration) UpdateCounter() error {
 | 
			
		||||
	return reg.updateCounter(db.GetEngine(db.DefaultContext))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// U2FRegistrationList is a list of *U2FRegistration
 | 
			
		||||
type U2FRegistrationList []*U2FRegistration
 | 
			
		||||
 | 
			
		||||
// ToRegistrations will convert all U2FRegistrations to u2f.Registrations
 | 
			
		||||
func (list U2FRegistrationList) ToRegistrations() []u2f.Registration {
 | 
			
		||||
	regs := make([]u2f.Registration, 0, len(list))
 | 
			
		||||
	for _, reg := range list {
 | 
			
		||||
		r, err := reg.Parse()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			log.Error("parsing u2f registration: %v", err)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		regs = append(regs, *r)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return regs
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getU2FRegistrationsByUID(e db.Engine, uid int64) (U2FRegistrationList, error) {
 | 
			
		||||
	regs := make(U2FRegistrationList, 0)
 | 
			
		||||
	return regs, e.Where("user_id = ?", uid).Find(®s)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetU2FRegistrationByID returns U2F registration by id
 | 
			
		||||
func GetU2FRegistrationByID(id int64) (*U2FRegistration, error) {
 | 
			
		||||
	return getU2FRegistrationByID(db.GetEngine(db.DefaultContext), id)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getU2FRegistrationByID(e db.Engine, id int64) (*U2FRegistration, error) {
 | 
			
		||||
	reg := new(U2FRegistration)
 | 
			
		||||
	if found, err := e.ID(id).Get(reg); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	} else if !found {
 | 
			
		||||
		return nil, ErrU2FRegistrationNotExist{ID: id}
 | 
			
		||||
	}
 | 
			
		||||
	return reg, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetU2FRegistrationsByUID returns all U2F registrations of the given user
 | 
			
		||||
func GetU2FRegistrationsByUID(uid int64) (U2FRegistrationList, error) {
 | 
			
		||||
	return getU2FRegistrationsByUID(db.GetEngine(db.DefaultContext), uid)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// HasU2FRegistrationsByUID returns whether a given user has U2F registrations
 | 
			
		||||
func HasU2FRegistrationsByUID(uid int64) (bool, error) {
 | 
			
		||||
	return db.GetEngine(db.DefaultContext).Where("user_id = ?", uid).Exist(&U2FRegistration{})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func createRegistration(e db.Engine, userID int64, name string, reg *u2f.Registration) (*U2FRegistration, error) {
 | 
			
		||||
	raw, err := reg.MarshalBinary()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	r := &U2FRegistration{
 | 
			
		||||
		UserID:  userID,
 | 
			
		||||
		Name:    name,
 | 
			
		||||
		Counter: 0,
 | 
			
		||||
		Raw:     raw,
 | 
			
		||||
	}
 | 
			
		||||
	_, err = e.InsertOne(r)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return r, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CreateRegistration will create a new U2FRegistration from the given Registration
 | 
			
		||||
func CreateRegistration(userID int64, name string, reg *u2f.Registration) (*U2FRegistration, error) {
 | 
			
		||||
	return createRegistration(db.GetEngine(db.DefaultContext), userID, name, reg)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// DeleteRegistration will delete U2FRegistration
 | 
			
		||||
func DeleteRegistration(reg *U2FRegistration) error {
 | 
			
		||||
	return deleteRegistration(db.GetEngine(db.DefaultContext), reg)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func deleteRegistration(e db.Engine, reg *U2FRegistration) error {
 | 
			
		||||
	_, err := e.Delete(reg)
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										100
									
								
								models/auth/u2f_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										100
									
								
								models/auth/u2f_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,100 @@
 | 
			
		||||
// Copyright 2020 The Gitea Authors. All rights reserved.
 | 
			
		||||
// Use of this source code is governed by a MIT-style
 | 
			
		||||
// license that can be found in the LICENSE file.
 | 
			
		||||
 | 
			
		||||
package auth
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/hex"
 | 
			
		||||
	"testing"
 | 
			
		||||
 | 
			
		||||
	"code.gitea.io/gitea/models/unittest"
 | 
			
		||||
 | 
			
		||||
	"github.com/stretchr/testify/assert"
 | 
			
		||||
	"github.com/tstranex/u2f"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestGetU2FRegistrationByID(t *testing.T) {
 | 
			
		||||
	assert.NoError(t, unittest.PrepareTestDatabase())
 | 
			
		||||
 | 
			
		||||
	res, err := GetU2FRegistrationByID(1)
 | 
			
		||||
	assert.NoError(t, err)
 | 
			
		||||
	assert.Equal(t, "U2F Key", res.Name)
 | 
			
		||||
 | 
			
		||||
	_, err = GetU2FRegistrationByID(342432)
 | 
			
		||||
	assert.Error(t, err)
 | 
			
		||||
	assert.True(t, IsErrU2FRegistrationNotExist(err))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestGetU2FRegistrationsByUID(t *testing.T) {
 | 
			
		||||
	assert.NoError(t, unittest.PrepareTestDatabase())
 | 
			
		||||
 | 
			
		||||
	res, err := GetU2FRegistrationsByUID(32)
 | 
			
		||||
 | 
			
		||||
	assert.NoError(t, err)
 | 
			
		||||
	assert.Len(t, res, 1)
 | 
			
		||||
	assert.Equal(t, "U2F Key", res[0].Name)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestU2FRegistration_TableName(t *testing.T) {
 | 
			
		||||
	assert.Equal(t, "u2f_registration", U2FRegistration{}.TableName())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestU2FRegistration_UpdateCounter(t *testing.T) {
 | 
			
		||||
	assert.NoError(t, unittest.PrepareTestDatabase())
 | 
			
		||||
	reg := unittest.AssertExistsAndLoadBean(t, &U2FRegistration{ID: 1}).(*U2FRegistration)
 | 
			
		||||
	reg.Counter = 1
 | 
			
		||||
	assert.NoError(t, reg.UpdateCounter())
 | 
			
		||||
	unittest.AssertExistsIf(t, true, &U2FRegistration{ID: 1, Counter: 1})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestU2FRegistration_UpdateLargeCounter(t *testing.T) {
 | 
			
		||||
	assert.NoError(t, unittest.PrepareTestDatabase())
 | 
			
		||||
	reg := unittest.AssertExistsAndLoadBean(t, &U2FRegistration{ID: 1}).(*U2FRegistration)
 | 
			
		||||
	reg.Counter = 0xffffffff
 | 
			
		||||
	assert.NoError(t, reg.UpdateCounter())
 | 
			
		||||
	unittest.AssertExistsIf(t, true, &U2FRegistration{ID: 1, Counter: 0xffffffff})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestCreateRegistration(t *testing.T) {
 | 
			
		||||
	assert.NoError(t, unittest.PrepareTestDatabase())
 | 
			
		||||
 | 
			
		||||
	res, err := CreateRegistration(1, "U2F Created Key", &u2f.Registration{Raw: []byte("Test")})
 | 
			
		||||
	assert.NoError(t, err)
 | 
			
		||||
	assert.Equal(t, "U2F Created Key", res.Name)
 | 
			
		||||
	assert.Equal(t, []byte("Test"), res.Raw)
 | 
			
		||||
 | 
			
		||||
	unittest.AssertExistsIf(t, true, &U2FRegistration{Name: "U2F Created Key", UserID: 1})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestDeleteRegistration(t *testing.T) {
 | 
			
		||||
	assert.NoError(t, unittest.PrepareTestDatabase())
 | 
			
		||||
	reg := unittest.AssertExistsAndLoadBean(t, &U2FRegistration{ID: 1}).(*U2FRegistration)
 | 
			
		||||
 | 
			
		||||
	assert.NoError(t, DeleteRegistration(reg))
 | 
			
		||||
	unittest.AssertNotExistsBean(t, &U2FRegistration{ID: 1})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const validU2FRegistrationResponseHex = "0504b174bc49c7ca254b70d2e5c207cee9cf174820ebd77ea3c65508c26da51b657c1cc6b952f8621697936482da0a6d3d3826a59095daf6cd7c03e2e60385d2f6d9402a552dfdb7477ed65fd84133f86196010b2215b57da75d315b7b9e8fe2e3925a6019551bab61d16591659cbaf00b4950f7abfe6660e2e006f76868b772d70c253082013c3081e4a003020102020a47901280001155957352300a06082a8648ce3d0403023017311530130603550403130c476e756262792050696c6f74301e170d3132303831343138323933325a170d3133303831343138323933325a3031312f302d0603550403132650696c6f74476e756262792d302e342e312d34373930313238303030313135353935373335323059301306072a8648ce3d020106082a8648ce3d030107034200048d617e65c9508e64bcc5673ac82a6799da3c1446682c258c463fffdf58dfd2fa3e6c378b53d795c4a4dffb4199edd7862f23abaf0203b4b8911ba0569994e101300a06082a8648ce3d0403020347003044022060cdb6061e9c22262d1aac1d96d8c70829b2366531dda268832cb836bcd30dfa0220631b1459f09e6330055722c8d89b7f48883b9089b88d60d1d9795902b30410df304502201471899bcc3987e62e8202c9b39c33c19033f7340352dba80fcab017db9230e402210082677d673d891933ade6f617e5dbde2e247e70423fd5ad7804a6d3d3961ef871"
 | 
			
		||||
 | 
			
		||||
func TestToRegistrations_SkipInvalidItemsWithoutCrashing(t *testing.T) {
 | 
			
		||||
	regKeyRaw, _ := hex.DecodeString(validU2FRegistrationResponseHex)
 | 
			
		||||
	regs := U2FRegistrationList{
 | 
			
		||||
		&U2FRegistration{ID: 1},
 | 
			
		||||
		&U2FRegistration{ID: 2, Name: "U2F Key", UserID: 2, Counter: 0, Raw: regKeyRaw, CreatedUnix: 946684800, UpdatedUnix: 946684800},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	actual := regs.ToRegistrations()
 | 
			
		||||
	assert.Len(t, actual, 1)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestToRegistrations(t *testing.T) {
 | 
			
		||||
	regKeyRaw, _ := hex.DecodeString(validU2FRegistrationResponseHex)
 | 
			
		||||
	regs := U2FRegistrationList{
 | 
			
		||||
		&U2FRegistration{ID: 1, Name: "U2F Key", UserID: 1, Counter: 0, Raw: regKeyRaw, CreatedUnix: 946684800, UpdatedUnix: 946684800},
 | 
			
		||||
		&U2FRegistration{ID: 2, Name: "U2F Key", UserID: 2, Counter: 0, Raw: regKeyRaw, CreatedUnix: 946684800, UpdatedUnix: 946684800},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	actual := regs.ToRegistrations()
 | 
			
		||||
	assert.Len(t, actual, 2)
 | 
			
		||||
}
 | 
			
		||||
		Reference in New Issue
	
	Block a user