diff --git a/wshandler/room.go b/wshandler/room.go new file mode 100644 index 0000000..a443e32 --- /dev/null +++ b/wshandler/room.go @@ -0,0 +1,82 @@ +package wshandler + +import ( + "context" + + "github.com/gorilla/websocket" + "repositories.action2quare.com/ayo/gocommon/logger" +) + +type room struct { + inChan chan *Richconn + outChan chan *Richconn + messageChan chan *UpstreamMessage + name string +} + +func makeRoom(name string) *room { + return &room{ + inChan: make(chan *Richconn, 10), + outChan: make(chan *Richconn, 10), + messageChan: make(chan *UpstreamMessage, 100), + name: name, + } +} + +func (r *room) broadcast(msg *UpstreamMessage) { + r.messageChan <- msg +} + +func (r *room) in(conn *Richconn) { + r.inChan <- conn +} + +func (r *room) out(conn *Richconn) { + r.outChan <- conn +} + +func (r *room) start(ctx context.Context) { + go func(ctx context.Context) { + conns := make(map[string]*Richconn) + normal := false + for !normal { + normal = r.loop(ctx, &conns) + } + }(ctx) +} + +func (r *room) loop(ctx context.Context, conns *map[string]*Richconn) (normalEnd bool) { + defer func() { + s := recover() + if s != nil { + logger.Error(s) + normalEnd = false + } + }() + + a, b, c := []byte(`{"alias":"`), []byte(`","body":"`), []byte(`"}`) + + for { + select { + case <-ctx.Done(): + return true + + case conn := <-r.inChan: + (*conns)[conn.alias] = conn + + case conn := <-r.outChan: + delete((*conns), conn.alias) + + case msg := <-r.messageChan: + for _, conn := range *conns { + writer, _ := conn.NextWriter(websocket.TextMessage) + writer.Write(a) + writer.Write([]byte(msg.Alias)) + writer.Write(b) + writer.Write(msg.Body) + writer.Write(c) + writer.Close() + } + } + } +} diff --git a/wshandler/wshandler.go b/wshandler/wshandler.go index 0ad7fb0..977ac58 100644 --- a/wshandler/wshandler.go +++ b/wshandler/wshandler.go @@ -7,100 +7,54 @@ import ( "fmt" "io" "net/http" - "os" "strings" "sync" - common "repositories.action2quare.com/ayo/gocommon" + "go.mongodb.org/mongo-driver/bson/primitive" + "repositories.action2quare.com/ayo/gocommon" "repositories.action2quare.com/ayo/gocommon/flagx" "repositories.action2quare.com/ayo/gocommon/logger" "github.com/go-redis/redis/v8" "github.com/gorilla/websocket" - "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/bson/primitive" ) var noSessionFlag = flagx.Bool("nosession", false, "nosession=[true|false]") -const ( - connStateCachePrefix = "conn_state_" - connStateScript = ` - local hosts = redis.call('keys',KEYS[1]) - for index, key in ipairs(hosts) do - local ok = redis.call('hexists', key, KEYS[2]) - if ok == 1 then - return redis.call('hget', key, KEYS[2]) - end - end - return "" - ` -) - -var ConnStateCacheKey = func() string { - hn, _ := os.Hostname() - return connStateCachePrefix + hn -}() - type Richconn struct { *websocket.Conn - lock sync.Mutex - alias primitive.ObjectID - tags []string - onClose map[string]func() + closeFuncLock sync.Mutex + alias string + onClose map[string]func() } -func (rc *Richconn) AddTag(name, val string) { - rc.lock.Lock() - defer rc.lock.Unlock() - - prefix := name + "=" - for i, tag := range rc.tags { - if strings.HasPrefix(tag, prefix) { - rc.tags[i] = prefix + val - return - } - } - rc.tags = append(rc.tags, prefix+val) +type UpstreamMessage struct { + Alias string + Accid primitive.ObjectID + Target string + Body []byte } -func (rc *Richconn) GetTag(name string) string { - rc.lock.Lock() - defer rc.lock.Unlock() - - prefix := name + "=" - for _, tag := range rc.tags { - if strings.HasPrefix(tag, prefix) { - return tag[len(prefix):] - } - } - return "" +type DownstreamMessage struct { + Alias string + Body string } -func (rc *Richconn) RemoveTag(name string, val string) { - rc.lock.Lock() - defer rc.lock.Unlock() +type CommandType string - whole := fmt.Sprintf("%s=%s", name, val) - for i, tag := range rc.tags { - if tag == whole { - if i == 0 && len(rc.tags) == 1 { - rc.tags = nil - } else { - lastidx := len(rc.tags) - 1 - if i < lastidx { - rc.tags[i] = rc.tags[lastidx] - } - rc.tags = rc.tags[:lastidx] - } - return - } - } +const ( + CommandType_JoinChannel = CommandType("join_channel") + CommandType_LeaveChannel = CommandType("leave_channel") +) + +type CommandMessage struct { + Cmd CommandType + Args []string } func (rc *Richconn) RegistOnCloseFunc(name string, f func()) { - rc.lock.Lock() - defer rc.lock.Unlock() + rc.closeFuncLock.Lock() + defer rc.closeFuncLock.Unlock() if rc.onClose == nil { f() @@ -110,8 +64,8 @@ func (rc *Richconn) RegistOnCloseFunc(name string, f func()) { } func (rc *Richconn) HasOnCloseFunc(name string) bool { - rc.lock.Lock() - defer rc.lock.Unlock() + rc.closeFuncLock.Lock() + defer rc.closeFuncLock.Unlock() if rc.onClose == nil { return false @@ -122,8 +76,8 @@ func (rc *Richconn) HasOnCloseFunc(name string) bool { } func (rc *Richconn) UnregistOnCloseFunc(name string) (out func()) { - rc.lock.Lock() - defer rc.lock.Unlock() + rc.closeFuncLock.Lock() + defer rc.closeFuncLock.Unlock() if rc.onClose == nil { return @@ -133,52 +87,27 @@ func (rc *Richconn) UnregistOnCloseFunc(name string) (out func()) { return } +func (rc *Richconn) Closed() { + rc.closeFuncLock.Lock() + defer rc.closeFuncLock.Unlock() + for _, f := range rc.onClose { + f() + } +} + func (rc *Richconn) WriteBytes(data []byte) error { - rc.lock.Lock() - defer rc.lock.Unlock() return rc.WriteMessage(websocket.TextMessage, data) } -type DeliveryMessage struct { - Alias primitive.ObjectID - Body []byte - Command string - Conn *Richconn -} - -func (dm *DeliveryMessage) Parse(out any) error { - return json.Unmarshal(dm.Body, out) -} - -func (dm *DeliveryMessage) MarshalBinary() (data []byte, err error) { - return append(dm.Alias[:], dm.Body...), nil -} - -func (dm *DeliveryMessage) UnmarshalBinary(data []byte) error { - copy(dm.Alias[:], data[:12]) - dm.Body = data[12:] - return nil -} - -type tagconn struct { - rc *Richconn - state string -} -type tagconnsmap = map[primitive.ObjectID]*tagconn -type tagconns struct { - sync.Mutex - tagconnsmap -} - type subhandler struct { - sync.Mutex - authCache *common.AuthCollection - conns map[primitive.ObjectID]*Richconn - aliases map[primitive.ObjectID]primitive.ObjectID - tags map[primitive.ObjectID]*tagconns - deliveryChan chan DeliveryMessage - url string - redisSync *redis.Client + authCache *gocommon.AuthCollection + redisMsgChanName string + redisCmdChanName string + redisSync *redis.Client + connsLock sync.Mutex + connectedAlias map[string]bool + connInOutChan chan *Richconn + deliveryChan chan any } // WebsocketHandler : @@ -191,28 +120,30 @@ type wsConfig struct { SyncPipeline string `json:"ws_sync_pipeline"` } -func NewWebsocketHandler(authglobal common.AuthCollectionGlobal) (wsh *WebsocketHandler) { +func NewWebsocketHandler(authglobal gocommon.AuthCollectionGlobal) (wsh *WebsocketHandler) { + var config wsConfig + gocommon.LoadConfig(&config) + + redisSync, err := gocommon.NewRedisClient(config.SyncPipeline, 0) + if err != nil { + panic(err) + } + authCaches := make(map[string]*subhandler) for _, region := range authglobal.Regions() { sh := &subhandler{ - authCache: authglobal.Get(region), - conns: make(map[primitive.ObjectID]*Richconn), - aliases: make(map[primitive.ObjectID]primitive.ObjectID), - tags: make(map[primitive.ObjectID]*tagconns), - deliveryChan: make(chan DeliveryMessage, 1000), + authCache: authglobal.Get(region), + redisMsgChanName: fmt.Sprintf("_wsh_msg_%s", region), + redisCmdChanName: fmt.Sprintf("_wsh_cmd_%s", region), + redisSync: redisSync, + connectedAlias: make(map[string]bool), + connInOutChan: make(chan *Richconn), + deliveryChan: make(chan any, 1000), } authCaches[region] = sh } - var config wsConfig - common.LoadConfig(&config) - - redisSync, err := common.NewRedisClient(config.SyncPipeline, 0) - if err != nil { - panic(err) - } - return &WebsocketHandler{ authCaches: authCaches, RedisSync: redisSync, @@ -220,292 +151,190 @@ func NewWebsocketHandler(authglobal common.AuthCollectionGlobal) (wsh *Websocket } func (ws *WebsocketHandler) Destructor() { - if ws.RedisSync != nil { - ws.RedisSync.Del(context.Background(), ConnStateCacheKey) - } } -func (ws *WebsocketHandler) DeliveryChannel(region string) <-chan DeliveryMessage { - return ws.authCaches[region].deliveryChan -} - -func (ws *WebsocketHandler) Conn(region string, alias primitive.ObjectID) *Richconn { +func (ws *WebsocketHandler) IsConnected(region string, alias string) bool { if sh := ws.authCaches[region]; sh != nil { - return sh.conns[alias] + return sh.connected(alias) } - return nil + return false } -func (ws *WebsocketHandler) JoinTag(region string, tag primitive.ObjectID, tid primitive.ObjectID, rc *Richconn, hint string) error { - if sh := ws.authCaches[region]; sh != nil { - sh.joinTag(tag, tid, rc, hint) - } - return nil -} - -func (ws *WebsocketHandler) LeaveTag(region string, tag primitive.ObjectID, tid primitive.ObjectID) error { - if sh := ws.authCaches[region]; sh != nil { - sh.leaveTag(tag, tid) - } - return nil -} - -func (ws *WebsocketHandler) SetStateInTag(region string, tag primitive.ObjectID, tid primitive.ObjectID, state string, hint string) error { - if sh := ws.authCaches[region]; sh != nil { - sh.setStateInTag(tag, tid, state, hint) - } - return nil -} - -func (ws *WebsocketHandler) BroadcastRaw(region string, tag primitive.ObjectID, raw []byte) { - if sh := ws.authCaches[region]; sh != nil { - if cs := sh.cloneTag(tag); len(cs) > 0 { - go func(raw []byte) { - for _, c := range cs { - if c != nil { - c.WriteBytes(raw) - } - } - }(raw) - } - } -} - -func (ws *WebsocketHandler) Broadcast(region string, tag primitive.ObjectID, doc bson.M) { - raw, _ := json.Marshal(doc) - ws.BroadcastRaw(region, tag, raw) -} - -var onlineQueryScriptHash string - func (ws *WebsocketHandler) RegisterHandlers(ctx context.Context, serveMux *http.ServeMux, prefix string) error { - ws.RedisSync.Del(context.Background(), ConnStateCacheKey) - - scriptHash, err := ws.RedisSync.ScriptLoad(context.Background(), connStateScript).Result() - if err != nil { - return err - } - onlineQueryScriptHash = scriptHash for region, sh := range ws.authCaches { if region == "default" { region = "" } - sh.url = common.MakeHttpHandlerPattern(prefix, region, "ws") - sh.redisSync = ws.RedisSync + url := gocommon.MakeHttpHandlerPattern(prefix, region, "ws") if *noSessionFlag { - serveMux.HandleFunc(sh.url, sh.upgrade_nosession) + serveMux.HandleFunc(url, sh.upgrade_nosession) } else { - serveMux.HandleFunc(sh.url, sh.upgrade) + serveMux.HandleFunc(url, sh.upgrade) } + + go sh.mainLoop(ctx) } return nil } -func (sh *subhandler) cloneTag(tag primitive.ObjectID) (out []*Richconn) { - sh.Lock() - cs := sh.tags[tag] - sh.Unlock() +func (sh *subhandler) setConnected(alias string, connected bool) { + sh.connsLock.Lock() + defer sh.connsLock.Unlock() - if cs == nil { - return nil - } - - cs.Lock() - defer cs.Unlock() - - out = make([]*Richconn, 0, len(cs.tagconnsmap)) - for _, c := range cs.tagconnsmap { - out = append(out, c.rc) - } - return -} - -func (sh *subhandler) joinTag(tag primitive.ObjectID, tid primitive.ObjectID, rc *Richconn, hint string) { - sh.Lock() - cs := sh.tags[tag] - if cs == nil { - cs = &tagconns{ - tagconnsmap: make(map[primitive.ObjectID]*tagconn), - } - } - sh.Unlock() - - cs.Lock() - states := make([]bson.M, 0, len(cs.tagconnsmap)) - for tid, conn := range cs.tagconnsmap { - states = append(states, bson.M{ - "_id": tid, - "_hint": hint, - "state": conn.state, - }) - } - - cs.tagconnsmap[tid] = &tagconn{rc: rc} - cs.Unlock() - - sh.Lock() - sh.tags[tag] = cs - sh.Unlock() - - if len(states) > 0 { - s, _ := json.Marshal(states) - rc.WriteBytes(s) - } -} - -func (sh *subhandler) leaveTag(tag primitive.ObjectID, tid primitive.ObjectID) { - sh.Lock() - defer sh.Unlock() - - cs := sh.tags[tag] - if cs == nil { - return - } - - delete(cs.tagconnsmap, tid) - if len(cs.tagconnsmap) == 0 { - delete(sh.tags, tag) + if connected { + sh.connectedAlias[alias] = true } else { - sh.tags[tag] = cs + delete(sh.connectedAlias, alias) } } -func (sh *subhandler) setStateInTag(tag primitive.ObjectID, tid primitive.ObjectID, state string, hint string) { - sh.Lock() - cs := sh.tags[tag] - sh.Unlock() +func (sh *subhandler) connected(alias string) bool { + sh.connsLock.Lock() + defer sh.connsLock.Unlock() - if cs == nil { - return - } + _, ok := sh.connectedAlias[alias] + return ok +} - cs.Lock() - defer cs.Unlock() - - if tagconn := cs.tagconnsmap[tid]; tagconn != nil { - tagconn.state = state - - var clone []*Richconn - for _, c := range cs.tagconnsmap { - clone = append(clone, c.rc) +func (sh *subhandler) mainLoop(ctx context.Context) { + defer func() { + s := recover() + if s != nil { + logger.Error(s) } - raw, _ := json.Marshal(map[string]any{ - "_id": tid, - "_hint": hint, - "state": state, - }) - go func(raw []byte) { - for _, c := range clone { - c.WriteBytes(raw) - } - }(raw) - } -} + }() -func (wsh *WebsocketHandler) GetState(alias primitive.ObjectID) (string, error) { - state, err := wsh.RedisSync.EvalSha(context.Background(), onlineQueryScriptHash, []string{ - connStateCachePrefix + "*", alias.Hex(), - }).Result() - - if err != nil { - return "", err - } - - return state.(string), nil -} - -func (wsh *WebsocketHandler) IsOnline(alias primitive.ObjectID) (bool, error) { - state, err := wsh.GetState(alias) - if err != nil { - logger.Error("IsOnline failed. err :", err) - return false, err - } - return len(state) > 0, nil -} - -func (sh *subhandler) closeConn(accid primitive.ObjectID) { - sh.Lock() - defer sh.Unlock() - - if alias, ok := sh.aliases[accid]; ok { - if old := sh.conns[alias]; old != nil { - old.Close() - } - } -} - -func (sh *subhandler) addConn(conn *Richconn, accid primitive.ObjectID) { - sh.Lock() - defer sh.Unlock() - - sh.conns[conn.alias] = conn - sh.aliases[accid] = conn.alias -} - -func upgrade_core(sh *subhandler, conn *websocket.Conn, initState string, accid primitive.ObjectID, alias primitive.ObjectID) { - sh.closeConn(accid) - - newconn := sh.makeRichConn(alias, conn) - sh.addConn(newconn, accid) - sh.redisSync.HSet(context.Background(), ConnStateCacheKey, alias.Hex(), initState).Result() - - go func(c *Richconn, accid primitive.ObjectID, deliveryChan chan<- DeliveryMessage) { + go func() { + var pubsub *redis.PubSub for { - mt, p, err := c.ReadMessage() + if pubsub == nil { + pubsub = sh.redisSync.Subscribe(ctx, sh.redisMsgChanName, sh.redisCmdChanName) + } + + raw, err := pubsub.ReceiveMessage(ctx) + if err == nil { + if raw.Channel == sh.redisMsgChanName { + var msg UpstreamMessage + if err := json.Unmarshal([]byte(raw.Payload), &msg); err == nil { + sh.deliveryChan <- &msg + } else { + logger.Println("decode UpstreamMessage failed :", err) + } + } else if raw.Channel == sh.redisCmdChanName { + var cmd CommandMessage + if err := json.Unmarshal([]byte(raw.Payload), &cmd); err == nil { + sh.deliveryChan <- &cmd + } else { + logger.Println("decode UpstreamMessage failed :", err) + } + } + } else { + logger.Println("pubsub.ReceiveMessage failed :", err) + pubsub.Close() + pubsub = nil + + if ctx.Err() != nil { + break + } + } + } + }() + + entireConns := make(map[string]*Richconn) + rooms := make(map[string]*room) + findRoom := func(name string, create bool) *room { + room := rooms[name] + if room == nil && create { + room = makeRoom(name) + rooms[name] = room + room.start(ctx) + } + return room + } + + for { + select { + case usermsg := <-sh.deliveryChan: + switch usermsg := usermsg.(type) { + case *UpstreamMessage: + target := usermsg.Target + if target[0] == '#' { + // 룸에 브로드캐스팅 + roomName := target[1:] + if room := findRoom(roomName, false); room != nil { + room.broadcast(usermsg) + } + } else if target[0] == '@' { + // TODO : 특정 유저에게만 + } + + case *CommandMessage: + if usermsg.Cmd == CommandType_JoinChannel && len(usermsg.Args) == 2 { + alias := usermsg.Args[0] + roomName := usermsg.Args[1] + + conn := entireConns[alias] + if conn != nil { + findRoom(roomName, true).in(conn) + } + } else if usermsg.Cmd == CommandType_JoinChannel && len(usermsg.Args) == 2 { + alias := usermsg.Args[0] + roomName := usermsg.Args[1] + + conn := entireConns[alias] + if conn != nil { + if room := findRoom(roomName, false); room != nil { + room.out(conn) + } + } + } + + default: + logger.Println("usermsg is unknown type") + } + + case c := <-sh.connInOutChan: + if c.Conn == nil { + delete(entireConns, c.alias) + sh.setConnected(c.alias, false) + for _, room := range rooms { + room.out(c) + } + c.Closed() + } else { + sh.setConnected(c.alias, true) + entireConns[c.alias] = c + } + } + } +} + +func upgrade_core(sh *subhandler, conn *websocket.Conn, accid primitive.ObjectID, alias string) { + newconn := sh.makeRichConn(alias, conn) + sh.connInOutChan <- newconn + + go func(c *Richconn, accid primitive.ObjectID, deliveryChan chan<- any) { + for { + messageType, r, err := c.NextReader() + + // 웹소켓에서 직접 메시지를 받지 않는다. + if r != nil { + io.Copy(io.Discard, r) + } if err != nil { c.Close() break } - switch mt { - case websocket.BinaryMessage: - msg := DeliveryMessage{ - Alias: c.alias, - Body: p, - Conn: c, - } - deliveryChan <- msg - - case websocket.TextMessage: - msg := string(p) - opcodes := strings.Split(msg, ";") - for _, opcode := range opcodes { - if strings.HasPrefix(opcode, "ps:") { - sh.redisSync.HSet(context.Background(), ConnStateCacheKey, alias.Hex(), opcode[3:]).Result() - } else if strings.HasPrefix(opcode, "cmd:") { - cmd := opcode[4:] - msg := DeliveryMessage{ - Alias: c.alias, - Command: cmd, - Conn: c, - } - deliveryChan <- msg - } - } - + if messageType == websocket.CloseMessage { + break } } - sh.redisSync.HDel(context.Background(), ConnStateCacheKey, c.alias.Hex()).Result() - - sh.Lock() - delete(sh.conns, c.alias) - delete(sh.aliases, accid) - sh.Unlock() - - var funcs []func() - c.lock.Lock() - for _, f := range c.onClose { - funcs = append(funcs, f) - } - c.onClose = nil - c.lock.Unlock() - - for _, f := range funcs { - f() - } + c.Conn = nil + sh.connInOutChan <- c }(newconn, accid, sh.deliveryChan) } @@ -551,17 +380,14 @@ func (sh *subhandler) upgrade_nosession(w http.ResponseWriter, r *http.Request) return } - alias := accid + var alias string if v := r.Header.Get("AS-X-ALIAS"); len(v) > 0 { - alias = common.ParseObjectID(v) + alias = v + } else { + alias = accid.Hex() } - initState := r.Header.Get("As-X-Tavern-InitialState") - if len(initState) == 0 { - initState = "online" - } - - upgrade_core(sh, conn, initState, accid, alias) + upgrade_core(sh, conn, accid, alias) } func (sh *subhandler) upgrade(w http.ResponseWriter, r *http.Request) { @@ -597,20 +423,17 @@ func (sh *subhandler) upgrade(w http.ResponseWriter, r *http.Request) { return } - alias := accid + var alias string if v := r.Header.Get("AS-X-ALIAS"); len(v) > 0 { - alias = common.ParseObjectID(v) + alias = v + } else { + alias = accid.Hex() } - initState := r.Header.Get("As-X-Tavern-InitialState") - if len(initState) == 0 { - initState = "online" - } - - upgrade_core(sh, conn, initState, accid, alias) + upgrade_core(sh, conn, accid, alias) } -func (sh *subhandler) makeRichConn(alias primitive.ObjectID, conn *websocket.Conn) *Richconn { +func (sh *subhandler) makeRichConn(alias string, conn *websocket.Conn) *Richconn { rc := Richconn{ Conn: conn, alias: alias,