You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

321 lines
8.3 KiB

package RedisStorage
import (
"bytes"
"encoding/gob"
"fmt"
"github.com/RangelReale/osin"
"github.com/garyburd/redigo/redis"
"github.com/pkg/errors"
"github.com/satori/go.uuid"
)
func init() {
gob.Register(map[string]interface{}{})
gob.Register(&osin.DefaultClient{})
gob.Register(osin.AuthorizeData{})
gob.Register(osin.AccessData{})
}
var OsinServer *osin.Server
var OAuth2RedisStorage *RedisStorage
// RedisStorage implements "github.com/RangelReale/osin".RedisStorage
type RedisStorage struct {
Pool *redis.Pool
KeyPrefix string
}
// Clone the storage if needed. For example, using mgo, you can clone the session with session.Clone
// to avoid concurrent access problems.
// This is to avoid cloning the connection at each method access.
// Can return itself if not a problem.
func (s *RedisStorage) Clone() osin.Storage {
return s
}
// Close the resources the RedisStorage potentially holds (using Clone for example)
func (s *RedisStorage) Close() {}
// CreateClient inserts a new client
func (s *RedisStorage) CreateClient(client osin.Client) error {
conn := s.Pool.Get()
if err := conn.Err(); err != nil {
return err
}
defer conn.Close()
payload, err := encode(client)
if err != nil {
return errors.Wrap(err, "failed to encode client")
}
_, err = conn.Do("SET", s.makeKey("client", client.GetId()), payload)
return errors.Wrap(err, "failed to save client")
}
// GetClient gets a client by ID
func (s *RedisStorage) GetClient(id string) (osin.Client, error) {
conn := s.Pool.Get()
if err := conn.Err(); err != nil {
return nil, err
}
defer conn.Close()
var (
rawClientGob interface{}
err error
)
if rawClientGob, err = conn.Do("GET", s.makeKey("client", id)); err != nil {
return nil, errors.Wrap(err, "unable to GET client")
}
if rawClientGob == nil {
return nil, nil
}
clientGob, _ := redis.Bytes(rawClientGob, err)
var client osin.DefaultClient
err = decode(clientGob, &client)
return &client, errors.Wrap(err, "failed to decode client gob")
}
// UpdateClient updates a client
func (s *RedisStorage) UpdateClient(client osin.Client) error {
return errors.Wrap(s.CreateClient(client), "failed to update client")
}
// DeleteClient deletes given client
func (s *RedisStorage) DeleteClient(client osin.Client) error {
conn := s.Pool.Get()
if err := conn.Err(); err != nil {
return err
}
defer conn.Close()
_, err := conn.Do("DEL", s.makeKey("client", client.GetId()))
return errors.Wrap(err, "failed to delete client")
}
// SaveAuthorize saves authorize data.
func (s *RedisStorage) SaveAuthorize(data *osin.AuthorizeData) (err error) {
conn := s.Pool.Get()
if err := conn.Err(); err != nil {
return err
}
defer conn.Close()
payload, err := encode(data)
if err != nil {
return errors.Wrap(err, "failed to encode data")
}
_, err = conn.Do("SETEX", s.makeKey("auth", data.Code), data.ExpiresIn, string(payload))
return errors.Wrap(err, "failed to set auth")
}
// LoadAuthorize looks up AuthorizeData by a code.
// Client information MUST be loaded together.
// Optionally can return error if expired.
func (s *RedisStorage) LoadAuthorize(code string) (*osin.AuthorizeData, error) {
conn := s.Pool.Get()
if err := conn.Err(); err != nil {
return nil, err
}
defer conn.Close()
var (
rawAuthGob interface{}
err error
)
if rawAuthGob, err = conn.Do("GET", s.makeKey("auth", code)); err != nil {
return nil, errors.Wrap(err, "unable to GET auth")
}
if rawAuthGob == nil {
return nil, nil
}
authGob, _ := redis.Bytes(rawAuthGob, err)
var auth osin.AuthorizeData
err = decode(authGob, &auth)
return &auth, errors.Wrap(err, "failed to decode auth")
}
// RemoveAuthorize revokes or deletes the authorization code.
func (s *RedisStorage) RemoveAuthorize(code string) (err error) {
conn := s.Pool.Get()
if err := conn.Err(); err != nil {
return err
}
defer conn.Close()
_, err = conn.Do("DEL", s.makeKey("auth", code))
return errors.Wrap(err, "failed to delete auth")
}
// SaveAccess creates AccessData.
func (s *RedisStorage) SaveAccess(data *osin.AccessData) (err error) {
conn := s.Pool.Get()
if err := conn.Err(); err != nil {
return err
}
defer conn.Close()
payload, err := encode(data)
if err != nil {
return errors.Wrap(err, "failed to encode access")
}
accessID := uuid.NewV4().String()
if _, err := conn.Do("SETEX", s.makeKey("access", accessID), data.ExpiresIn, string(payload)); err != nil {
return errors.Wrap(err, "failed to save access")
}
if _, err := conn.Do("SETEX", s.makeKey("access_token", data.AccessToken), data.ExpiresIn, accessID); err != nil {
return errors.Wrap(err, "failed to register access token")
}
_, err = conn.Do("SETEX", s.makeKey("refresh_token", data.RefreshToken), data.ExpiresIn, accessID)
return errors.Wrap(err, "failed to register refresh token")
}
// LoadAccess gets access data with given access token
func (s *RedisStorage) LoadAccess(token string) (*osin.AccessData, error) {
return s.loadAccessByKey(s.makeKey("access_token", token))
}
// RemoveAccess deletes AccessData with given access token
func (s *RedisStorage) RemoveAccess(token string) error {
return s.removeAccessByKey(s.makeKey("access_token", token))
}
// LoadRefresh gets access data with given refresh token
func (s *RedisStorage) LoadRefresh(token string) (*osin.AccessData, error) {
return s.loadAccessByKey(s.makeKey("refresh_token", token))
}
// RemoveRefresh deletes AccessData with given refresh token
func (s *RedisStorage) RemoveRefresh(token string) error {
return s.removeAccessByKey(s.makeKey("refresh_token", token))
}
func (s *RedisStorage) removeAccessByKey(key string) error {
conn := s.Pool.Get()
if err := conn.Err(); err != nil {
return err
}
defer conn.Close()
accessID, err := redis.String(conn.Do("GET", key))
if err != nil {
return errors.Wrap(err, "failed to get access")
}
access, err := s.loadAccessByKey(key)
if err != nil {
return errors.Wrap(err, "unable to load access for removal")
}
if access == nil {
return nil
}
accessKey := s.makeKey("access", accessID)
if _, err := conn.Do("DEL", accessKey); err != nil {
return errors.Wrap(err, "failed to delete access")
}
accessTokenKey := s.makeKey("access_token", access.AccessToken)
if _, err := conn.Do("DEL", accessTokenKey); err != nil {
return errors.Wrap(err, "failed to deregister access_token")
}
refreshTokenKey := s.makeKey("refresh_token", access.RefreshToken)
_, err = conn.Do("DEL", refreshTokenKey)
return errors.Wrap(err, "failed to deregister refresh_token")
}
func (s *RedisStorage) loadAccessByKey(key string) (*osin.AccessData, error) {
conn := s.Pool.Get()
if err := conn.Err(); err != nil {
return nil, err
}
defer conn.Close()
var (
rawAuthGob interface{}
err error
)
if rawAuthGob, err = conn.Do("GET", key); err != nil {
return nil, errors.Wrap(err, "unable to GET auth")
}
if rawAuthGob == nil {
return nil, nil
}
accessID, err := redis.String(conn.Do("GET", key))
if err != nil {
return nil, errors.Wrap(err, "unable to get access ID")
}
accessIDKey := s.makeKey("access", accessID)
accessGob, err := redis.Bytes(conn.Do("GET", accessIDKey))
if err != nil {
return nil, errors.Wrap(err, "unable to get access gob")
}
var access osin.AccessData
if err := decode(accessGob, &access); err != nil {
return nil, errors.Wrap(err, "failed to decode access gob")
}
ttl, err := redis.Int(conn.Do("TTL", accessIDKey))
if err != nil {
return nil, errors.Wrap(err, "unable to get access TTL")
}
access.ExpiresIn = int32(ttl)
access.Client, err = s.GetClient(access.Client.GetId())
if err != nil {
return nil, errors.Wrap(err, "unable to get client for access")
}
if access.AuthorizeData != nil && access.AuthorizeData.Client != nil {
access.AuthorizeData.Client, err = s.GetClient(access.AuthorizeData.Client.GetId())
if err != nil {
return nil, errors.Wrap(err, "unable to get client for access authorize data")
}
}
return &access, nil
}
func (s *RedisStorage) makeKey(namespace, id string) string {
return fmt.Sprintf("%s:%s:%s", s.KeyPrefix, namespace, id)
}
func encode(v interface{}) ([]byte, error) {
var buf bytes.Buffer
if err := gob.NewEncoder(&buf).Encode(v); err != nil {
return nil, errors.Wrap(err, "unable to encode")
}
return buf.Bytes(), nil
}
func decode(data []byte, v interface{}) error {
err := gob.NewDecoder(bytes.NewBuffer(data)).Decode(v)
return errors.Wrap(err, "unable to decode")
}