close func를 밖으로 옮김

This commit is contained in:
2023-07-05 22:26:57 +09:00
parent edd2f7aab5
commit 822681bf74
2 changed files with 53 additions and 85 deletions

View File

@ -8,16 +8,16 @@ import (
) )
type room struct { type room struct {
inChan chan *Richconn inChan chan *wsconn
outChan chan *Richconn outChan chan *wsconn
messageChan chan *UpstreamMessage messageChan chan *UpstreamMessage
name string name string
} }
func makeRoom(name string) *room { func makeRoom(name string) *room {
return &room{ return &room{
inChan: make(chan *Richconn, 10), inChan: make(chan *wsconn, 10),
outChan: make(chan *Richconn, 10), outChan: make(chan *wsconn, 10),
messageChan: make(chan *UpstreamMessage, 100), messageChan: make(chan *UpstreamMessage, 100),
name: name, name: name,
} }
@ -27,17 +27,17 @@ func (r *room) broadcast(msg *UpstreamMessage) {
r.messageChan <- msg r.messageChan <- msg
} }
func (r *room) in(conn *Richconn) { func (r *room) in(conn *wsconn) {
r.inChan <- conn r.inChan <- conn
} }
func (r *room) out(conn *Richconn) { func (r *room) out(conn *wsconn) {
r.outChan <- conn r.outChan <- conn
} }
func (r *room) start(ctx context.Context) { func (r *room) start(ctx context.Context) {
go func(ctx context.Context) { go func(ctx context.Context) {
conns := make(map[string]*Richconn) conns := make(map[string]*wsconn)
normal := false normal := false
for !normal { for !normal {
normal = r.loop(ctx, &conns) normal = r.loop(ctx, &conns)
@ -45,7 +45,7 @@ func (r *room) start(ctx context.Context) {
}(ctx) }(ctx)
} }
func (r *room) loop(ctx context.Context, conns *map[string]*Richconn) (normalEnd bool) { func (r *room) loop(ctx context.Context, conns *map[string]*wsconn) (normalEnd bool) {
defer func() { defer func() {
s := recover() s := recover()
if s != nil { if s != nil {

View File

@ -21,11 +21,10 @@ import (
var noSessionFlag = flagx.Bool("nosession", false, "nosession=[true|false]") var noSessionFlag = flagx.Bool("nosession", false, "nosession=[true|false]")
type Richconn struct { type wsconn struct {
*websocket.Conn *websocket.Conn
closeFuncLock sync.Mutex
alias string alias string
onClose map[string]func() accid primitive.ObjectID
} }
type UpstreamMessage struct { type UpstreamMessage struct {
@ -43,8 +42,8 @@ type DownstreamMessage struct {
type CommandType string type CommandType string
const ( const (
CommandType_JoinChannel = CommandType("join_channel") CommandType_JoinRoom = CommandType("join_room")
CommandType_LeaveChannel = CommandType("leave_channel") CommandType_LeaveRoom = CommandType("leave_room")
) )
type CommandMessage struct { type CommandMessage struct {
@ -52,53 +51,6 @@ type CommandMessage struct {
Args []string Args []string
} }
func (rc *Richconn) RegistOnCloseFunc(name string, f func()) {
rc.closeFuncLock.Lock()
defer rc.closeFuncLock.Unlock()
if rc.onClose == nil {
f()
return
}
rc.onClose[name] = f
}
func (rc *Richconn) HasOnCloseFunc(name string) bool {
rc.closeFuncLock.Lock()
defer rc.closeFuncLock.Unlock()
if rc.onClose == nil {
return false
}
_, ok := rc.onClose[name]
return ok
}
func (rc *Richconn) UnregistOnCloseFunc(name string) (out func()) {
rc.closeFuncLock.Lock()
defer rc.closeFuncLock.Unlock()
if rc.onClose == nil {
return
}
out = rc.onClose[name]
delete(rc.onClose, name)
return
}
func (rc *Richconn) Closed() {
rc.closeFuncLock.Lock()
defer rc.closeFuncLock.Unlock()
for _, f := range rc.onClose {
f()
}
}
func (rc *Richconn) WriteBytes(data []byte) error {
return rc.WriteMessage(websocket.TextMessage, data)
}
type subhandler struct { type subhandler struct {
authCache *gocommon.AuthCollection authCache *gocommon.AuthCollection
redisMsgChanName string redisMsgChanName string
@ -106,8 +58,10 @@ type subhandler struct {
redisSync *redis.Client redisSync *redis.Client
connsLock sync.Mutex connsLock sync.Mutex
connectedAlias map[string]bool connectedAlias map[string]bool
connInOutChan chan *Richconn connInOutChan chan *wsconn
deliveryChan chan any deliveryChan chan any
callReceiver func(primitive.ObjectID, string, io.Reader)
} }
// WebsocketHandler : // WebsocketHandler :
@ -120,7 +74,7 @@ type wsConfig struct {
SyncPipeline string `json:"ws_sync_pipeline"` SyncPipeline string `json:"ws_sync_pipeline"`
} }
func NewWebsocketHandler(authglobal gocommon.AuthCollectionGlobal) (wsh *WebsocketHandler) { func NewWebsocketHandler[T any](authglobal gocommon.AuthCollectionGlobal, receiver func(primitive.ObjectID, string, *T)) (wsh *WebsocketHandler) {
var config wsConfig var config wsConfig
gocommon.LoadConfig(&config) gocommon.LoadConfig(&config)
@ -129,6 +83,21 @@ func NewWebsocketHandler(authglobal gocommon.AuthCollectionGlobal) (wsh *Websock
panic(err) panic(err)
} }
decoder := func(r io.Reader) *T {
if r == nil {
// 접속이 끊겼을 때.
return nil
}
var m T
dec := json.NewDecoder(r)
if err := dec.Decode(&m); err != nil {
logger.Println(err)
}
// decoding 실패하더라도 빈 *T를 내보냄
return &m
}
authCaches := make(map[string]*subhandler) authCaches := make(map[string]*subhandler)
for _, region := range authglobal.Regions() { for _, region := range authglobal.Regions() {
sh := &subhandler{ sh := &subhandler{
@ -137,8 +106,11 @@ func NewWebsocketHandler(authglobal gocommon.AuthCollectionGlobal) (wsh *Websock
redisCmdChanName: fmt.Sprintf("_wsh_cmd_%s", region), redisCmdChanName: fmt.Sprintf("_wsh_cmd_%s", region),
redisSync: redisSync, redisSync: redisSync,
connectedAlias: make(map[string]bool), connectedAlias: make(map[string]bool),
connInOutChan: make(chan *Richconn), connInOutChan: make(chan *wsconn),
deliveryChan: make(chan any, 1000), deliveryChan: make(chan any, 1000),
callReceiver: func(accid primitive.ObjectID, alias string, r io.Reader) {
receiver(accid, alias, decoder(r))
},
} }
authCaches[region] = sh authCaches[region] = sh
@ -205,6 +177,7 @@ func (sh *subhandler) mainLoop(ctx context.Context) {
} }
}() }()
// redis channel에서 유저가 보낸 메시지를 읽는 go rountine
go func() { go func() {
var pubsub *redis.PubSub var pubsub *redis.PubSub
for { for {
@ -241,7 +214,7 @@ func (sh *subhandler) mainLoop(ctx context.Context) {
} }
}() }()
entireConns := make(map[string]*Richconn) entireConns := make(map[string]*wsconn)
rooms := make(map[string]*room) rooms := make(map[string]*room)
findRoom := func(name string, create bool) *room { findRoom := func(name string, create bool) *room {
room := rooms[name] room := rooms[name]
@ -253,6 +226,7 @@ func (sh *subhandler) mainLoop(ctx context.Context) {
return room return room
} }
// 유저에게서 온 메세지, 소켓 연결/해체 처리
for { for {
select { select {
case usermsg := <-sh.deliveryChan: case usermsg := <-sh.deliveryChan:
@ -270,7 +244,7 @@ func (sh *subhandler) mainLoop(ctx context.Context) {
} }
case *CommandMessage: case *CommandMessage:
if usermsg.Cmd == CommandType_JoinChannel && len(usermsg.Args) == 2 { if usermsg.Cmd == CommandType_JoinRoom && len(usermsg.Args) == 2 {
alias := usermsg.Args[0] alias := usermsg.Args[0]
roomName := usermsg.Args[1] roomName := usermsg.Args[1]
@ -278,7 +252,7 @@ func (sh *subhandler) mainLoop(ctx context.Context) {
if conn != nil { if conn != nil {
findRoom(roomName, true).in(conn) findRoom(roomName, true).in(conn)
} }
} else if usermsg.Cmd == CommandType_JoinChannel && len(usermsg.Args) == 2 { } else if usermsg.Cmd == CommandType_JoinRoom && len(usermsg.Args) == 2 {
alias := usermsg.Args[0] alias := usermsg.Args[0]
roomName := usermsg.Args[1] roomName := usermsg.Args[1]
@ -301,7 +275,7 @@ func (sh *subhandler) mainLoop(ctx context.Context) {
for _, room := range rooms { for _, room := range rooms {
room.out(c) room.out(c)
} }
c.Closed() sh.callReceiver(c.accid, c.alias, nil)
} else { } else {
sh.setConnected(c.alias, true) sh.setConnected(c.alias, true)
entireConns[c.alias] = c entireConns[c.alias] = c
@ -311,18 +285,16 @@ func (sh *subhandler) mainLoop(ctx context.Context) {
} }
func upgrade_core(sh *subhandler, conn *websocket.Conn, accid primitive.ObjectID, alias string) { func upgrade_core(sh *subhandler, conn *websocket.Conn, accid primitive.ObjectID, alias string) {
newconn := sh.makeRichConn(alias, conn) newconn := &wsconn{
Conn: conn,
alias: alias,
accid: accid,
}
sh.connInOutChan <- newconn sh.connInOutChan <- newconn
go func(c *Richconn, accid primitive.ObjectID, deliveryChan chan<- any) { go func(c *wsconn, accid primitive.ObjectID, deliveryChan chan<- any) {
for { for {
messageType, r, err := c.NextReader() messageType, r, err := c.NextReader()
// 웹소켓에서 직접 메시지를 받지 않는다.
if r != nil {
io.Copy(io.Discard, r)
}
if err != nil { if err != nil {
c.Close() c.Close()
break break
@ -331,6 +303,11 @@ func upgrade_core(sh *subhandler, conn *websocket.Conn, accid primitive.ObjectID
if messageType == websocket.CloseMessage { if messageType == websocket.CloseMessage {
break break
} }
if messageType == websocket.TextMessage {
// 유저가 직접 보낸 메시지
sh.callReceiver(accid, c.alias, r)
}
} }
c.Conn = nil c.Conn = nil
@ -432,12 +409,3 @@ func (sh *subhandler) upgrade(w http.ResponseWriter, r *http.Request) {
upgrade_core(sh, conn, accid, alias) upgrade_core(sh, conn, accid, alias)
} }
func (sh *subhandler) makeRichConn(alias string, conn *websocket.Conn) *Richconn {
rc := Richconn{
Conn: conn,
alias: alias,
onClose: make(map[string]func()),
}
return &rc
}