376 lines
8.2 KiB
Go
376 lines
8.2 KiB
Go
package session
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"slices"
|
|
"time"
|
|
|
|
"github.com/go-redis/redis/v8"
|
|
"go.mongodb.org/mongo-driver/bson/primitive"
|
|
"repositories.action2quare.com/ayo/gocommon"
|
|
"repositories.action2quare.com/ayo/gocommon/logger"
|
|
)
|
|
|
|
const (
|
|
communication_channel_name_prefix = "_sess_comm_chan_name"
|
|
)
|
|
|
|
type sessionRedis struct {
|
|
*Authorization
|
|
expireAt time.Time
|
|
}
|
|
|
|
type provider_redis struct {
|
|
redisClient *redis.Client
|
|
deleteChannel string
|
|
ttl time.Duration
|
|
ctx context.Context
|
|
}
|
|
|
|
func newProviderWithRedis(ctx context.Context, redisUrl string, ttl time.Duration) (Provider, error) {
|
|
redisClient, err := gocommon.NewRedisClient(redisUrl)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &provider_redis{
|
|
redisClient: redisClient,
|
|
deleteChannel: fmt.Sprintf("%s_%d_d", communication_channel_name_prefix, redisClient.Options().DB),
|
|
ttl: ttl,
|
|
ctx: ctx,
|
|
}, nil
|
|
}
|
|
|
|
func (p *provider_redis) New(input *Authorization) (string, error) {
|
|
sks, err := p.RevokeAll(input.Account, false)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
var newsk storagekey
|
|
for {
|
|
newsk = make_storagekey(input.Account)
|
|
duplicated := slices.Contains(sks, string(newsk))
|
|
if !duplicated {
|
|
break
|
|
}
|
|
}
|
|
|
|
_, err = p.redisClient.HSet(p.ctx, string(newsk), input.ToStrings()).Result()
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
_, err = p.redisClient.Expire(p.ctx, string(newsk), p.ttl).Result()
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
pk := storagekey_to_publickey(newsk)
|
|
return string(pk), err
|
|
}
|
|
|
|
func (p *provider_redis) RevokeAll(account primitive.ObjectID, infinite bool) ([]string, error) {
|
|
prefix := account.Hex()
|
|
sks, err := p.redisClient.Keys(p.ctx, prefix+"*").Result()
|
|
if err != nil {
|
|
logger.Println("session provider delete :", sks, err)
|
|
return nil, err
|
|
}
|
|
|
|
for _, sk := range sks {
|
|
if infinite {
|
|
p.redisClient.Publish(p.ctx, p.deleteChannel, "~"+sk).Result()
|
|
} else {
|
|
p.redisClient.Publish(p.ctx, p.deleteChannel, sk).Result()
|
|
}
|
|
}
|
|
|
|
return sks, nil
|
|
}
|
|
|
|
func (p *provider_redis) Query(pk string) (Authorization, error) {
|
|
sk := publickey_to_storagekey(publickey(pk))
|
|
src, err := p.redisClient.HGetAll(p.ctx, string(sk)).Result()
|
|
if err == redis.Nil {
|
|
logger.Println("session provider query :", pk, err)
|
|
return Authorization{}, nil
|
|
} else if err != nil {
|
|
logger.Println("session provider query :", pk, err)
|
|
return Authorization{}, err
|
|
}
|
|
|
|
auth := MakeAuthrizationFromStringMap(src)
|
|
return auth, nil
|
|
}
|
|
|
|
func (p *provider_redis) Touch(pk string) (bool, error) {
|
|
sk := publickey_to_storagekey(publickey(pk))
|
|
ok, err := p.redisClient.Expire(p.ctx, string(sk), p.ttl).Result()
|
|
|
|
if err == redis.Nil {
|
|
// 이미 만료됨
|
|
logger.Println("session provider touch :", pk, err)
|
|
return false, nil
|
|
} else if err != nil {
|
|
logger.Println("session provider touch :", pk, err)
|
|
return false, err
|
|
}
|
|
|
|
return ok, nil
|
|
}
|
|
|
|
type consumer_redis struct {
|
|
consumer_common[*sessionRedis]
|
|
redisClient *redis.Client
|
|
deleteChannel string
|
|
}
|
|
|
|
func newConsumerWithRedis(ctx context.Context, redisUrl string, ttl time.Duration) (Consumer, error) {
|
|
redisClient, err := gocommon.NewRedisClient(redisUrl)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
deleteChannel := fmt.Sprintf("%s_%d_d", communication_channel_name_prefix, redisClient.Options().DB)
|
|
sub := redisClient.Subscribe(ctx, deleteChannel)
|
|
|
|
consumer := &consumer_redis{
|
|
consumer_common: consumer_common[*sessionRedis]{
|
|
ttl: ttl,
|
|
ctx: ctx,
|
|
stages: [2]*cache_stage[*sessionRedis]{make_cache_stage[*sessionRedis](), make_cache_stage[*sessionRedis]()},
|
|
startTime: time.Now(),
|
|
},
|
|
redisClient: redisClient,
|
|
deleteChannel: deleteChannel,
|
|
}
|
|
|
|
go func() {
|
|
stageswitch := time.Now().Add(ttl)
|
|
tickTimer := time.After(ttl)
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
|
|
case <-tickTimer:
|
|
consumer.changeStage()
|
|
stageswitch = stageswitch.Add(ttl)
|
|
tempttl := time.Until(stageswitch)
|
|
tickTimer = time.After(tempttl)
|
|
|
|
case msg := <-sub.Channel():
|
|
if msg == nil {
|
|
return
|
|
}
|
|
|
|
if len(msg.Payload) == 0 {
|
|
continue
|
|
}
|
|
|
|
switch msg.Channel {
|
|
case deleteChannel:
|
|
infinite := false
|
|
var sk string
|
|
if msg.Payload[0] == '~' {
|
|
sk = msg.Payload[1:]
|
|
infinite = true
|
|
} else {
|
|
sk = msg.Payload
|
|
infinite = false
|
|
}
|
|
old := consumer.delete(storagekey(sk))
|
|
if old != nil {
|
|
invsess := InvalidatedSession{
|
|
Account: old.Account,
|
|
Infinite: infinite,
|
|
}
|
|
|
|
for _, f := range consumer.onSessionInvalidated {
|
|
f(invsess)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
|
|
return consumer, nil
|
|
}
|
|
|
|
func (c *consumer_redis) query_internal(sk storagekey) (*sessionRedis, error) {
|
|
if old, deleted := c.stages[0].deleted[sk]; deleted {
|
|
return old, nil
|
|
}
|
|
|
|
if old, deleted := c.stages[1].deleted[sk]; deleted {
|
|
return old, nil
|
|
}
|
|
|
|
found, ok := c.stages[0].cache[sk]
|
|
if !ok {
|
|
found, ok = c.stages[1].cache[sk]
|
|
}
|
|
|
|
if ok {
|
|
if time.Now().Before(found.expireAt) {
|
|
// 만료전 세션
|
|
return found, nil
|
|
}
|
|
|
|
// 다른 Consumer가 Touch했을 수도 있으므로 redis에서 읽어본다.
|
|
}
|
|
|
|
payload, err := c.redisClient.HGetAll(c.ctx, string(sk)).Result()
|
|
if err != nil && err != redis.Nil {
|
|
logger.Println("consumer Query :", err)
|
|
return nil, err
|
|
}
|
|
|
|
if len(payload) == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
ttl, err := c.redisClient.TTL(c.ctx, string(sk)).Result()
|
|
if err != nil {
|
|
logger.Println("consumer Query :", err)
|
|
return nil, err
|
|
}
|
|
if ttl < 0 {
|
|
ttl = time.Duration(time.Hour * 24)
|
|
}
|
|
|
|
auth := MakeAuthrizationFromStringMap(payload)
|
|
si := &sessionRedis{
|
|
Authorization: &auth,
|
|
expireAt: time.Now().Add(ttl),
|
|
}
|
|
|
|
if auth.Valid() {
|
|
c.add_internal(sk, si)
|
|
} else {
|
|
c.stages[0].deleted[sk] = si
|
|
}
|
|
|
|
return si, nil
|
|
}
|
|
|
|
var errRevoked = errors.New("session revoked")
|
|
var errExpired = errors.New("session expired")
|
|
|
|
func (c *consumer_redis) Query(pk string) Authorization {
|
|
c.lock.Lock()
|
|
defer c.lock.Unlock()
|
|
|
|
sk := publickey_to_storagekey(publickey(pk))
|
|
|
|
if _, deleted := c.stages[0].deleted[sk]; deleted {
|
|
return Authorization{}
|
|
}
|
|
|
|
if _, deleted := c.stages[1].deleted[sk]; deleted {
|
|
return Authorization{}
|
|
}
|
|
|
|
si, err := c.query_internal(sk)
|
|
if err != nil {
|
|
logger.Println("session consumer query :", pk, err)
|
|
return Authorization{}
|
|
}
|
|
|
|
if si == nil {
|
|
logger.Println("session consumer query(si nil) :", pk, nil)
|
|
return Authorization{}
|
|
}
|
|
|
|
if time.Now().After(si.expireAt) {
|
|
logger.Println("session consumer query(expired):", pk, nil)
|
|
return Authorization{}
|
|
}
|
|
|
|
return *si.Authorization
|
|
}
|
|
|
|
func (c *consumer_redis) Touch(pk string) (Authorization, error) {
|
|
c.lock.Lock()
|
|
defer c.lock.Unlock()
|
|
|
|
sk := publickey_to_storagekey(publickey(pk))
|
|
|
|
if _, deleted := c.stages[0].deleted[sk]; deleted {
|
|
return Authorization{}, nil
|
|
}
|
|
|
|
if _, deleted := c.stages[1].deleted[sk]; deleted {
|
|
return Authorization{}, nil
|
|
}
|
|
|
|
ok, err := c.redisClient.Expire(c.ctx, string(sk), c.ttl).Result()
|
|
if err == redis.Nil {
|
|
logger.Println("session consumer touch :", pk, err)
|
|
|
|
return Authorization{}, nil
|
|
} else if err != nil {
|
|
logger.Println("session consumer touch :", pk, err)
|
|
return Authorization{}, err
|
|
}
|
|
|
|
if ok {
|
|
// redis에 살아있다.
|
|
si, err := c.query_internal(sk)
|
|
if err != nil {
|
|
logger.Println("session consumer touch(ok) :", pk, err)
|
|
return Authorization{}, err
|
|
}
|
|
|
|
if si == nil {
|
|
logger.Println("session consumer touch(ok, si nil) :", pk)
|
|
return Authorization{}, nil
|
|
}
|
|
|
|
return *si.Authorization, nil
|
|
}
|
|
|
|
return Authorization{}, nil
|
|
}
|
|
|
|
func (c *consumer_redis) Revoke(pk string) {
|
|
sk := publickey_to_storagekey(publickey(pk))
|
|
|
|
c.redisClient.Del(c.ctx, string(sk))
|
|
|
|
c.lock.Lock()
|
|
defer c.lock.Unlock()
|
|
|
|
if sr, ok := c.stages[0].cache[sk]; ok {
|
|
c.stages[0].deleted[sk] = sr
|
|
}
|
|
|
|
if sr, ok := c.stages[1].cache[sk]; ok {
|
|
c.stages[1].deleted[sk] = sr
|
|
}
|
|
}
|
|
|
|
func (c *consumer_redis) IsRevoked(accid primitive.ObjectID) bool {
|
|
sk := make_storagekey(accid)
|
|
|
|
c.lock.Lock()
|
|
defer c.lock.Unlock()
|
|
|
|
if _, deleted := c.stages[0].deleted[sk]; deleted {
|
|
return true
|
|
}
|
|
|
|
if _, deleted := c.stages[1].deleted[sk]; deleted {
|
|
return true
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func (c *consumer_redis) RegisterOnSessionInvalidated(cb func(InvalidatedSession)) {
|
|
c.onSessionInvalidated = append(c.onSessionInvalidated, cb)
|
|
}
|