Files
gocommon/session/impl_redis.go

375 lines
8.2 KiB
Go

package session
import (
"context"
"encoding/json"
"errors"
"fmt"
"slices"
"time"
"github.com/go-redis/redis/v8"
"go.mongodb.org/mongo-driver/bson/primitive"
"repositories.action2quare.com/ayo/gocommon"
"repositories.action2quare.com/ayo/gocommon/logger"
)
const (
communication_channel_name_prefix = "_sess_comm_chan_name"
)
type sessionRedis struct {
*Authorization
expireAt time.Time
}
type provider_redis struct {
redisClient *redis.Client
deleteChannel string
ttl time.Duration
ctx context.Context
}
func newProviderWithRedis(ctx context.Context, redisUrl string, ttl time.Duration) (Provider, error) {
redisClient, err := gocommon.NewRedisClient(redisUrl)
if err != nil {
return nil, err
}
return &provider_redis{
redisClient: redisClient,
deleteChannel: fmt.Sprintf("%s_%d_d", communication_channel_name_prefix, redisClient.Options().DB),
ttl: ttl,
ctx: ctx,
}, nil
}
func (p *provider_redis) New(input *Authorization) (string, error) {
sks, err := p.RevokeAll(input.Account, false)
if err != nil {
return "", err
}
var newsk storagekey
for {
newsk = make_storagekey(input.Account)
duplicated := slices.Contains(sks, string(newsk))
if !duplicated {
break
}
}
_, err = p.redisClient.HSet(p.ctx, string(newsk), input.ToStrings()).Result()
if err != nil {
return "", err
}
_, err = p.redisClient.Expire(p.ctx, string(newsk), p.ttl).Result()
if err != nil {
return "", err
}
pk := storagekey_to_publickey(newsk)
return string(pk), err
}
func (p *provider_redis) RevokeAll(account primitive.ObjectID, infinite bool) ([]string, error) {
prefix := account.Hex()
sks, err := p.redisClient.Keys(p.ctx, prefix+"*").Result()
if err != nil {
logger.Println("session provider delete :", sks, err)
return nil, err
}
if len(sks) > 0 {
p.redisClient.Del(p.ctx, sks...)
invsess := InvalidatedSession{
SessionKeys: sks,
Account: account,
Infinite: infinite,
}
data, _ := json.Marshal(invsess)
p.redisClient.Publish(p.ctx, p.deleteChannel, string(data)).Result()
}
return sks, nil
}
func (p *provider_redis) Query(pk string) (Authorization, error) {
sk := publickey_to_storagekey(publickey(pk))
src, err := p.redisClient.HGetAll(p.ctx, string(sk)).Result()
if err == redis.Nil {
logger.Println("session provider query :", pk, err)
return Authorization{}, nil
} else if err != nil {
logger.Println("session provider query :", pk, err)
return Authorization{}, err
}
auth := MakeAuthrizationFromStringMap(src)
return auth, nil
}
func (p *provider_redis) Touch(pk string) (bool, error) {
sk := publickey_to_storagekey(publickey(pk))
ok, err := p.redisClient.Expire(p.ctx, string(sk), p.ttl).Result()
if err == redis.Nil {
// 이미 만료됨
logger.Println("session provider touch :", pk, err)
return false, nil
} else if err != nil {
logger.Println("session provider touch :", pk, err)
return false, err
}
return ok, nil
}
type consumer_redis struct {
consumer_common[*sessionRedis]
redisClient *redis.Client
deleteChannel string
}
func newConsumerWithRedis(ctx context.Context, redisUrl string, ttl time.Duration) (Consumer, error) {
redisClient, err := gocommon.NewRedisClient(redisUrl)
if err != nil {
return nil, err
}
deleteChannel := fmt.Sprintf("%s_%d_d", communication_channel_name_prefix, redisClient.Options().DB)
sub := redisClient.Subscribe(ctx, deleteChannel)
consumer := &consumer_redis{
consumer_common: consumer_common[*sessionRedis]{
ttl: ttl,
ctx: ctx,
stages: [2]*cache_stage[*sessionRedis]{make_cache_stage[*sessionRedis](), make_cache_stage[*sessionRedis]()},
startTime: time.Now(),
},
redisClient: redisClient,
deleteChannel: deleteChannel,
}
go func() {
stageswitch := time.Now().Add(ttl)
tickTimer := time.After(ttl)
for {
select {
case <-ctx.Done():
return
case <-tickTimer:
consumer.changeStage()
stageswitch = stageswitch.Add(ttl)
tempttl := time.Until(stageswitch)
tickTimer = time.After(tempttl)
case msg := <-sub.Channel():
if msg == nil {
return
}
if len(msg.Payload) == 0 {
continue
}
switch msg.Channel {
case deleteChannel:
var invsess InvalidatedSession
if err := json.Unmarshal([]byte(msg.Payload), &invsess); err != nil {
logger.Println("redis consumer deleteChannel unmarshal failed :", err)
break
}
for _, sk := range invsess.SessionKeys {
consumer.delete(storagekey(sk))
}
for _, f := range consumer.onSessionInvalidated {
f(invsess)
}
}
}
}
}()
return consumer, nil
}
func (c *consumer_redis) query_internal(sk storagekey) (*sessionRedis, error) {
if old, deleted := c.stages[0].deleted[sk]; deleted {
return old, nil
}
if old, deleted := c.stages[1].deleted[sk]; deleted {
return old, nil
}
found, ok := c.stages[0].cache[sk]
if !ok {
found, ok = c.stages[1].cache[sk]
}
if ok {
if time.Now().Before(found.expireAt) {
// 만료전 세션
return found, nil
}
// 다른 Consumer가 Touch했을 수도 있으므로 redis에서 읽어본다.
}
payload, err := c.redisClient.HGetAll(c.ctx, string(sk)).Result()
if err != nil && err != redis.Nil {
logger.Println("consumer Query :", err)
return nil, err
}
if len(payload) == 0 {
return nil, nil
}
ttl, err := c.redisClient.TTL(c.ctx, string(sk)).Result()
if err != nil {
logger.Println("consumer Query :", err)
return nil, err
}
if ttl < 0 {
ttl = time.Duration(time.Hour * 24)
}
auth := MakeAuthrizationFromStringMap(payload)
si := &sessionRedis{
Authorization: &auth,
expireAt: time.Now().Add(ttl),
}
if auth.Valid() {
c.add_internal(sk, si)
} else {
c.stages[0].deleted[sk] = si
}
return si, nil
}
var errRevoked = errors.New("session revoked")
var errExpired = errors.New("session expired")
func (c *consumer_redis) Query(pk string) Authorization {
c.lock.Lock()
defer c.lock.Unlock()
sk := publickey_to_storagekey(publickey(pk))
if _, deleted := c.stages[0].deleted[sk]; deleted {
return Authorization{}
}
if _, deleted := c.stages[1].deleted[sk]; deleted {
return Authorization{}
}
si, err := c.query_internal(sk)
if err != nil {
logger.Println("session consumer query :", pk, err)
return Authorization{}
}
if si == nil {
logger.Println("session consumer query(si nil) :", pk, nil)
return Authorization{}
}
if time.Now().After(si.expireAt) {
logger.Println("session consumer query(expired):", pk, nil)
return Authorization{}
}
return *si.Authorization
}
func (c *consumer_redis) Touch(pk string) (Authorization, error) {
c.lock.Lock()
defer c.lock.Unlock()
sk := publickey_to_storagekey(publickey(pk))
if _, deleted := c.stages[0].deleted[sk]; deleted {
return Authorization{}, nil
}
if _, deleted := c.stages[1].deleted[sk]; deleted {
return Authorization{}, nil
}
ok, err := c.redisClient.Expire(c.ctx, string(sk), c.ttl).Result()
if err == redis.Nil {
logger.Println("session consumer touch :", pk, err)
return Authorization{}, nil
} else if err != nil {
logger.Println("session consumer touch :", pk, err)
return Authorization{}, err
}
if ok {
// redis에 살아있다.
si, err := c.query_internal(sk)
if err != nil {
logger.Println("session consumer touch(ok) :", pk, err)
return Authorization{}, err
}
if si == nil {
logger.Println("session consumer touch(ok, si nil) :", pk)
return Authorization{}, nil
}
return *si.Authorization, nil
}
return Authorization{}, nil
}
func (c *consumer_redis) Revoke(pk string) {
sk := publickey_to_storagekey(publickey(pk))
c.redisClient.Del(c.ctx, string(sk))
c.lock.Lock()
defer c.lock.Unlock()
if sr, ok := c.stages[0].cache[sk]; ok {
c.stages[0].deleted[sk] = sr
}
if sr, ok := c.stages[1].cache[sk]; ok {
c.stages[1].deleted[sk] = sr
}
}
func (c *consumer_redis) IsRevoked(accid primitive.ObjectID) bool {
sk := make_storagekey(accid)
c.lock.Lock()
defer c.lock.Unlock()
if _, deleted := c.stages[0].deleted[sk]; deleted {
return true
}
if _, deleted := c.stages[1].deleted[sk]; deleted {
return true
}
return false
}
func (c *consumer_redis) RegisterOnSessionInvalidated(cb func(InvalidatedSession)) {
c.onSessionInvalidated = append(c.onSessionInvalidated, cb)
}