diff --git a/core/api.go b/core/api.go index ac9de2e..da14872 100644 --- a/core/api.go +++ b/core/api.go @@ -165,44 +165,67 @@ func (caller apiCaller) uploadAPI(w http.ResponseWriter, r *http.Request) error return nil } -func (caller apiCaller) whitelistAPI(w http.ResponseWriter, r *http.Request) error { +func (caller apiCaller) blockAPI(w http.ResponseWriter, r *http.Request) error { mg := caller.mg if r.Method == "GET" { - // if !caller.isAdminOrValidToken() { - // logger.Println("whitelistAPI failed. not vaild user :", r.Method, caller.userinfo) - // w.WriteHeader(http.StatusUnauthorized) - // return nil - // } + enc := json.NewEncoder(w) + enc.Encode(mg.bl.all()) + } else if r.Method == "PUT" { + body, _ := io.ReadAll(r.Body) + var bi blockinfo + if err := json.Unmarshal(body, &bi); err != nil { + return err + } - all, err := mg.mongoClient.All(CollectionWhitelist) + _, _, err := mg.mongoClient.Update(CollectionBlock, bson.M{ + "_id": primitive.NewObjectID(), + }, bson.M{ + "$set": &bi, + }, options.Update().SetUpsert(true)) + + if err != nil { + return err + } + } else if r.Method == "DELETE" { + id := r.URL.Query().Get("id") + + if len(id) == 0 { + return errors.New("id param is missing") + } + idobj, err := primitive.ObjectIDFromHex(id) if err != nil { return err } - if len(all) > 0 { - var notexp []primitive.M - for _, v := range all { - if _, exp := v["_ts"]; !exp { - notexp = append(notexp, v) - } - } - allraw, _ := json.Marshal(notexp) - w.Write(allraw) + _, _, err = mg.mongoClient.Update(CollectionBlock, bson.M{ + "_id": idobj, + }, bson.M{ + "$currentDate": bson.M{ + "_ts": bson.M{"$type": "date"}, + }, + }, options.Update().SetUpsert(false)) + + if err != nil { + return err } + + mg.mongoClient.Delete(CollectionAuth, bson.M{"_id": idobj}) + } + return nil +} + +func (caller apiCaller) whitelistAPI(w http.ResponseWriter, r *http.Request) error { + mg := caller.mg + if r.Method == "GET" { + enc := json.NewEncoder(w) + enc.Encode(mg.wl.all()) } else if r.Method == "PUT" { body, _ := io.ReadAll(r.Body) var member whitelistmember if err := json.Unmarshal(body, &member); err != nil { return err } - - // if !caller.isAdminOrValidToken() { - // logger.Println("whitelistAPI failed. not vaild user :", r.Method, caller.userinfo) - // w.WriteHeader(http.StatusUnauthorized) - // return nil - // } - - member.Expired = 0 + member.ExpiredAt = 0 _, _, err := mg.mongoClient.Update(CollectionWhitelist, bson.M{ "_id": primitive.NewObjectID(), @@ -432,6 +455,8 @@ func (mg *Maingate) api(w http.ResponseWriter, r *http.Request) { err = caller.maintenanceAPI(w, r) } else if strings.HasSuffix(r.URL.Path, "/files") { err = caller.filesAPI(w, r) + } else if strings.HasSuffix(r.URL.Path, "/block") { + err = caller.blockAPI(w, r) } if err != nil { diff --git a/core/maingate.go b/core/maingate.go index 613ad12..88dd1f5 100644 --- a/core/maingate.go +++ b/core/maingate.go @@ -169,7 +169,8 @@ type Maingate struct { //services servicelist serviceptr unsafe.Pointer admins unsafe.Pointer - wl whitelist + wl memberContainerPtr[string, *whitelistmember] + bl memberContainerPtr[primitive.ObjectID, *blockinfo] tokenEndpoints map[string]string authorizationEndpoints map[string]string @@ -409,27 +410,25 @@ func (mg *Maingate) prepare(context context.Context) (err error) { } } - var whites []whitelistmember + var whites []*whitelistmember if err := mg.mongoClient.AllAs(CollectionWhitelist, &whites, options.Find().SetReturnKey(false)); err != nil { return err } mg.wl.init(whites) + var blocks []*blockinfo + if err := mg.mongoClient.AllAs(CollectionBlock, &blocks, options.Find().SetReturnKey(false)); err != nil { + return err + } + mg.bl.init(blocks) + go watchAuthCollection(context, mg.auths, mg.mongoClient) - go mg.watchWhitelistCollection(context) + go mg.wl.watchCollection(context, CollectionWhitelist, mg.mongoClient) + go mg.bl.watchCollection(context, CollectionBlock, mg.mongoClient) return nil } -func whitelistKey(email string, platform string) string { - if strings.HasPrefix(email, "*@") { - // 도메인 전체 허용 - return email[2:] - } - - return email -} - func (mg *Maingate) RegisterHandlers(ctx context.Context, serveMux *http.ServeMux, prefix string) error { var allServices []*serviceDescription if err := mg.mongoClient.AllAs(CollectionService, &allServices, options.Find().SetReturnKey(false)); err != nil { diff --git a/core/member_container.go b/core/member_container.go new file mode 100644 index 0000000..2c4eaa4 --- /dev/null +++ b/core/member_container.go @@ -0,0 +1,169 @@ +package core + +import ( + "context" + "sync/atomic" + "time" + "unsafe" + + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/mongo" + "repositories.action2quare.com/ayo/gocommon" + "repositories.action2quare.com/ayo/gocommon/logger" +) + +type memberContraints[K comparable] interface { + Key() K + Expired() bool +} + +type memberContainerPtr[K comparable, T memberContraints[K]] struct { + ptr unsafe.Pointer +} + +func (p *memberContainerPtr[K, T]) init(ms []T) { + next := map[K]T{} + for _, m := range ms { + next[m.Key()] = m + } + atomic.StorePointer(&p.ptr, unsafe.Pointer(&next)) +} + +func (p *memberContainerPtr[K, T]) add(m T) { + ptr := atomic.LoadPointer(&p.ptr) + src := (*map[K]T)(ptr) + + next := map[K]T{} + for k, v := range *src { + next[k] = v + } + next[m.Key()] = m + atomic.StorePointer(&p.ptr, unsafe.Pointer(&next)) +} + +func (p *memberContainerPtr[K, T]) remove(key K) { + ptr := atomic.LoadPointer(&p.ptr) + src := (*map[K]T)(ptr) + + next := map[K]T{} + for k, v := range *src { + next[k] = v + } + delete(next, key) + atomic.StorePointer(&p.ptr, unsafe.Pointer(&next)) +} + +type memberPipelineDocument[K comparable, T memberContraints[K]] struct { + OperationType string `bson:"operationType"` + DocumentKey struct { + Id primitive.ObjectID `bson:"_id"` + } `bson:"documentKey"` + Member T `bson:"fullDocument"` +} + +func (p *memberContainerPtr[K, T]) all() []T { + ptr := atomic.LoadPointer(&p.ptr) + src := (*map[K]T)(ptr) + + out := make([]T, 0, len(*src)) + for _, m := range *src { + if m.Expired() { + continue + } + out = append(out, m) + } + return out +} + +func (p *memberContainerPtr[K, T]) contains(key K, out *T) bool { + ptr := atomic.LoadPointer(&p.ptr) + src := (*map[K]T)(ptr) + + found, exists := (*src)[key] + if exists { + if found.Expired() { + p.remove(key) + return false + } + if out != nil { + out = &found + } + return true + } + return false +} + +func (p *memberContainerPtr[K, T]) watchCollection(parentctx context.Context, coll gocommon.CollectionName, mc gocommon.MongoClient) { + defer func() { + s := recover() + if s != nil { + logger.Error(s) + } + }() + + matchStage := bson.D{ + { + Key: "$match", Value: bson.D{ + {Key: "operationType", Value: bson.D{ + {Key: "$in", Value: bson.A{ + "update", + "insert", + }}, + }}, + }, + }} + projectStage := bson.D{ + { + Key: "$project", Value: bson.D{ + {Key: "documentKey", Value: 1}, + {Key: "fullDocument", Value: 1}, + }, + }, + } + + var stream *mongo.ChangeStream + var err error + var ctx context.Context + + for { + if stream == nil { + stream, err = mc.Watch(coll, mongo.Pipeline{matchStage, projectStage}) + if err != nil { + logger.Error("watchCollection watch failed :", err) + time.Sleep(time.Minute) + continue + } + ctx = context.TODO() + } + + changed := stream.TryNext(ctx) + if ctx.Err() != nil { + logger.Error("watchCollection stream.TryNext failed. process should be restarted! :", ctx.Err().Error()) + break + } + + if changed { + var data memberPipelineDocument[K, T] + if err := stream.Decode(&data); err == nil { + p.add(data.Member) + } else { + logger.Error("watchCollection stream.Decode failed :", err) + } + } else if stream.Err() != nil || stream.ID() == 0 { + select { + case <-ctx.Done(): + logger.Println("watchCollection is done") + stream.Close(ctx) + return + + case <-time.After(time.Second): + logger.Error("watchCollection stream error :", stream.Err()) + stream.Close(ctx) + stream = nil + } + } else { + time.Sleep(time.Second) + } + } +} diff --git a/core/service.go b/core/service.go index 5e32a08..c4b3543 100644 --- a/core/service.go +++ b/core/service.go @@ -8,10 +8,9 @@ import ( "fmt" "io" "net/http" + "strconv" "strings" - "sync/atomic" "time" - "unsafe" "repositories.action2quare.com/ayo/gocommon" "repositories.action2quare.com/ayo/gocommon/logger" @@ -22,20 +21,38 @@ import ( ) type blockinfo struct { + Accid primitive.ObjectID `bson:"_id" json:"_id"` Start primitive.DateTime `bson:"start" json:"start"` End primitive.DateTime `bson:"_ts"` Reason string `bson:"reason" json:"reason"` } type whitelistmember struct { - Email string `bson:"email" json:"email"` - Platform string `bson:"platform" json:"platform"` - Desc string `bson:"desc" json:"desc"` - Expired primitive.DateTime `bson:"_ts,omitempty" json:"_ts,omitempty"` + Email string `bson:"email" json:"email"` + Platform string `bson:"platform" json:"platform"` + Desc string `bson:"desc" json:"desc"` + ExpiredAt primitive.DateTime `bson:"_ts,omitempty" json:"_ts,omitempty"` } -type whitelist struct { - emailptr unsafe.Pointer +func (wh *whitelistmember) Key() string { + if strings.HasPrefix(wh.Email, "*@") { + // 도메인 전체 허용 + return wh.Email[2:] + } + return wh.Email +} + +func (wh *whitelistmember) Expired() bool { + // 얘는 Expired가 있기만 하면 제거된 상태 + return wh.ExpiredAt != 0 +} + +func (bi *blockinfo) Key() primitive.ObjectID { + return bi.Accid +} + +func (bi *blockinfo) Expired() bool { + return bi.End.Time().Before(time.Now().UTC()) } type usertokeninfo struct { @@ -48,54 +65,6 @@ type usertokeninfo struct { accesstoken_expire_time int64 // microsoft only } -func (wl *whitelist) init(total []whitelistmember) { - all := make(map[string]*whitelistmember) - for _, member := range total { - all[whitelistKey(member.Email, member.Platform)] = &member - } - atomic.StorePointer(&wl.emailptr, unsafe.Pointer(&all)) -} - -func addToUnsafePointer(to *unsafe.Pointer, m *whitelistmember) { - ptr := atomic.LoadPointer(to) - src := (*map[string]*whitelistmember)(ptr) - - next := map[string]*whitelistmember{} - for k, v := range *src { - next[k] = v - } - next[whitelistKey(m.Email, m.Platform)] = m - atomic.StorePointer(to, unsafe.Pointer(&next)) -} - -func removeFromUnsafePointer(from *unsafe.Pointer, email string, platform string) { - ptr := atomic.LoadPointer(from) - src := (*map[string]*whitelistmember)(ptr) - - next := make(map[string]*whitelistmember) - for k, v := range *src { - next[k] = v - } - delete(next, whitelistKey(email, platform)) - atomic.StorePointer(from, unsafe.Pointer(&next)) -} - -func (wl *whitelist) add(m *whitelistmember) { - addToUnsafePointer(&wl.emailptr, m) -} - -func (wl *whitelist) remove(email string, platform string) { - removeFromUnsafePointer(&wl.emailptr, email, platform) -} - -func (wl *whitelist) isMember(email string, platform string) bool { - ptr := atomic.LoadPointer(&wl.emailptr) - src := *(*map[string]*whitelistmember)(ptr) - - _, exists := src[whitelistKey(email, platform)] - return exists -} - type DivisionStateName string const ( @@ -135,7 +104,8 @@ type serviceDescription struct { VersionSplits map[string]string `bson:"version_splits" json:"version_splits"` auths *gocommon.AuthCollection - wl *whitelist + wl memberContainerPtr[string, *whitelistmember] + bl memberContainerPtr[primitive.ObjectID, *blockinfo] mongoClient gocommon.MongoClient sessionTTL time.Duration @@ -281,7 +251,8 @@ func (sh *serviceDescription) prepare(mg *Maingate) error { sh.updateUserinfo = mg.updateUserinfo sh.getProviderInfo = mg.getProviderInfo - sh.wl = &mg.wl + sh.wl = mg.wl + sh.bl = mg.bl sh.serviceSummarySerialized, _ = json.Marshal(sh.ServiceDescriptionSummary) logger.Println("service is ready :", sh.ServiceCode, string(sh.divisionsSerialized)) @@ -657,28 +628,16 @@ func (sh *serviceDescription) authorize(w http.ResponseWriter, r *http.Request) oldcreate := account["create"].(primitive.DateTime) newaccount := oldcreate == createtime - var bi blockinfo - if err := sh.mongoClient.FindOneAs(CollectionBlock, bson.M{ - "code": sh.ServiceCode, - "accid": accid, - }, &bi); err != nil { - logger.Error("authorize failed. find blockinfo in CollectionBlock err:", err) - w.WriteHeader(http.StatusInternalServerError) + var bi *blockinfo + if sh.bl.contains(accid, &bi) { + // 블럭된 계정. 블락 정보를 알려준다. + w.Header().Add("MG-ACCOUNTBLOCK-START", strconv.FormatInt(bi.Start.Time().Unix(), 10)) + w.Header().Add("MG-ACCOUNTBLOCK-END", strconv.FormatInt(bi.End.Time().Unix(), 10)) + w.Header().Add("MG-ACCOUNTBLOCK-REASON", bi.Reason) + w.WriteHeader(http.StatusUnauthorized) return } - if !bi.Start.Time().IsZero() { - now := time.Now().UTC() - if bi.Start.Time().Before(now) && bi.End.Time().After(now) { - // block됐네? - // status는 정상이고 reason을 넘겨주자 - json.NewEncoder(w).Encode(map[string]any{ - "blocked": bi, - }) - return - } - } - newsession := primitive.NewObjectID() expired := primitive.NewDateTimeFromTime(time.Now().UTC().Add(sh.sessionTTL)) newauth := gocommon.Authinfo{ @@ -839,7 +798,8 @@ func (sh *serviceDescription) serveHTTP(w http.ResponseWriter, r *http.Request) w.WriteHeader(http.StatusBadRequest) return } - if sh.wl.isMember(cell.ToAuthinfo().Email, cell.ToAuthinfo().Platform) { + wm := &whitelistmember{Email: cell.ToAuthinfo().Email, Platform: cell.ToAuthinfo().Platform} + if sh.wl.contains(wm.Key(), nil) { // qa 권한이면 입장 가능 w.Write([]byte(fmt.Sprintf(`{"service":"%s"}`, div.Url))) } else if div.Maintenance != nil { diff --git a/core/watch.go b/core/watch.go index 04e3b8c..ba190f6 100644 --- a/core/watch.go +++ b/core/watch.go @@ -43,102 +43,6 @@ type filePipelineDocument struct { File *FileDocumentDesc `bson:"fullDocument"` } -type whilelistPipelineDocument struct { - OperationType string `bson:"operationType"` - DocumentKey struct { - Id primitive.ObjectID `bson:"_id"` - } `bson:"documentKey"` - Member *whitelistmember `bson:"fullDocument"` -} - -func (mg *Maingate) watchWhitelistCollection(parentctx context.Context) { - defer func() { - s := recover() - if s != nil { - logger.Error(s) - } - }() - - matchStage := bson.D{ - { - Key: "$match", Value: bson.D{ - {Key: "operationType", Value: bson.D{ - {Key: "$in", Value: bson.A{ - "update", - "insert", - }}, - }}, - }, - }} - projectStage := bson.D{ - { - Key: "$project", Value: bson.D{ - {Key: "documentKey", Value: 1}, - {Key: "operationType", Value: 1}, - {Key: "fullDocument", Value: 1}, - }, - }, - } - - var stream *mongo.ChangeStream - var err error - var ctx context.Context - - for { - if stream == nil { - stream, err = mg.mongoClient.Watch(CollectionWhitelist, mongo.Pipeline{matchStage, projectStage}) - if err != nil { - logger.Error("watchWhitelistCollection watch failed :", err) - time.Sleep(time.Minute) - continue - } - ctx = context.TODO() - } - - changed := stream.TryNext(ctx) - if ctx.Err() != nil { - logger.Error("watchWhitelistCollection stream.TryNext failed. process should be restarted! :", ctx.Err().Error()) - break - } - - if changed { - var data whilelistPipelineDocument - if err := stream.Decode(&data); err == nil { - ot := data.OperationType - switch ot { - case "insert": - // 새 화이트리스트 멤버 - mg.service().wl.add(data.Member) - case "update": - if data.Member.Expired != 0 { - logger.Println("whitelist member is removed :", *data.Member) - mg.service().wl.remove(data.Member.Email, data.Member.Platform) - } else { - logger.Println("whitelist member is updated :", *data.Member) - mg.service().wl.add(data.Member) - } - } - } else { - logger.Error("watchWhitelistCollection stream.Decode failed :", err) - } - } else if stream.Err() != nil || stream.ID() == 0 { - select { - case <-ctx.Done(): - logger.Println("watchWhitelistCollection is done") - stream.Close(ctx) - return - - case <-time.After(time.Second): - logger.Error("watchWhitelistCollection stream error :", stream.Err()) - stream.Close(ctx) - stream = nil - } - } else { - time.Sleep(time.Second) - } - } -} - func (mg *Maingate) watchFileCollection(parentctx context.Context, serveMux *http.ServeMux, prefix string) { defer func() { s := recover()