diff --git a/wshandler/room.go b/wshandler/room.go index f74fb45..8a1c046 100644 --- a/wshandler/room.go +++ b/wshandler/room.go @@ -62,10 +62,10 @@ func (r *room) loop(ctx context.Context, conns *map[string]*wsconn) (normalEnd b return true case conn := <-r.inChan: - (*conns)[conn.alias] = conn + (*conns)[conn.sender.Accid.Hex()] = conn case conn := <-r.outChan: - delete((*conns), conn.alias) + delete((*conns), conn.sender.Accid.Hex()) case msg := <-r.messageChan: for _, conn := range *conns { diff --git a/wshandler/wshandler.go b/wshandler/wshandler.go index 9bb7acb..d1bfc95 100644 --- a/wshandler/wshandler.go +++ b/wshandler/wshandler.go @@ -24,8 +24,7 @@ var noSessionFlag = flagx.Bool("nosession", false, "nosession=[true|false]") type wsconn struct { *websocket.Conn - alias string - accid primitive.ObjectID + sender *Sender } type UpstreamMessage struct { @@ -65,7 +64,13 @@ const ( Disconnected = WebSocketMessageType(101) ) -type WebSocketMessageReceiver func(accid primitive.ObjectID, alias string, messageType WebSocketMessageType, body io.Reader) +type Sender struct { + Region string + Accid primitive.ObjectID + Alias string +} + +type WebSocketMessageReceiver func(sender *Sender, messageType WebSocketMessageType, body io.Reader) type subhandler struct { authCache *gocommon.AuthCollection @@ -77,6 +82,7 @@ type subhandler struct { localDeliveryChan chan any callReceiver WebSocketMessageReceiver connWaitGroup sync.WaitGroup + region string } // WebsocketHandler : @@ -124,6 +130,7 @@ func NewWebsocketHandler(authglobal gocommon.AuthCollectionGlobal) (wsh *Websock connInOutChan: make(chan *wsconn), deliveryChan: make(chan any, 1000), localDeliveryChan: make(chan any, 100), + region: region, } authCaches[region] = sh @@ -144,13 +151,13 @@ func (ws *WebsocketHandler) Start(ctx context.Context) { for region, sh := range ws.authCaches { chain := ws.receiverChain[region] if len(chain) == 0 { - sh.callReceiver = func(accid primitive.ObjectID, alias string, messageType WebSocketMessageType, body io.Reader) {} + sh.callReceiver = func(sender *Sender, messageType WebSocketMessageType, body io.Reader) {} } else if len(chain) == 1 { sh.callReceiver = chain[0] } else { - sh.callReceiver = func(accid primitive.ObjectID, alias string, messageType WebSocketMessageType, body io.Reader) { + sh.callReceiver = func(sender *Sender, messageType WebSocketMessageType, body io.Reader) { for _, r := range chain { - r(accid, alias, messageType, body) + r(sender, messageType, body) } } } @@ -281,7 +288,8 @@ func (sh *subhandler) mainLoop(ctx context.Context) { case *UpstreamMessage: target := usermsg.Target if target[0] == '@' { - conn := entireConns[target[1:]] + accid := target[1:] + conn := entireConns[accid] if conn != nil { // 이 경우 아니면 publish 해야 함 conn.WriteMessage(websocket.TextMessage, usermsg.Body) @@ -294,19 +302,19 @@ func (sh *subhandler) mainLoop(ctx context.Context) { case *CommandMessage: if usermsg.Cmd == CommandType_JoinRoom && len(usermsg.Args) == 2 { - alias := usermsg.Args[0].(string) + accid := usermsg.Args[0].(string) roomName := usermsg.Args[1].(string) - conn := entireConns[alias] + conn := entireConns[accid] if conn != nil { findRoom(roomName, true).in(conn) break } } else if usermsg.Cmd == CommandType_LeaveRoom && len(usermsg.Args) == 2 { - alias := usermsg.Args[0].(string) + accid := usermsg.Args[0].(string) roomName := usermsg.Args[1].(string) - conn := entireConns[alias] + conn := entireConns[accid] if conn != nil { if room := findRoom(roomName, false); room != nil { room.out(conn) @@ -314,10 +322,11 @@ func (sh *subhandler) mainLoop(ctx context.Context) { } } } else if usermsg.Cmd == CommandType_WriteControl && len(usermsg.Args) == 2 { - alias := usermsg.Args[0].(string) - conn := entireConns[alias] + accid := usermsg.Args[0].(string) + conn := entireConns[accid] if conn != nil { conn.WriteControl(usermsg.Args[1].(int), usermsg.Args[2].([]byte), time.Time{}) + break } } @@ -338,7 +347,8 @@ func (sh *subhandler) mainLoop(ctx context.Context) { room.broadcast(usermsg) } } else if target[0] == '@' { - conn := entireConns[target[1:]] + accid := target[1:] + conn := entireConns[accid] if conn != nil { conn.WriteMessage(websocket.TextMessage, usermsg.Body) } @@ -346,18 +356,18 @@ func (sh *subhandler) mainLoop(ctx context.Context) { case *CommandMessage: if usermsg.Cmd == CommandType_JoinRoom && len(usermsg.Args) == 2 { - alias := usermsg.Args[0].(string) + accid := usermsg.Args[0].(string) roomName := usermsg.Args[1].(string) - conn := entireConns[alias] + conn := entireConns[accid] if conn != nil { findRoom(roomName, true).in(conn) } } else if usermsg.Cmd == CommandType_LeaveRoom && len(usermsg.Args) == 2 { - alias := usermsg.Args[0].(string) + accid := usermsg.Args[0].(string) roomName := usermsg.Args[1].(string) - conn := entireConns[alias] + conn := entireConns[accid] if conn != nil { if room := findRoom(roomName, false); room != nil { room.out(conn) @@ -371,13 +381,14 @@ func (sh *subhandler) mainLoop(ctx context.Context) { case c := <-sh.connInOutChan: if c.Conn == nil { - delete(entireConns, c.alias) + delete(entireConns, c.sender.Accid.Hex()) for _, room := range rooms { room.out(c) } - sh.callReceiver(c.accid, c.alias, Connected, nil) + sh.callReceiver(c.sender, Disconnected, nil) } else { - entireConns[c.alias] = c + entireConns[c.sender.Accid.Hex()] = c + sh.callReceiver(c.sender, Connected, nil) } } } @@ -385,9 +396,12 @@ func (sh *subhandler) mainLoop(ctx context.Context) { func upgrade_core(sh *subhandler, conn *websocket.Conn, accid primitive.ObjectID, alias string) { newconn := &wsconn{ - Conn: conn, - alias: alias, - accid: accid, + Conn: conn, + sender: &Sender{ + Region: sh.region, + Alias: alias, + Accid: accid, + }, } sh.connInOutChan <- newconn @@ -402,15 +416,15 @@ func upgrade_core(sh *subhandler, conn *websocket.Conn, accid primitive.ObjectID } if messageType == websocket.CloseMessage { - sh.callReceiver(accid, c.alias, CloseMessage, r) + sh.callReceiver(c.sender, CloseMessage, r) break } if messageType == websocket.TextMessage { // 유저가 직접 보낸 메시지 - sh.callReceiver(accid, c.alias, TextMessage, r) + sh.callReceiver(c.sender, TextMessage, r) } else if messageType == websocket.BinaryMessage { - sh.callReceiver(accid, c.alias, BinaryMessage, r) + sh.callReceiver(c.sender, BinaryMessage, r) } } sh.redisSync.Del(context.Background(), accid.Hex())