mirror of
				https://gitee.com/gitea/gitea
				synced 2025-11-04 08:30:25 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			343 lines
		
	
	
		
			9.3 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			343 lines
		
	
	
		
			9.3 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
/*
 | 
						|
Package gothic wraps common behaviour when using Goth. This makes it quick, and easy, to get up
 | 
						|
and running with Goth. Of course, if you want complete control over how things flow, in regards
 | 
						|
to the authentication process, feel free and use Goth directly.
 | 
						|
 | 
						|
See https://github.com/markbates/goth/examples/main.go to see this in action.
 | 
						|
*/
 | 
						|
package gothic
 | 
						|
 | 
						|
import (
 | 
						|
	"bytes"
 | 
						|
	"compress/gzip"
 | 
						|
	"crypto/rand"
 | 
						|
	"encoding/base64"
 | 
						|
	"errors"
 | 
						|
	"fmt"
 | 
						|
	"io"
 | 
						|
	"io/ioutil"
 | 
						|
	"net/http"
 | 
						|
	"net/url"
 | 
						|
	"os"
 | 
						|
	"strings"
 | 
						|
 | 
						|
	"github.com/gorilla/mux"
 | 
						|
	"github.com/gorilla/sessions"
 | 
						|
	"github.com/markbates/goth"
 | 
						|
)
 | 
						|
 | 
						|
// SessionName is the key used to access the session store.
 | 
						|
const SessionName = "_gothic_session"
 | 
						|
 | 
						|
// Store can/should be set by applications using gothic. The default is a cookie store.
 | 
						|
var Store sessions.Store
 | 
						|
var defaultStore sessions.Store
 | 
						|
 | 
						|
var keySet = false
 | 
						|
 | 
						|
func init() {
 | 
						|
	key := []byte(os.Getenv("SESSION_SECRET"))
 | 
						|
	keySet = len(key) != 0
 | 
						|
 | 
						|
	cookieStore := sessions.NewCookieStore([]byte(key))
 | 
						|
	cookieStore.Options.HttpOnly = true
 | 
						|
	Store = cookieStore
 | 
						|
	defaultStore = Store
 | 
						|
}
 | 
						|
 | 
						|
/*
 | 
						|
BeginAuthHandler is a convenience handler for starting the authentication process.
 | 
						|
It expects to be able to get the name of the provider from the query parameters
 | 
						|
as either "provider" or ":provider".
 | 
						|
 | 
						|
BeginAuthHandler will redirect the user to the appropriate authentication end-point
 | 
						|
for the requested provider.
 | 
						|
 | 
						|
See https://github.com/markbates/goth/examples/main.go to see this in action.
 | 
						|
*/
 | 
						|
func BeginAuthHandler(res http.ResponseWriter, req *http.Request) {
 | 
						|
	url, err := GetAuthURL(res, req)
 | 
						|
	if err != nil {
 | 
						|
		res.WriteHeader(http.StatusBadRequest)
 | 
						|
		fmt.Fprintln(res, err)
 | 
						|
		return
 | 
						|
	}
 | 
						|
 | 
						|
	http.Redirect(res, req, url, http.StatusTemporaryRedirect)
 | 
						|
}
 | 
						|
 | 
						|
// SetState sets the state string associated with the given request.
 | 
						|
// If no state string is associated with the request, one will be generated.
 | 
						|
// This state is sent to the provider and can be retrieved during the
 | 
						|
// callback.
 | 
						|
var SetState = func(req *http.Request) string {
 | 
						|
	state := req.URL.Query().Get("state")
 | 
						|
	if len(state) > 0 {
 | 
						|
		return state
 | 
						|
	}
 | 
						|
 | 
						|
	// If a state query param is not passed in, generate a random
 | 
						|
	// base64-encoded nonce so that the state on the auth URL
 | 
						|
	// is unguessable, preventing CSRF attacks, as described in
 | 
						|
	//
 | 
						|
	// https://auth0.com/docs/protocols/oauth2/oauth-state#keep-reading
 | 
						|
	nonceBytes := make([]byte, 64)
 | 
						|
	_, err := io.ReadFull(rand.Reader, nonceBytes)
 | 
						|
	if err != nil {
 | 
						|
		panic("gothic: source of randomness unavailable: " + err.Error())
 | 
						|
	}
 | 
						|
	return base64.URLEncoding.EncodeToString(nonceBytes)
 | 
						|
}
 | 
						|
 | 
						|
// GetState gets the state returned by the provider during the callback.
 | 
						|
// This is used to prevent CSRF attacks, see
 | 
						|
// http://tools.ietf.org/html/rfc6749#section-10.12
 | 
						|
var GetState = func(req *http.Request) string {
 | 
						|
	return req.URL.Query().Get("state")
 | 
						|
}
 | 
						|
 | 
						|
