package RedisStorage import ( "bytes" "encoding/gob" "fmt" "github.com/RangelReale/osin" "github.com/garyburd/redigo/redis" "github.com/pkg/errors" "github.com/satori/go.uuid" ) func init() { gob.Register(map[string]interface{}{}) gob.Register(&osin.DefaultClient{}) gob.Register(osin.AuthorizeData{}) gob.Register(osin.AccessData{}) } var OsinServer *osin.Server var OAuth2RedisStorage *RedisStorage // RedisStorage implements "github.com/RangelReale/osin".RedisStorage type RedisStorage struct { Pool *redis.Pool KeyPrefix string } // Clone the storage if needed. For example, using mgo, you can clone the session with session.Clone // to avoid concurrent access problems. // This is to avoid cloning the connection at each method access. // Can return itself if not a problem. func (s *RedisStorage) Clone() osin.Storage { return s } // Close the resources the RedisStorage potentially holds (using Clone for example) func (s *RedisStorage) Close() {} // CreateClient inserts a new client func (s *RedisStorage) CreateClient(client osin.Client) error { conn := s.Pool.Get() if err := conn.Err(); err != nil { return err } defer conn.Close() payload, err := encode(client) if err != nil { return errors.Wrap(err, "failed to encode client") } _, err = conn.Do("SET", s.makeKey("client", client.GetId()), payload) return errors.Wrap(err, "failed to save client") } // GetClient gets a client by ID func (s *RedisStorage) GetClient(id string) (osin.Client, error) { conn := s.Pool.Get() if err := conn.Err(); err != nil { return nil, err } defer conn.Close() var ( rawClientGob interface{} err error ) if rawClientGob, err = conn.Do("GET", s.makeKey("client", id)); err != nil { return nil, errors.Wrap(err, "unable to GET client") } if rawClientGob == nil { return nil, nil } clientGob, _ := redis.Bytes(rawClientGob, err) var client osin.DefaultClient err = decode(clientGob, &client) return &client, errors.Wrap(err, "failed to decode client gob") } // UpdateClient updates a client func (s *RedisStorage) UpdateClient(client osin.Client) error { return errors.Wrap(s.CreateClient(client), "failed to update client") } // DeleteClient deletes given client func (s *RedisStorage) DeleteClient(client osin.Client) error { conn := s.Pool.Get() if err := conn.Err(); err != nil { return err } defer conn.Close() _, err := conn.Do("DEL", s.makeKey("client", client.GetId())) return errors.Wrap(err, "failed to delete client") } // SaveAuthorize saves authorize data. func (s *RedisStorage) SaveAuthorize(data *osin.AuthorizeData) (err error) { conn := s.Pool.Get() if err := conn.Err(); err != nil { return err } defer conn.Close() payload, err := encode(data) if err != nil { return errors.Wrap(err, "failed to encode data") } _, err = conn.Do("SETEX", s.makeKey("auth", data.Code), data.ExpiresIn, string(payload)) return errors.Wrap(err, "failed to set auth") } // LoadAuthorize looks up AuthorizeData by a code. // Client information MUST be loaded together. // Optionally can return error if expired. func (s *RedisStorage) LoadAuthorize(code string) (*osin.AuthorizeData, error) { conn := s.Pool.Get() if err := conn.Err(); err != nil { return nil, err } defer conn.Close() var ( rawAuthGob interface{} err error ) if rawAuthGob, err = conn.Do("GET", s.makeKey("auth", code)); err != nil { return nil, errors.Wrap(err, "unable to GET auth") } if rawAuthGob == nil { return nil, nil } authGob, _ := redis.Bytes(rawAuthGob, err) var auth osin.AuthorizeData err = decode(authGob, &auth) return &auth, errors.Wrap(err, "failed to decode auth") } // RemoveAuthorize revokes or deletes the authorization code. func (s *RedisStorage) RemoveAuthorize(code string) (err error) { conn := s.Pool.Get() if err := conn.Err(); err != nil { return err } defer conn.Close() _, err = conn.Do("DEL", s.makeKey("auth", code)) return errors.Wrap(err, "failed to delete auth") } // SaveAccess creates AccessData. func (s *RedisStorage) SaveAccess(data *osin.AccessData) (err error) { conn := s.Pool.Get() if err := conn.Err(); err != nil { return err } defer conn.Close() payload, err := encode(data) if err != nil { return errors.Wrap(err, "failed to encode access") } accessID := uuid.NewV4().String() if _, err := conn.Do("SETEX", s.makeKey("access", accessID), data.ExpiresIn, string(payload)); err != nil { return errors.Wrap(err, "failed to save access") } if _, err := conn.Do("SETEX", s.makeKey("access_token", data.AccessToken), data.ExpiresIn, accessID); err != nil { return errors.Wrap(err, "failed to register access token") } _, err = conn.Do("SETEX", s.makeKey("refresh_token", data.RefreshToken), data.ExpiresIn, accessID) return errors.Wrap(err, "failed to register refresh token") } // LoadAccess gets access data with given access token func (s *RedisStorage) LoadAccess(token string) (*osin.AccessData, error) { return s.loadAccessByKey(s.makeKey("access_token", token)) } // RemoveAccess deletes AccessData with given access token func (s *RedisStorage) RemoveAccess(token string) error { return s.removeAccessByKey(s.makeKey("access_token", token)) } // LoadRefresh gets access data with given refresh token func (s *RedisStorage) LoadRefresh(token string) (*osin.AccessData, error) { return s.loadAccessByKey(s.makeKey("refresh_token", token)) } // RemoveRefresh deletes AccessData with given refresh token func (s *RedisStorage) RemoveRefresh(token string) error { return s.removeAccessByKey(s.makeKey("refresh_token", token)) } func (s *RedisStorage) removeAccessByKey(key string) error { conn := s.Pool.Get() if err := conn.Err(); err != nil { return err } defer conn.Close() accessID, err := redis.String(conn.Do("GET", key)) if err != nil { return errors.Wrap(err, "failed to get access") } access, err := s.loadAccessByKey(key) if err != nil { return errors.Wrap(err, "unable to load access for removal") } if access == nil { return nil } accessKey := s.makeKey("access", accessID) if _, err := conn.Do("DEL", accessKey); err != nil { return errors.Wrap(err, "failed to delete access") } accessTokenKey := s.makeKey("access_token", access.AccessToken) if _, err := conn.Do("DEL", accessTokenKey); err != nil { return errors.Wrap(err, "failed to deregister access_token") } refreshTokenKey := s.makeKey("refresh_token", access.RefreshToken) _, err = conn.Do("DEL", refreshTokenKey) return errors.Wrap(err, "failed to deregister refresh_token") } func (s *RedisStorage) loadAccessByKey(key string) (*osin.AccessData, error) { conn := s.Pool.Get() if err := conn.Err(); err != nil { return nil, err } defer conn.Close() var ( rawAuthGob interface{} err error ) if rawAuthGob, err = conn.Do("GET", key); err != nil { return nil, errors.Wrap(err, "unable to GET auth") } if rawAuthGob == nil { return nil, nil } accessID, err := redis.String(conn.Do("GET", key)) if err != nil { return nil, errors.Wrap(err, "unable to get access ID") } accessIDKey := s.makeKey("access", accessID) accessGob, err := redis.Bytes(conn.Do("GET", accessIDKey)) if err != nil { return nil, errors.Wrap(err, "unable to get access gob") } var access osin.AccessData if err := decode(accessGob, &access); err != nil { return nil, errors.Wrap(err, "failed to decode access gob") } ttl, err := redis.Int(conn.Do("TTL", accessIDKey)) if err != nil { return nil, errors.Wrap(err, "unable to get access TTL") } access.ExpiresIn = int32(ttl) access.Client, err = s.GetClient(access.Client.GetId()) if err != nil { return nil, errors.Wrap(err, "unable to get client for access") } if access.AuthorizeData != nil && access.AuthorizeData.Client != nil { access.AuthorizeData.Client, err = s.GetClient(access.AuthorizeData.Client.GetId()) if err != nil { return nil, errors.Wrap(err, "unable to get client for access authorize data") } } return &access, nil } func (s *RedisStorage) makeKey(namespace, id string) string { return fmt.Sprintf("%s:%s:%s", s.KeyPrefix, namespace, id) } func encode(v interface{}) ([]byte, error) { var buf bytes.Buffer if err := gob.NewEncoder(&buf).Encode(v); err != nil { return nil, errors.Wrap(err, "unable to encode") } return buf.Bytes(), nil } func decode(data []byte, v interface{}) error { err := gob.NewDecoder(bytes.NewBuffer(data)).Decode(v) return errors.Wrap(err, "unable to decode") }