package gocommon import ( "context" "encoding/json" "fmt" "io" "net/http" "sync" "sync/atomic" "time" "unsafe" "repositories.action2quare.com/ayo/gocommon/logger" "github.com/go-redis/redis/v8" "go.mongodb.org/mongo-driver/bson/primitive" ) type Authinfo struct { Accid primitive.ObjectID `bson:"_id,omitempty" json:"_id,omitempty"` Platform string Uid string Email string Sk primitive.ObjectID RefreshToken string `bson:"refresh_token,omitempty" json:"refresh_token,omitempty"` Expired primitive.DateTime `bson:"_ts" json:"_ts"` } const ( sessionSyncChannelNamePrefix = "session-sync-channel2" ) type AuthinfoCell interface { ToBytes() []byte ToAuthinfo() *Authinfo } type AuthCollection struct { lock sync.Mutex // key : session auths map[string]*Authinfo expiring map[string]*Authinfo // accid -> session (auths) reverseOn map[primitive.ObjectID]string // accid -> session (expiring) reverseOff map[primitive.ObjectID]string nextTrim time.Time ttl time.Duration SessionRemoved func(string) SessionAdded func(AuthinfoCell) QuerySession func(string, string) AuthinfoCell Stop func() } func MakeAuthCollection(sessionTTL time.Duration) *AuthCollection { return &AuthCollection{ auths: make(map[string]*Authinfo), expiring: make(map[string]*Authinfo), reverseOn: make(map[primitive.ObjectID]string), reverseOff: make(map[primitive.ObjectID]string), nextTrim: time.Now().Add(time.Hour * 1000000), ttl: sessionTTL, SessionRemoved: func(string) {}, SessionAdded: func(AuthinfoCell) {}, QuerySession: func(string, string) AuthinfoCell { return nil }, } } type authCollectionConfig struct { RegionStorageConfig Maingate string `json:"maingate_service_url"` } type redisAuthCell struct { raw []byte } func (ac *redisAuthCell) ToAuthinfo() *Authinfo { var out Authinfo err := json.Unmarshal(ac.raw, &out) if err != nil { logger.Error("redisAuthCell ToAuthinfo failed :", string(ac.raw), err) return nil } return &out } func (ac *redisAuthCell) ToBytes() []byte { return ac.raw } func newAuthCollectionWithRedis(redisClient *redis.Client, subctx context.Context, maingateURL string, apiToken string) *AuthCollection { sessionTTL := int64(3600) ac := MakeAuthCollection(time.Duration(sessionTTL * int64(time.Second))) sessionSyncChannelName := fmt.Sprintf("%s-%d", sessionSyncChannelNamePrefix, redisClient.Options().DB) pubsub := redisClient.Subscribe(subctx, sessionSyncChannelName) ctx, cancel := context.WithCancel(context.TODO()) go func(ctx context.Context, sub *redis.PubSub, authCache *AuthCollection) { for { select { case <-ctx.Done(): return case msg := <-sub.Channel(): if msg == nil { return } if len(msg.Payload) == 0 { continue } if msg.Payload[0] == '-' { authCache.RemoveBySessionKey(msg.Payload[1:], false) } else { authCache.AddRaw(&redisAuthCell{ raw: []byte(msg.Payload), }) } } } }(ctx, pubsub, ac) ac.Stop = cancel ac.QuerySession = func(key string, token string) AuthinfoCell { req, _ := http.NewRequest("GET", fmt.Sprintf("%s/query?sk=%s", maingateURL, key), nil) req.Header.Add("Authorization", "Bearer "+token) req.Header.Add("MG-X-API-TOKEN", apiToken) client := http.Client{} resp, err := client.Do(req) if err != nil { logger.Error("authorize query failed :", err) return nil } defer resp.Body.Close() raw, _ := io.ReadAll(resp.Body) if len(raw) == 0 { // 세션키가 없네? 클라이언트한테 재로그인하라고 알려줘야 함 return nil } return &redisAuthCell{ raw: raw, } } ac.SessionAdded = func(cell AuthinfoCell) { redisClient.Publish(context.Background(), sessionSyncChannelName, cell.ToBytes()) } ac.SessionRemoved = func(sk string) { redisClient.Publish(context.Background(), sessionSyncChannelName, "-"+sk) } return ac } type AuthCollectionGlobal struct { apiToken string ptr unsafe.Pointer // map[string]*AuthCollection } func (acg *AuthCollectionGlobal) Get(region string) *AuthCollection { ptr := atomic.LoadPointer(&acg.ptr) oldval := *(*map[string]*AuthCollection)(ptr) return oldval[region] } func (acg *AuthCollectionGlobal) Regions() (out []string) { ptr := atomic.LoadPointer(&acg.ptr) oldval := *(*map[string]*AuthCollection)(ptr) for k := range oldval { out = append(out, k) } return } func (acg *AuthCollectionGlobal) Reload(context context.Context) error { ptr := atomic.LoadPointer(&acg.ptr) oldval := *(*map[string]*AuthCollection)(ptr) var config authCollectionConfig if err := LoadConfig(&config); err != nil { return err } newval := make(map[string]*AuthCollection) for r, c := range oldval { if _, ok := config.RegionStorage[r]; !ok { // 없어졌네? 닫음 c.Stop() } else { newval[r] = c } } for r, url := range config.RegionStorage { if _, ok := oldval[r]; !ok { // 새로 생겼네 redisClient, err := NewRedisClient(url.Redis["session"]) if err != nil { return err } if authCache := newAuthCollectionWithRedis(redisClient, context, config.Maingate, acg.apiToken); authCache != nil { newval[r] = authCache } } } atomic.StorePointer(&acg.ptr, unsafe.Pointer(&newval)) return nil } func NewAuthCollectionGlobal(context context.Context, apiToken string) (AuthCollectionGlobal, error) { var config authCollectionConfig if err := LoadConfig(&config); err != nil { return AuthCollectionGlobal{}, err } output := make(map[string]*AuthCollection) for region, url := range config.RegionStorage { redisClient, err := NewRedisClient(url.Redis["session"]) if err != nil { return AuthCollectionGlobal{}, err } if authCache := newAuthCollectionWithRedis(redisClient, context, config.Maingate, apiToken); authCache != nil { output[region] = authCache } } return AuthCollectionGlobal{ apiToken: apiToken, ptr: unsafe.Pointer(&output), }, nil } func (sc *AuthCollection) AddRaw(cell AuthinfoCell) { sc.lock.Lock() defer sc.lock.Unlock() if time.Now().After(sc.nextTrim) { sc.expiring, sc.auths = sc.auths, sc.expiring sc.reverseOff, sc.reverseOn = sc.reverseOn, sc.reverseOff // maps 패키지는 아직 0.0.0 상태;; https://pkg.go.dev/golang.org/x/exp/maps?tab=versions // maps.Clear(sc.auths) sc.auths = make(map[string]*Authinfo) sc.reverseOn = make(map[primitive.ObjectID]string) } newauth := cell.ToAuthinfo() if newauth == nil { logger.Println("AuthCollection.AddRaw failed. cell.ToAuthinfo returns nil") return } sk := newauth.Sk.Hex() if oldsk, exists := sc.reverseOn[newauth.Accid]; exists { delete(sc.auths, oldsk) delete(sc.reverseOn, newauth.Accid) } else if oldsk, exists = sc.reverseOff[newauth.Accid]; exists { delete(sc.expiring, oldsk) delete(sc.reverseOff, newauth.Accid) } delete(sc.auths, sk) delete(sc.expiring, sk) sc.auths[sk] = newauth sc.reverseOn[newauth.Accid] = sk if len(sc.auths) == 1 { sc.nextTrim = time.Now().Add(sc.ttl) } } func (sc *AuthCollection) Find(sk string) *Authinfo { sc.lock.Lock() defer sc.lock.Unlock() if found, ok := sc.auths[sk]; ok { return found } return sc.expiring[sk] } func (sc *AuthCollection) RemoveByAccId(accid primitive.ObjectID) { sc.lock.Lock() defer sc.lock.Unlock() var sk string if on, ok := sc.reverseOn[accid]; ok { sk = on } else if off, ok := sc.reverseOff[accid]; ok { sk = off } if len(sk) > 0 { old := sc.auths[sk] if old != nil { accid = old.Accid delete(sc.auths, sk) delete(sc.reverseOn, accid) } else if old = sc.expiring[sk]; old != nil { accid = old.Accid delete(sc.expiring, sk) delete(sc.reverseOff, accid) } } } func (sc *AuthCollection) RemoveBySessionKey(sk string, publish bool) (accid primitive.ObjectID) { sc.lock.Lock() defer sc.lock.Unlock() if publish { // 나한테 있든 없든 무조건 publish해야 함 sc.SessionRemoved(sk) } old := sc.auths[sk] if old != nil { accid = old.Accid delete(sc.auths, sk) delete(sc.reverseOn, accid) } else if old = sc.expiring[sk]; old != nil { accid = old.Accid delete(sc.expiring, sk) delete(sc.reverseOff, accid) } else { accid = primitive.NilObjectID } return } func (sc *AuthCollection) IsValid(sk string, token string) (accid primitive.ObjectID, success bool) { exists := sc.Find(sk) if exists != nil { now := int64(primitive.NewDateTimeFromTime(time.Now().UTC())) //if int64(exists.Expired) > now && exists.Token == token { if int64(exists.Expired) > now { //-- accesstoken은 사실상 쓰지 않는다. return exists.Accid, true } if exists.Expired == 0 { // 이미 maingate db까지 가서 만료된 것으로 확인된 키다. return primitive.NilObjectID, false } } cell := sc.QuerySession(sk, token) if cell == nil { // maingate db까지 가서 만료된 것으로 확인된 키다. Expired를 0으로 저장해 놓고 쿼리를 더 이상 보내지 않도록 sc.lock.Lock() defer sc.lock.Unlock() sc.auths[sk] = &Authinfo{Expired: 0} logger.Println("session is invalid. cell is nil :", sk) return primitive.NilObjectID, false } newauth := cell.ToAuthinfo() if newauth == nil { logger.Println("session is invalid. ToAuthinfo() returns nil :", sk) return primitive.NilObjectID, false } sc.AddRaw(cell) sc.SessionAdded(cell) return newauth.Accid, true }