/*
 | 
						|
GetAuthURL starts the authentication process with the requested provided.
 | 
						|
It will return a URL that should be used to send users to.
 | 
						|
 | 
						|
It expects to be able to get the name of the provider from the query parameters
 | 
						|
as either "provider" or ":provider".
 | 
						|
 | 
						|
I would recommend using the BeginAuthHandler instead of doing all of these steps
 | 
						|
yourself, but that's entirely up to you.
 | 
						|
*/
 | 
						|
func GetAuthURL(res http.ResponseWriter, req *http.Request) (string, error) {
 | 
						|
	if !keySet && defaultStore == Store {
 | 
						|
		fmt.Println("goth/gothic: no SESSION_SECRET environment variable is set. The default cookie store is not available and any calls will fail. Ignore this warning if you are using a different store.")
 | 
						|
	}
 | 
						|
 | 
						|
	providerName, err := GetProviderName(req)
 | 
						|
	if err != nil {
 | 
						|
		return "", err
 | 
						|
	}
 | 
						|
 | 
						|
	provider, err := goth.GetProvider(providerName)
 | 
						|
	if err != nil {
 | 
						|
		return "", err
 | 
						|
	}
 | 
						|
	sess, err := provider.BeginAuth(SetState(req))
 | 
						|
	if err != nil {
 | 
						|
		return "", err
 | 
						|
	}
 | 
						|
 | 
						|
	url, err := sess.GetAuthURL()
 | 
						|
	if err != nil {
 | 
						|
		return "", err
 | 
						|
	}
 | 
						|
 | 
						|
	err = StoreInSession(providerName, sess.Marshal(), req, res)
 | 
						|
 | 
						|
	if err != nil {
 | 
						|
		return "", err
 | 
						|
	}
 | 
						|
 | 
						|
	return url, err
 | 
						|
}
 | 
						|
 | 
						|
/*
 | 
						|
CompleteUserAuth does what it says on the tin. It completes the authentication
 | 
						|
process and fetches all of the basic information about the user from the provider.
 | 
						|
 | 
						|
It expects to be able to get the name of the provider from the query parameters
 | 
						|
as either "provider" or ":provider".
 | 
						|
 | 
						|
See https://github.com/markbates/goth/examples/main.go to see this in action.
 | 
						|
*/
 | 
						|
var CompleteUserAuth = func(res http.ResponseWriter, req *http.Request) (goth.User, error) {
 | 
						|
	defer Logout(res, req)
 | 
						|
	if !keySet && defaultStore == Store {
 | 
						|
		fmt.Println("goth/gothic: no SESSION_SECRET environment variable is set. The default cookie store is not available and any calls will fail. Ignore this warning if you are using a different store.")
 | 
						|
	}
 | 
						|
 | 
						|
	providerName, err := GetProviderName(req)
 | 
						|
	if err != nil {
 | 
						|
		return goth.User{}, err
 | 
						|
	}
 | 
						|
 | 
						|
	provider, err := goth.GetProvider(providerName)
 | 
						|
	if err != nil {
 | 
						|
		return goth.User{}, err
 | 
						|
	}
 | 
						|
 | 
						|
	value, err := GetFromSession(providerName, req)
 | 
						|
	if err != nil {
 | 
						|
		return goth.User{}, err
 | 
						|
	}
 | 
						|
 | 
						|
	sess, err := provider.UnmarshalSession(value)
 | 
						|
	if err != nil {
 | 
						|
		return goth.User{}, err
 | 
						|
	}
 | 
						|
 | 
						|
	err = validateState(req, sess)
 | 
						|
	if err != nil {
 | 
						|
		return goth.User{}, err
 | 
						|
	}
 | 
						|
 | 
						|
	user, err := provider.FetchUser(sess)
 | 
						|
	if err == nil {
 | 
						|
		// user can be found with existing session data
 | 
						|
		return user, err
 | 
						|
	}
 | 
						|
 | 
						|
	// get new token and retry fetch
 | 
						|
	_, err = sess.Authorize(provider, req.URL.Query())
 | 
						|
	if err != nil {
 | 
						|
		return goth.User{}, err
 | 
						|
	}
 | 
						|
 | 
						|
	err = StoreInSession(providerName, sess.Marshal(), req, res)
 | 
						|
 | 
						|
	if err != nil {
 | 
						|
		return goth.User{}, err
 | 
						|
	}
 | 
						|
 | 
						|
	gu, err := provider.FetchUser(sess)
 | 
						|
	return gu, err
 | 
						|
}
 | 
						|
 | 
						|
// validateState ensures that the state token param from the original
 | 
						|
// AuthURL matches the one included in the current (callback) request.
 | 
						|
