From 0392966760844b2bb83b879163d195368d0f1b84 Mon Sep 17 00:00:00 2001 From: mountain Date: Thu, 11 Sep 2025 09:38:38 +0900 Subject: [PATCH] =?UTF-8?q?=EC=84=B8=EC=85=98=20invalidate=EB=90=A0=20?= =?UTF-8?q?=EB=95=8C=20=EC=A0=84=EB=8B=AC=ED=95=98=EB=8A=94=20=EC=9D=B8?= =?UTF-8?q?=EC=9E=90=EB=A5=BC=20=EA=B5=AC=EC=A1=B0=EC=B2=B4=EB=A1=9C=20?= =?UTF-8?q?=EB=B3=80=EA=B2=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- session/common.go | 24 +-- session/consumer_common.go | 4 +- session/impl_mongo.go | 383 ------------------------------------ session/impl_redis.go | 56 +++--- session/session_test.go | 2 +- wshandler/wshandler.go | 4 +- wshandler/wshandler_peer.go | 4 +- 7 files changed, 45 insertions(+), 432 deletions(-) delete mode 100644 session/impl_mongo.go diff --git a/session/common.go b/session/common.go index 901a68d..1ce578d 100644 --- a/session/common.go +++ b/session/common.go @@ -14,8 +14,7 @@ import ( ) type Authorization struct { - Account primitive.ObjectID `bson:"a" json:"a"` - invalidated string + Account primitive.ObjectID `bson:"a" json:"a"` // by authorization provider Platform string `bson:"p" json:"p"` @@ -30,13 +29,12 @@ func (auth *Authorization) ToStrings() []string { "p", auth.Platform, "u", auth.Uid, "al", auth.Alias, - "inv", auth.invalidated, "ct", strconv.FormatInt(auth.CreatedTime, 10), } } func (auth *Authorization) Valid() bool { - return len(auth.invalidated) == 0 && !auth.Account.IsZero() + return !auth.Account.IsZero() } func MakeAuthrizationFromStringMap(src map[string]string) Authorization { @@ -47,24 +45,28 @@ func MakeAuthrizationFromStringMap(src map[string]string) Authorization { Platform: src["p"], Uid: src["u"], Alias: src["al"], - invalidated: src["inv"], CreatedTime: ct, } } type Provider interface { New(*Authorization) (string, error) - RevokeAll(primitive.ObjectID) error + RevokeAll(primitive.ObjectID, bool) ([]string, error) Query(string) (Authorization, error) Touch(string) (bool, error) } +type InvalidatedSession struct { + Account primitive.ObjectID + Infinite bool +} + type Consumer interface { Query(string) Authorization Touch(string) (Authorization, error) IsRevoked(primitive.ObjectID) bool Revoke(string) - RegisterOnSessionInvalidated(func(primitive.ObjectID)) + RegisterOnSessionInvalidated(func(InvalidatedSession)) } type storagekey string @@ -120,10 +122,6 @@ var errInvalidScheme = errors.New("storageAddr is not valid scheme") var errSessionStorageMissing = errors.New("session_storageis missing") func NewConsumer(ctx context.Context, storageAddr string, ttl time.Duration) (Consumer, error) { - if strings.HasPrefix(storageAddr, "mongodb") { - return newConsumerWithMongo(ctx, storageAddr, ttl) - } - if strings.HasPrefix(storageAddr, "redis") { return newConsumerWithRedis(ctx, storageAddr, ttl) } @@ -143,10 +141,6 @@ func NewConsumerWithConfig(ctx context.Context, cfg SessionConfig) (Consumer, er } func NewProvider(ctx context.Context, storageAddr string, ttl time.Duration) (Provider, error) { - if strings.HasPrefix(storageAddr, "mongodb") { - return newProviderWithMongo(ctx, storageAddr, ttl) - } - if strings.HasPrefix(storageAddr, "redis") { return newProviderWithRedis(ctx, storageAddr, ttl) } diff --git a/session/consumer_common.go b/session/consumer_common.go index 1932cfe..bbce101 100644 --- a/session/consumer_common.go +++ b/session/consumer_common.go @@ -4,8 +4,6 @@ import ( "context" "sync" "time" - - "go.mongodb.org/mongo-driver/bson/primitive" ) type cache_stage[T any] struct { @@ -26,7 +24,7 @@ type consumer_common[T any] struct { ctx context.Context stages [2]*cache_stage[T] startTime time.Time - onSessionInvalidated []func(primitive.ObjectID) + onSessionInvalidated []func(InvalidatedSession) } func (c *consumer_common[T]) add_internal(sk storagekey, si T) { diff --git a/session/impl_mongo.go b/session/impl_mongo.go deleted file mode 100644 index b6011b3..0000000 --- a/session/impl_mongo.go +++ /dev/null @@ -1,383 +0,0 @@ -package session - -import ( - "context" - "time" - - "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/bson/primitive" - "go.mongodb.org/mongo-driver/mongo" - "go.mongodb.org/mongo-driver/mongo/options" - "repositories.action2quare.com/ayo/gocommon" - "repositories.action2quare.com/ayo/gocommon/logger" -) - -const ( - session_collection_name = gocommon.CollectionName("session") -) - -type provider_mongo struct { - mongoClient gocommon.MongoClient -} - -type sessionMongo struct { - Id primitive.ObjectID `bson:"_id,omitempty"` - Auth *Authorization `bson:"auth"` - Key storagekey `bson:"key"` - Ts primitive.DateTime `bson:"_ts"` -} - -func newProviderWithMongo(ctx context.Context, mongoUrl string, ttl time.Duration) (Provider, error) { - mc, err := gocommon.NewMongoClient(ctx, mongoUrl) - if err != nil { - return nil, err - } - - if err = mc.MakeUniqueIndices(session_collection_name, map[string]bson.D{ - "key": {{Key: "key", Value: 1}}, - }); err != nil { - return nil, err - } - - if err := mc.MakeExpireIndex(session_collection_name, int32(ttl.Seconds())); err != nil { - return nil, err - } - - return &provider_mongo{ - mongoClient: mc, - }, nil -} - -func (p *provider_mongo) New(input *Authorization) (string, error) { - sk := make_storagekey(input.Account) - - _, _, err := p.mongoClient.Update(session_collection_name, bson.M{ - "_id": input.Account, - }, bson.M{ - "$set": sessionMongo{ - Auth: input, - Key: sk, - Ts: primitive.NewDateTimeFromTime(time.Now().UTC()), - }, - }, options.Update().SetUpsert(true)) - - return string(storagekey_to_publickey(sk)), err -} - -func (p *provider_mongo) RevokeAll(acc primitive.ObjectID) error { - _, err := p.mongoClient.Delete(session_collection_name, bson.M{ - "_id": acc, - }) - return err -} - -func (p *provider_mongo) Query(pk string) (Authorization, error) { - sk := publickey_to_storagekey(publickey(pk)) - var auth Authorization - err := p.mongoClient.FindOneAs(session_collection_name, bson.M{ - "key": sk, - }, &auth) - - return auth, err -} - -func (p *provider_mongo) Touch(pk string) (bool, error) { - sk := publickey_to_storagekey(publickey(pk)) - worked, _, err := p.mongoClient.Update(session_collection_name, bson.M{ - "key": sk, - }, bson.M{ - "$currentDate": bson.M{ - "_ts": bson.M{"$type": "date"}, - }, - }, options.Update().SetUpsert(false)) - - if err != nil { - logger.Println("provider Touch :", err) - return false, err - } - - return worked, nil -} - -type consumer_mongo struct { - consumer_common[*sessionMongo] - ids map[primitive.ObjectID]storagekey - mongoClient gocommon.MongoClient - ttl time.Duration -} - -type sessionPipelineDocument struct { - OperationType string `bson:"operationType"` - DocumentKey struct { - Id primitive.ObjectID `bson:"_id"` - } `bson:"documentKey"` - Session *sessionMongo `bson:"fullDocument"` -} - -func newConsumerWithMongo(ctx context.Context, mongoUrl string, ttl time.Duration) (Consumer, error) { - mc, err := gocommon.NewMongoClient(ctx, mongoUrl) - if err != nil { - return nil, err - } - - consumer := &consumer_mongo{ - consumer_common: consumer_common[*sessionMongo]{ - ttl: ttl, - ctx: ctx, - stages: [2]*cache_stage[*sessionMongo]{make_cache_stage[*sessionMongo](), make_cache_stage[*sessionMongo]()}, - startTime: time.Now(), - }, - ids: make(map[primitive.ObjectID]storagekey), - ttl: ttl, - mongoClient: mc, - } - - go func() { - matchStage := bson.D{ - { - Key: "$match", Value: bson.D{ - {Key: "operationType", Value: bson.D{ - {Key: "$in", Value: bson.A{ - "delete", - "insert", - "update", - }}, - }}, - }, - }} - projectStage := bson.D{ - { - Key: "$project", Value: bson.D{ - {Key: "documentKey", Value: 1}, - {Key: "operationType", Value: 1}, - {Key: "fullDocument", Value: 1}, - }, - }, - } - - var stream *mongo.ChangeStream - nextswitch := time.Now().Add(ttl) - for { - if stream == nil { - stream, err = mc.Watch(session_collection_name, mongo.Pipeline{matchStage, projectStage}) - if err != nil { - logger.Error("watchAuthCollection watch failed :", err) - time.Sleep(time.Minute) - continue - } - } - - changed := stream.TryNext(ctx) - if ctx.Err() != nil { - logger.Error("watchAuthCollection stream.TryNext failed. process should be restarted! :", ctx.Err().Error()) - break - } - - if changed { - var data sessionPipelineDocument - if err := stream.Decode(&data); err == nil { - ot := data.OperationType - switch ot { - case "insert": - consumer.add(data.Session.Key, data.DocumentKey.Id, data.Session) - case "update": - if data.Session == nil { - if old := consumer.deleteById(data.DocumentKey.Id); old != nil { - for _, f := range consumer.onSessionInvalidated { - f(old.Auth.Account) - } - } - } else { - consumer.add(data.Session.Key, data.DocumentKey.Id, data.Session) - } - case "delete": - if old := consumer.deleteById(data.DocumentKey.Id); old != nil { - for _, f := range consumer.onSessionInvalidated { - f(old.Auth.Account) - } - } - } - } else { - logger.Error("watchAuthCollection stream.Decode failed :", err) - } - } else if stream.Err() != nil || stream.ID() == 0 { - select { - case <-ctx.Done(): - logger.Println("watchAuthCollection is done") - stream.Close(ctx) - return - - case <-time.After(time.Second): - logger.Error("watchAuthCollection stream error :", stream.Err()) - stream.Close(ctx) - stream = nil - } - } else { - time.Sleep(time.Second) - } - - now := time.Now() - for now.After(nextswitch) { - consumer.changeStage() - nextswitch = nextswitch.Add(ttl) - } - } - }() - - return consumer, nil -} - -func (c *consumer_mongo) query_internal(sk storagekey) (*sessionMongo, bool, error) { - if _, deleted := c.stages[0].deleted[sk]; deleted { - return nil, false, nil - } - - if _, deleted := c.stages[1].deleted[sk]; deleted { - return nil, false, nil - } - - found, ok := c.stages[0].cache[sk] - if !ok { - found, ok = c.stages[1].cache[sk] - } - - if ok { - return found, false, nil - } - - var si sessionMongo - err := c.mongoClient.FindOneAs(session_collection_name, bson.M{ - "key": sk, - }, &si) - - if err != nil { - logger.Println("consumer Query :", err) - return nil, false, err - } - - if len(si.Key) > 0 { - siptr := &si - c.add_internal(sk, siptr) - return siptr, true, nil - } - return nil, false, nil -} - -func (c *consumer_mongo) Query(pk string) Authorization { - c.lock.Lock() - defer c.lock.Unlock() - - sk := publickey_to_storagekey(publickey(pk)) - si, _, err := c.query_internal(sk) - if err != nil { - return Authorization{} - } - - if si == nil { - return Authorization{} - } - - if time.Now().After(si.Ts.Time().Add(c.ttl)) { - return Authorization{} - } - - return *si.Auth -} - -func (c *consumer_mongo) Touch(pk string) (Authorization, error) { - c.lock.Lock() - defer c.lock.Unlock() - - sk := publickey_to_storagekey(publickey(pk)) - worked, _, err := c.mongoClient.Update(session_collection_name, bson.M{ - "key": sk, - }, bson.M{ - "$currentDate": bson.M{ - "_ts": bson.M{"$type": "date"}, - }, - }, options.Update().SetUpsert(false)) - - if err != nil { - logger.Println("consumer Touch :", err) - return Authorization{}, err - } - - if !worked { - // 이미 만료되서 사라짐 - return Authorization{}, nil - } - - si, added, err := c.query_internal(sk) - if err != nil { - return Authorization{}, err - } - - if si == nil { - return Authorization{}, nil - } - - if !added { - var doc sessionMongo - err := c.mongoClient.FindOneAs(session_collection_name, bson.M{ - "key": sk, - }, &doc) - - if err != nil { - logger.Println("consumer Query :", err) - return Authorization{}, err - } - - if len(si.Key) > 0 { - c.add_internal(sk, &doc) - c.ids[doc.Id] = sk - - return *doc.Auth, nil - } - } - - return *si.Auth, nil -} - -func (c *consumer_mongo) Revoke(pk string) { - sk := publickey_to_storagekey(publickey(pk)) - _, err := c.mongoClient.Delete(session_collection_name, bson.M{ - "key": sk, - }) - - if err == nil { - for id, v := range c.ids { - if v == sk { - delete(c.ids, id) - break - } - } - } -} - -func (c *consumer_mongo) IsRevoked(id primitive.ObjectID) bool { - _, ok := c.ids[id] - return !ok -} - -func (c *consumer_mongo) add(sk storagekey, id primitive.ObjectID, si *sessionMongo) { - c.lock.Lock() - defer c.lock.Unlock() - - c.consumer_common.add_internal(sk, si) - c.ids[id] = sk -} - -func (c *consumer_mongo) deleteById(id primitive.ObjectID) (old *sessionMongo) { - c.lock.Lock() - defer c.lock.Unlock() - - if sk, ok := c.ids[id]; ok { - old = c.consumer_common.delete_internal(sk) - delete(c.ids, id) - } - return -} - -func (c *consumer_mongo) RegisterOnSessionInvalidated(cb func(primitive.ObjectID)) { - c.onSessionInvalidated = append(c.onSessionInvalidated, cb) -} diff --git a/session/impl_redis.go b/session/impl_redis.go index fe5e0db..077746b 100644 --- a/session/impl_redis.go +++ b/session/impl_redis.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "slices" "time" "github.com/go-redis/redis/v8" @@ -43,31 +44,18 @@ func newProviderWithRedis(ctx context.Context, redisUrl string, ttl time.Duratio } func (p *provider_redis) New(input *Authorization) (string, error) { - newsk := make_storagekey(input.Account) - prefix := input.Account.Hex() - sks, err := p.redisClient.Keys(p.ctx, prefix+"*").Result() + sks, err := p.RevokeAll(input.Account, false) if err != nil { - logger.Println("session provider delete :", sks, err) return "", err } - p.redisClient.Del(p.ctx, sks...) - for _, sk := range sks { - p.redisClient.Publish(p.ctx, p.deleteChannel, sk).Result() - } - + var newsk storagekey for { - duplicated := false - for _, sk := range sks { - if sk == string(newsk) { - duplicated = true - break - } - } + newsk = make_storagekey(input.Account) + duplicated := slices.Contains(sks, string(newsk)) if !duplicated { break } - newsk = make_storagekey(input.Account) } _, err = p.redisClient.HSet(p.ctx, string(newsk), input.ToStrings()).Result() @@ -82,20 +70,23 @@ func (p *provider_redis) New(input *Authorization) (string, error) { return string(pk), err } -func (p *provider_redis) RevokeAll(account primitive.ObjectID) error { +func (p *provider_redis) RevokeAll(account primitive.ObjectID, infinite bool) ([]string, error) { prefix := account.Hex() sks, err := p.redisClient.Keys(p.ctx, prefix+"*").Result() if err != nil { logger.Println("session provider delete :", sks, err) - return err + return nil, err } for _, sk := range sks { - p.redisClient.HSet(p.ctx, sk, "inv", "true") - p.redisClient.Publish(p.ctx, p.deleteChannel, sk).Result() + if infinite { + p.redisClient.Publish(p.ctx, p.deleteChannel, "~"+sk).Result() + } else { + p.redisClient.Publish(p.ctx, p.deleteChannel, sk).Result() + } } - return nil + return sks, nil } func (p *provider_redis) Query(pk string) (Authorization, error) { @@ -181,11 +172,24 @@ func newConsumerWithRedis(ctx context.Context, redisUrl string, ttl time.Duratio switch msg.Channel { case deleteChannel: - sk := storagekey(msg.Payload) - old := consumer.delete(sk) + infinite := false + var sk string + if msg.Payload[0] == '~' { + sk = msg.Payload[1:] + infinite = true + } else { + sk = msg.Payload + infinite = false + } + old := consumer.delete(storagekey(sk)) if old != nil { + invsess := InvalidatedSession{ + Account: old.Account, + Infinite: infinite, + } + for _, f := range consumer.onSessionInvalidated { - f(old.Account) + f(invsess) } } } @@ -366,6 +370,6 @@ func (c *consumer_redis) IsRevoked(accid primitive.ObjectID) bool { return false } -func (c *consumer_redis) RegisterOnSessionInvalidated(cb func(primitive.ObjectID)) { +func (c *consumer_redis) RegisterOnSessionInvalidated(cb func(InvalidatedSession)) { c.onSessionInvalidated = append(c.onSessionInvalidated, cb) } diff --git a/session/session_test.go b/session/session_test.go index 1f1562b..3db21bf 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -75,7 +75,7 @@ func TestExpTable(t *testing.T) { time.Sleep(2 * time.Second) time.Sleep(2 * time.Second) - pv.RevokeAll(au1.Account) + pv.RevokeAll(au1.Account, false) cs.Touch(sk1) time.Sleep(2 * time.Second) diff --git a/wshandler/wshandler.go b/wshandler/wshandler.go index d3fa385..492b1cd 100644 --- a/wshandler/wshandler.go +++ b/wshandler/wshandler.go @@ -334,8 +334,8 @@ func (ws *WebsocketHandler) LeaveRoom(room string, accid primitive.ObjectID) { } } -func (ws *WebsocketHandler) onSessionInvalidated(accid primitive.ObjectID) { - ws.forceCloseChan <- accid +func (ws *WebsocketHandler) onSessionInvalidated(invsess session.InvalidatedSession) { + ws.forceCloseChan <- invsess.Account } func (ws *WebsocketHandler) mainLoop(ctx context.Context) { diff --git a/wshandler/wshandler_peer.go b/wshandler/wshandler_peer.go index f1af5a7..0c6a1e9 100644 --- a/wshandler/wshandler_peer.go +++ b/wshandler/wshandler_peer.go @@ -176,9 +176,9 @@ func (ws *websocketPeerHandler[T]) RegisterHandlers(serveMux gocommon.ServerMuxI return nil } -func (ws *websocketPeerHandler[T]) onSessionInvalidated(accid primitive.ObjectID) { +func (ws *websocketPeerHandler[T]) onSessionInvalidated(invsess session.InvalidatedSession) { ws.peerDtorChannel <- peerDtorChannelValue{ - accid: accid, + accid: invsess.Account, } }