package session import ( "context" "encoding/json" "errors" "fmt" "slices" "time" "github.com/go-redis/redis/v8" "go.mongodb.org/mongo-driver/bson/primitive" "repositories.action2quare.com/ayo/gocommon" "repositories.action2quare.com/ayo/gocommon/logger" ) const ( communication_channel_name_prefix = "_sess_comm_chan_name" ) type sessionRedis struct { *Authorization expireAt time.Time } type provider_redis struct { redisClient *redis.Client deleteChannel string ttl time.Duration ctx context.Context } func newProviderWithRedis(ctx context.Context, redisUrl string, ttl time.Duration) (Provider, error) { redisClient, err := gocommon.NewRedisClient(redisUrl) if err != nil { return nil, err } return &provider_redis{ redisClient: redisClient, deleteChannel: fmt.Sprintf("%s_%d_d", communication_channel_name_prefix, redisClient.Options().DB), ttl: ttl, ctx: ctx, }, nil } func (p *provider_redis) New(input *Authorization) (string, error) { sks, err := p.RevokeAll(input.Account, false) if err != nil { return "", err } var newsk storagekey for { newsk = make_storagekey(input.Account) duplicated := slices.Contains(sks, string(newsk)) if !duplicated { break } } _, err = p.redisClient.HSet(p.ctx, string(newsk), input.ToStrings()).Result() if err != nil { return "", err } _, err = p.redisClient.Expire(p.ctx, string(newsk), p.ttl).Result() if err != nil { return "", err } pk := storagekey_to_publickey(newsk) return string(pk), err } func (p *provider_redis) RevokeAll(account primitive.ObjectID, infinite bool) ([]string, error) { prefix := account.Hex() sks, err := p.redisClient.Keys(p.ctx, prefix+"*").Result() if err != nil { logger.Println("session provider delete :", sks, err) return nil, err } if len(sks) > 0 { p.redisClient.Del(p.ctx, sks...) invsess := InvalidatedSession{ SessionKeys: sks, Account: account, Infinite: infinite, } data, _ := json.Marshal(invsess) p.redisClient.Publish(p.ctx, p.deleteChannel, string(data)).Result() } return sks, nil } func (p *provider_redis) Query(pk string) (Authorization, error) { sk := publickey_to_storagekey(publickey(pk)) src, err := p.redisClient.HGetAll(p.ctx, string(sk)).Result() if err == redis.Nil { logger.Println("session provider query :", pk, err) return Authorization{}, nil } else if err != nil { logger.Println("session provider query :", pk, err) return Authorization{}, err } auth := MakeAuthrizationFromStringMap(src) return auth, nil } func (p *provider_redis) Touch(pk string) (bool, error) { sk := publickey_to_storagekey(publickey(pk)) ok, err := p.redisClient.Expire(p.ctx, string(sk), p.ttl).Result() if err == redis.Nil { // 이미 만료됨 logger.Println("session provider touch :", pk, err) return false, nil } else if err != nil { logger.Println("session provider touch :", pk, err) return false, err } return ok, nil } type consumer_redis struct { consumer_common[*sessionRedis] redisClient *redis.Client deleteChannel string } func newConsumerWithRedis(ctx context.Context, redisUrl string, ttl time.Duration) (Consumer, error) { redisClient, err := gocommon.NewRedisClient(redisUrl) if err != nil { return nil, err } deleteChannel := fmt.Sprintf("%s_%d_d", communication_channel_name_prefix, redisClient.Options().DB) sub := redisClient.Subscribe(ctx, deleteChannel) consumer := &consumer_redis{ consumer_common: consumer_common[*sessionRedis]{ ttl: ttl, ctx: ctx, stages: [2]*cache_stage[*sessionRedis]{make_cache_stage[*sessionRedis](), make_cache_stage[*sessionRedis]()}, startTime: time.Now(), }, redisClient: redisClient, deleteChannel: deleteChannel, } go func() { stageswitch := time.Now().Add(ttl) tickTimer := time.After(ttl) for { select { case <-ctx.Done(): return case <-tickTimer: consumer.changeStage() stageswitch = stageswitch.Add(ttl) tempttl := time.Until(stageswitch) tickTimer = time.After(tempttl) case msg := <-sub.Channel(): if msg == nil { return } if len(msg.Payload) == 0 { continue } switch msg.Channel { case deleteChannel: var invsess InvalidatedSession if err := json.Unmarshal([]byte(msg.Payload), &invsess); err != nil { logger.Println("redis consumer deleteChannel unmarshal failed :", err) break } for _, sk := range invsess.SessionKeys { consumer.delete(storagekey(sk)) } for _, f := range consumer.onSessionInvalidated { f(invsess) } } } } }() return consumer, nil } func (c *consumer_redis) query_internal(sk storagekey) (*sessionRedis, error) { if old, deleted := c.stages[0].deleted[sk]; deleted { return old, nil } if old, deleted := c.stages[1].deleted[sk]; deleted { return old, nil } found, ok := c.stages[0].cache[sk] if !ok { found, ok = c.stages[1].cache[sk] } if ok { if time.Now().Before(found.expireAt) { // 만료전 세션 return found, nil } // 다른 Consumer가 Touch했을 수도 있으므로 redis에서 읽어본다. } payload, err := c.redisClient.HGetAll(c.ctx, string(sk)).Result() if err != nil && err != redis.Nil { logger.Println("consumer Query :", err) return nil, err } if len(payload) == 0 { return nil, nil } ttl, err := c.redisClient.TTL(c.ctx, string(sk)).Result() if err != nil { logger.Println("consumer Query :", err) return nil, err } if ttl < 0 { ttl = time.Duration(time.Hour * 24) } auth := MakeAuthrizationFromStringMap(payload) si := &sessionRedis{ Authorization: &auth, expireAt: time.Now().Add(ttl), } if auth.Valid() { c.add_internal(sk, si) } else { c.stages[0].deleted[sk] = si } return si, nil } var errRevoked = errors.New("session revoked") var errExpired = errors.New("session expired") func (c *consumer_redis) Query(pk string) Authorization { c.lock.Lock() defer c.lock.Unlock() sk := publickey_to_storagekey(publickey(pk)) if _, deleted := c.stages[0].deleted[sk]; deleted { return Authorization{} } if _, deleted := c.stages[1].deleted[sk]; deleted { return Authorization{} } si, err := c.query_internal(sk) if err != nil { logger.Println("session consumer query :", pk, err) return Authorization{} } if si == nil { logger.Println("session consumer query(si nil) :", pk, nil) return Authorization{} } if time.Now().After(si.expireAt) { logger.Println("session consumer query(expired):", pk, nil) return Authorization{} } return *si.Authorization } func (c *consumer_redis) Touch(pk string) (Authorization, error) { c.lock.Lock() defer c.lock.Unlock() sk := publickey_to_storagekey(publickey(pk)) if _, deleted := c.stages[0].deleted[sk]; deleted { return Authorization{}, nil } if _, deleted := c.stages[1].deleted[sk]; deleted { return Authorization{}, nil } ok, err := c.redisClient.Expire(c.ctx, string(sk), c.ttl).Result() if err == redis.Nil { logger.Println("session consumer touch :", pk, err) return Authorization{}, nil } else if err != nil { logger.Println("session consumer touch :", pk, err) return Authorization{}, err } if ok { // redis에 살아있다. si, err := c.query_internal(sk) if err != nil { logger.Println("session consumer touch(ok) :", pk, err) return Authorization{}, err } if si == nil { logger.Println("session consumer touch(ok, si nil) :", pk) return Authorization{}, nil } return *si.Authorization, nil } return Authorization{}, nil } func (c *consumer_redis) Revoke(pk string) { sk := publickey_to_storagekey(publickey(pk)) c.redisClient.Del(c.ctx, string(sk)) c.lock.Lock() defer c.lock.Unlock() if sr, ok := c.stages[0].cache[sk]; ok { c.stages[0].deleted[sk] = sr } if sr, ok := c.stages[1].cache[sk]; ok { c.stages[1].deleted[sk] = sr } } func (c *consumer_redis) IsRevoked(accid primitive.ObjectID) bool { sk := make_storagekey(accid) c.lock.Lock() defer c.lock.Unlock() if _, deleted := c.stages[0].deleted[sk]; deleted { return true } if _, deleted := c.stages[1].deleted[sk]; deleted { return true } return false } func (c *consumer_redis) RegisterOnSessionInvalidated(cb func(InvalidatedSession)) { c.onSessionInvalidated = append(c.onSessionInvalidated, cb) }