func validateState(req *http.Request, sess goth.Session) error {
 | 
						|
	rawAuthURL, err := sess.GetAuthURL()
 | 
						|
	if err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
 | 
						|
	authURL, err := url.Parse(rawAuthURL)
 | 
						|
	if err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
 | 
						|
	originalState := authURL.Query().Get("state")
 | 
						|
	if originalState != "" && (originalState != req.URL.Query().Get("state")) {
 | 
						|
		return errors.New("state token mismatch")
 | 
						|
	}
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
// Logout invalidates a user session.
 | 
						|
func Logout(res http.ResponseWriter, req *http.Request) error {
 | 
						|
	session, err := Store.Get(req, SessionName)
 | 
						|
	if err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
	session.Options.MaxAge = -1
 | 
						|
	session.Values = make(map[interface{}]interface{})
 | 
						|
	err = session.Save(req, res)
 | 
						|
	if err != nil {
 | 
						|
		return errors.New("Could not delete user session ")
 | 
						|
	}
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
// GetProviderName is a function used to get the name of a provider
 | 
						|
// for a given request. By default, this provider is fetched from
 | 
						|
// the URL query string. If you provide it in a different way,
 | 
						|
// assign your own function to this variable that returns the provider
 | 
						|
// name for your request.
 | 
						|
var GetProviderName = getProviderName
 | 
						|
 | 
						|
func getProviderName(req *http.Request) (string, error) {
 | 
						|
 | 
						|
	// get all the used providers
 | 
						|
	providers := goth.GetProviders()
 | 
						|
 | 
						|
	// loop over the used providers, if we already have a valid session for any provider (ie. user is already logged-in with a provider), then return that provider name
 | 
						|
	for _, provider := range providers {
 | 
						|
		p := provider.Name()
 | 
						|
		session, _ := Store.Get(req, p+SessionName)
 | 
						|
		value := session.Values[p]
 | 
						|
		if _, ok := value.(string); ok {
 | 
						|
			return p, nil
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	// try to get it from the url param "provider"
 | 
						|
	if p := req.URL.Query().Get("provider"); p != "" {
 | 
						|
		return p, nil
 | 
						|
	}
 | 
						|
 | 
						|
	// try to get it from the url param ":provider"
 | 
						|
	if p := req.URL.Query().Get(":provider"); p != "" {
 | 
						|
		return p, nil
 | 
						|
	}
 | 
						|
 | 
						|
	// try to get it from the context's value of "provider" key
 | 
						|
	if p, ok := mux.Vars(req)["provider"]; ok {
 | 
						|
		return p, nil
 | 
						|
	}
 | 
						|
 | 
						|
	//  try to get it from the go-context's value of "provider" key
 | 
						|
	if p, ok := req.Context().Value("provider").(string); ok {
 | 
						|
		return p, nil
 | 
						|
	}
 | 
						|
 | 
						|
	// if not found then return an empty string with the corresponding error
 | 
						|
	return "", errors.New("you must select a provider")
 | 
						|
}
 | 
						|
 | 
						|
// StoreInSession stores a specified key/value pair in the session.
 | 
						|
func StoreInSession(key string, value string, req *http.Request, res http.ResponseWriter) error {
 | 
						|
	session, _ := Store.New(req, SessionName)
 | 
						|
 | 
						|
	if err := updateSessionValue(session, key, value); err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
 | 
						|
	return session.Save(req, res)
 | 
						|
}
 | 
						|
 | 
						|
// GetFromSession retrieves a previously-stored value from the session.
 | 
						|
// If no value has previously been stored at the specified key, it will return an error.
 | 
						|
func GetFromSession(key string, req *http.Request) (string, error) {
 | 
						|
	session, _ := Store.Get(req, SessionName)
 | 
						|
	value, err := getSessionValue(session, key)
 | 
						|
	if err != nil {
 | 
						|
		return "", errors.New("could not find a matching session for this request")
 | 
						|
	}
 | 
						|
 | 
						|
	return value, nil
 | 
						|
}
 | 
						|
 | 
						|
func getSessionValue(session *sessions.Session, key string) (string, error) {
 | 
						|
	value := session.Values[key]
 | 
						|
	if value == nil {
 | 
						|
		return "", fmt.Errorf("could not find a matching session for this request")
 | 
						|
	}
 | 
						|
 | 
						|
	rdata := strings.NewReader(value.(string))
 | 
						|
	r, err := gzip.NewReader(rdata)
 | 
						|
	if err != nil {
 | 
						|
		return "", err
 | 
						|
	}
 | 
						|
	s, err := ioutil.ReadAll(r)
 | 
						|
	if err != nil {
 | 
						|
		return "", err
 | 
						|
	}
 | 
						|
 | 
						|
	return string(s), nil
 | 
						|
}
 | 
						|
 | 
						|
func updateSessionValue(session *sessions.Session, key, value string) error {
 | 
						|
	var b bytes.Buffer
 | 
						|
	gz := gzip.NewWriter(&b)
 | 
						|
	if _, err := gz.Write([]byte(value)); err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
	if err := gz.Flush(); err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
	if err := gz.Close(); err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
 | 
						|
	session.Values[key] = b.String()
 | 
						|
	return nil
 | 
						|
}
 |