You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

323 lines
8.6 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package RedisStorage
import (
"bytes"
"encoding/gob"
"fmt"
"github.com/RangelReale/osin"
"github.com/garyburd/redigo/redis"
"github.com/pkg/errors"
"github.com/satori/go.uuid"
)
func init() {
gob.Register(map[string]interface{}{})
gob.Register(&osin.DefaultClient{})
gob.Register(osin.AuthorizeData{})
gob.Register(osin.AccessData{})
}
var OsinServer *osin.Server
var OAuth2RedisStorage *RedisStorage
// RedisStorage implements "github.com/RangelReale/osin".RedisStorage
type RedisStorage struct {
Pool *redis.Pool
KeyPrefix string
}
// Clone the storage if needed. For example, using mgo, you can clone the session with session.Clone
// to avoid concurrent access problems.
// This is to avoid cloning the connection at each method access.
// Can return itself if not a problem.
func (s *RedisStorage) Clone() osin.Storage {
return s
}
// Close the resources the RedisStorage potentially holds (using Clone for example)
func (s *RedisStorage) Close() {}
// CreateClient inserts a new client
func (s *RedisStorage) CreateClient(client osin.Client) error {
conn := s.Pool.Get()
if err := conn.Err(); err != nil {
return err
}
defer conn.Close()
payload, err := encode(client)
if err != nil {
return errors.Wrap(err, "failed to encode client")
}
_, err = conn.Do("SET", s.makeKey("client", client.GetId()), payload)
return errors.Wrap(err, "failed to save client")
}
// GetClient gets a client by ID
func (s *RedisStorage) GetClient(id string) (osin.Client, error) {
conn := s.Pool.Get()
if err := conn.Err(); err != nil {
return nil, err
}
defer conn.Close()
var (
rawClientGob interface{}
err error
)
if rawClientGob, err = conn.Do("GET", s.makeKey("client", id)); err != nil {
return nil, errors.Wrap(err, "unable to GET client")
}
if rawClientGob == nil {
return nil, nil
}
clientGob, _ := redis.Bytes(rawClientGob, err)
var client osin.DefaultClient
err = decode(clientGob, &client)
return &client, errors.Wrap(err, "failed to decode client gob")
}
// UpdateClient updates a client
func (s *RedisStorage) UpdateClient(client osin.Client) error {
return errors.Wrap(s.CreateClient(client), "failed to update client")
}
// DeleteClient deletes given client
func (s *RedisStorage) DeleteClient(client osin.Client) error {
conn := s.Pool.Get()
if err := conn.Err(); err != nil {
return err
}
defer conn.Close()
_, err := conn.Do("DEL", s.makeKey("client", client.GetId()))
return errors.Wrap(err, "failed to delete client")
}
// SaveAuthorize saves authorize data.
func (s *RedisStorage) SaveAuthorize(data *osin.AuthorizeData) (err error) {
conn := s.Pool.Get()
if err := conn.Err(); err != nil {
return err
}
defer conn.Close()
payload, err := encode(data)
if err != nil {
return errors.Wrap(err, "failed to encode data")
}
_, err = conn.Do("SETEX", s.makeKey("auth", data.Code), data.ExpiresIn, string(payload))
return errors.Wrap(err, "failed to set auth")
}
// LoadAuthorize looks up AuthorizeData by a code.
// Client information MUST be loaded together.
// Optionally can return error if expired.
func (s *RedisStorage) LoadAuthorize(code string) (*osin.AuthorizeData, error) {
conn := s.Pool.Get()
if err := conn.Err(); err != nil {
return nil, err
}
defer conn.Close()
var (
rawAuthGob interface{}
err error
)
if rawAuthGob, err = conn.Do("GET", s.makeKey("auth", code)); err != nil {
return nil, errors.Wrap(err, "unable to GET auth")
}
if rawAuthGob == nil {
return nil, nil
}
authGob, _ := redis.Bytes(rawAuthGob, err)
var auth osin.AuthorizeData
err = decode(authGob, &auth)
return &auth, errors.Wrap(err, "failed to decode auth")
}
// RemoveAuthorize revokes or deletes the authorization code.
func (s *RedisStorage) RemoveAuthorize(code string) (err error) {
conn := s.Pool.Get()
if err := conn.Err(); err != nil {
return err
}
defer conn.Close()
_, err = conn.Do("DEL", s.makeKey("auth", code))
return errors.Wrap(err, "failed to delete auth")
}
// SaveAccess creates AccessData.
func (s *RedisStorage) SaveAccess(data *osin.AccessData) (err error) {
conn := s.Pool.Get()
if err := conn.Err(); err != nil {
return err
}
defer conn.Close()
payload, err := encode(data)
if err != nil {
return errors.Wrap(err, "failed to encode access")
}
accessID := uuid.NewV4().String()
if _, err := conn.Do("SETEX", s.makeKey("access", accessID), data.ExpiresIn, string(payload)); err != nil {
return errors.Wrap(err, "failed to save access")
}
if _, err := conn.Do("SETEX", s.makeKey("access_token", data.AccessToken), data.ExpiresIn, accessID); err != nil {
return errors.Wrap(err, "failed to register access token")
}
//下面的代码由黄海修改于2020-03-21原来的代码居然把refresh_token和access_token的有效时长写成一样的这肯定不行
//_, err = conn.Do("SETEX", s.makeKey("refresh_token", data.RefreshToken), data.ExpiresIn, accessID)
_, err = conn.Do("SETEX", s.makeKey("refresh_token", data.RefreshToken), 3600*24*30, accessID)
return errors.Wrap(err, "failed to register refresh token")
}
// LoadAccess gets access data with given access token
func (s *RedisStorage) LoadAccess(token string) (*osin.AccessData, error) {
return s.loadAccessByKey(s.makeKey("access_token", token))
}
// RemoveAccess deletes AccessData with given access token
func (s *RedisStorage) RemoveAccess(token string) error {
return s.removeAccessByKey(s.makeKey("access_token", token))
}
// LoadRefresh gets access data with given refresh token
func (s *RedisStorage) LoadRefresh(token string) (*osin.AccessData, error) {
return s.loadAccessByKey(s.makeKey("refresh_token", token))
}
// RemoveRefresh deletes AccessData with given refresh token
func (s *RedisStorage) RemoveRefresh(token string) error {
return s.removeAccessByKey(s.makeKey("refresh_token", token))
}
func (s *RedisStorage) removeAccessByKey(key string) error {
conn := s.Pool.Get()
if err := conn.Err(); err != nil {
return err
}
defer conn.Close()
accessID, err := redis.String(conn.Do("GET", key))
if err != nil {
return errors.Wrap(err, "failed to get access")
}
access, err := s.loadAccessByKey(key)
if err != nil {
return errors.Wrap(err, "unable to load access for removal")
}
if access == nil {
return nil
}
accessKey := s.makeKey("access", accessID)
if _, err := conn.Do("DEL", accessKey); err != nil {
return errors.Wrap(err, "failed to delete access")
}
accessTokenKey := s.makeKey("access_token", access.AccessToken)
if _, err := conn.Do("DEL", accessTokenKey); err != nil {
return errors.Wrap(err, "failed to deregister access_token")
}
refreshTokenKey := s.makeKey("refresh_token", access.RefreshToken)
_, err = conn.Do("DEL", refreshTokenKey)
return errors.Wrap(err, "failed to deregister refresh_token")
}
func (s *RedisStorage) loadAccessByKey(key string) (*osin.AccessData, error) {
conn := s.Pool.Get()
if err := conn.Err(); err != nil {
return nil, err
}
defer conn.Close()
var (
rawAuthGob interface{}
err error
)
if rawAuthGob, err = conn.Do("GET", key); err != nil {
return nil, errors.Wrap(err, "unable to GET auth")
}
if rawAuthGob == nil {
return nil, nil
}
accessID, err := redis.String(conn.Do("GET", key))
if err != nil {
return nil, errors.Wrap(err, "unable to get access ID")
}
accessIDKey := s.makeKey("access", accessID)
accessGob, err := redis.Bytes(conn.Do("GET", accessIDKey))
if err != nil {
return nil, errors.Wrap(err, "unable to get access gob")
}
var access osin.AccessData
if err := decode(accessGob, &access); err != nil {
return nil, errors.Wrap(err, "failed to decode access gob")
}
ttl, err := redis.Int(conn.Do("TTL", accessIDKey))
if err != nil {
return nil, errors.Wrap(err, "unable to get access TTL")
}
access.ExpiresIn = int32(ttl)
access.Client, err = s.GetClient(access.Client.GetId())
if err != nil {
return nil, errors.Wrap(err, "unable to get client for access")
}
if access.AuthorizeData != nil && access.AuthorizeData.Client != nil {
access.AuthorizeData.Client, err = s.GetClient(access.AuthorizeData.Client.GetId())
if err != nil {
return nil, errors.Wrap(err, "unable to get client for access authorize data")
}
}
return &access, nil
}
func (s *RedisStorage) makeKey(namespace, id string) string {
return fmt.Sprintf("%s:%s:%s", s.KeyPrefix, namespace, id)
}
func encode(v interface{}) ([]byte, error) {
var buf bytes.Buffer
if err := gob.NewEncoder(&buf).Encode(v); err != nil {
return nil, errors.Wrap(err, "unable to encode")
}
return buf.Bytes(), nil
}
func decode(data []byte, v interface{}) error {
err := gob.NewDecoder(bytes.NewBuffer(data)).Decode(v)
return errors.Wrap(err, "unable to decode")
}