diff --git a/wshandler/api_handler.go b/wshandler/api_handler.go index 0de56f6..b9a2768 100644 --- a/wshandler/api_handler.go +++ b/wshandler/api_handler.go @@ -7,7 +7,6 @@ import ( "strings" "unsafe" - "github.com/gorilla/websocket" "repositories.action2quare.com/ayo/gocommon/logger" ) @@ -17,7 +16,7 @@ const ( ) type apiFuncType func(ApiCallContext) -type connFuncType func(*websocket.Conn, *Sender) +type connFuncType func(*Conn, *Sender) type disconnFuncType func(string, *Sender) type WebsocketApiHandler struct { @@ -53,7 +52,7 @@ func MakeWebsocketApiHandler[T any](receiver *T, receiverName string) WebsocketA if method.Type.NumIn() != 3 { continue } - if method.Type.In(1) != reflect.TypeOf((*websocket.Conn)(nil)) { + if method.Type.In(1) != reflect.TypeOf((*Conn)(nil)) { continue } if method.Type.In(2) != reflect.TypeOf((*Sender)(nil)) { @@ -62,9 +61,9 @@ func MakeWebsocketApiHandler[T any](receiver *T, receiverName string) WebsocketA funcptr := method.Func.Pointer() p1 := unsafe.Pointer(&funcptr) p2 := unsafe.Pointer(&p1) - connfuncptr := (*func(*T, *websocket.Conn, *Sender))(p2) + connfuncptr := (*func(*T, *Conn, *Sender))(p2) - connfunc = func(c *websocket.Conn, s *Sender) { + connfunc = func(c *Conn, s *Sender) { (*connfuncptr)(receiver, c, s) } } else if method.Name == ClientDisconnected { diff --git a/wshandler/wshandler.go b/wshandler/wshandler.go index 30e9080..b4c3e95 100644 --- a/wshandler/wshandler.go +++ b/wshandler/wshandler.go @@ -10,9 +10,11 @@ import ( "errors" "fmt" "io" + "net" "net/http" "strings" "sync" + "sync/atomic" "time" "go.mongodb.org/mongo-driver/bson/primitive" @@ -27,12 +29,94 @@ import ( var noAuthFlag = flagx.Bool("noauth", false, "") +type Conn struct { + innerConn *websocket.Conn + spinLock int32 +} + +func makeConn(conn *websocket.Conn) *Conn { + return &Conn{ + innerConn: conn, + spinLock: 0, + } +} + type wsconn struct { - *websocket.Conn + *Conn sender *Sender closeMessage string } +type noCopy struct{} +type websocketWriter struct { + _ noCopy + innerConn *websocket.Conn + spinLock *int32 + fingerprint int32 +} + +func (c websocketWriter) writeImpl(vf func() error) error { + defer atomic.StoreInt32(c.spinLock, 0) + + for i := int64(0); ; i++ { + if atomic.CompareAndSwapInt32(c.spinLock, 0, c.fingerprint) { + return vf() + } + + time.Sleep(time.Microsecond) + if i >= int64(time.Second/time.Microsecond) && i&int64(time.Second/time.Microsecond) == 0 { + // 1초동안 락 실패 + logger.Println("websocket write lock failed : ", i/int64(time.Second/time.Microsecond)) + } + } +} + +func (c websocketWriter) WriteJSON(v interface{}) error { + return c.writeImpl(func() error { return c.innerConn.WriteJSON(v) }) +} + +func (c websocketWriter) WriteMessage(messageType int, data []byte) error { + return c.writeImpl(func() error { return c.innerConn.WriteMessage(messageType, data) }) +} + +func (c websocketWriter) WritePreparedMessage(pm *websocket.PreparedMessage) error { + return c.writeImpl(func() error { return c.innerConn.WritePreparedMessage(pm) }) +} + +func (c websocketWriter) WriteControl(messageType int, data []byte, deadline time.Time) error { + return c.writeImpl(func() error { return c.innerConn.WriteControl(messageType, data, deadline) }) +} + +func (c *Conn) SetReadDeadline(t time.Time) error { + return c.innerConn.SetReadDeadline(t) +} + +func (c *Conn) Close() error { + return c.innerConn.Close() +} + +func (c *Conn) ReadMessage() (messageType int, p []byte, err error) { + return c.innerConn.ReadMessage() +} + +func (c *Conn) RemoteAddr() net.Addr { + return c.innerConn.RemoteAddr() +} + +func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { + return c.innerConn.NextReader() +} + +var websocketWriterSeq = int32(1) + +func (c *Conn) MakeWriter() websocketWriter { + return websocketWriter{ + innerConn: c.innerConn, + spinLock: &c.spinLock, + fingerprint: atomic.AddInt32(&websocketWriterSeq, 1), + } +} + type UpstreamMessage struct { Alias string Accid primitive.ObjectID @@ -82,7 +166,7 @@ type EventReceiver interface { } type send_msg_queue_elem struct { - to *websocket.Conn + to *Conn pmsg *websocket.PreparedMessage //msg []byte } @@ -148,7 +232,7 @@ func NewWebsocketHandler(consumer session.Consumer, redisUrl string) (*Websocket return } - elem.to.WritePreparedMessage(elem.pmsg) + elem.to.MakeWriter().WritePreparedMessage(elem.pmsg) } for elem := range sendchan { @@ -197,7 +281,7 @@ func (ws *WebsocketHandler) SendUpstreamMessage(msg *UpstreamMessage) { ws.localDeliveryChan <- msg } -func (ws *WebsocketHandler) WriteDirectMessage(c *websocket.Conn, messageType int, data []byte) { +func (ws *WebsocketHandler) WriteDirectMessage(c *Conn, messageType int, data []byte) { pmsg, _ := websocket.NewPreparedMessage(messageType, data) ws.sendMsgChan <- send_msg_queue_elem{ to: c, @@ -479,13 +563,13 @@ func (ws *WebsocketHandler) mainLoop(ctx context.Context) { case accid := <-ws.forceCloseChan: if conn := entireConns[accid.Hex()]; conn != nil { - conn.WriteControl(websocket.CloseMessage, unauthdata, time.Time{}) + conn.MakeWriter().WriteControl(websocket.CloseMessage, unauthdata, time.Time{}) } } } } -func upgrade_core(ws *WebsocketHandler, conn *websocket.Conn, accid primitive.ObjectID, alias string) { +func upgrade_core(ws *WebsocketHandler, conn *Conn, accid primitive.ObjectID, alias string) { newconn := &wsconn{ Conn: conn, sender: &Sender{ @@ -503,7 +587,7 @@ func upgrade_core(ws *WebsocketHandler, conn *websocket.Conn, accid primitive.Ob }() for { - messageType, r, err := c.NextReader() + messageType, r, err := c.innerConn.NextReader() if err != nil { if ce, ok := err.(*websocket.CloseError); ok { c.closeMessage = ce.Text @@ -596,7 +680,7 @@ func (ws *WebsocketHandler) upgrade_nosession(w http.ResponseWriter, r *http.Req alias = accid.Hex() } - upgrade_core(ws, conn, accid, alias) + upgrade_core(ws, makeConn(conn), accid, alias) } func (ws *WebsocketHandler) upgrade(w http.ResponseWriter, r *http.Request) { @@ -643,5 +727,5 @@ func (ws *WebsocketHandler) upgrade(w http.ResponseWriter, r *http.Request) { alias = authinfo.Account.Hex() } - upgrade_core(ws, conn, authinfo.Account, alias) + upgrade_core(ws, makeConn(conn), authinfo.Account, alias) } diff --git a/wshandler/wshandler_peer.go b/wshandler/wshandler_peer.go index bbb5f98..937c3ba 100644 --- a/wshandler/wshandler_peer.go +++ b/wshandler/wshandler_peer.go @@ -23,7 +23,7 @@ type WebsocketPeerHandler interface { type peerCtorChannelValue struct { accid primitive.ObjectID - conn *websocket.Conn + conn *Conn } type peerDtorChannelValue struct { @@ -42,7 +42,7 @@ type websocketPeerHandler[T PeerInterface] struct { type PeerInterface interface { ClientDisconnected(string) - ClientConnected(*websocket.Conn) + ClientConnected(*Conn) } type peerApiFuncType[T PeerInterface] func(T, io.Reader) (any, error) @@ -182,7 +182,7 @@ func (ws *websocketPeerHandler[T]) onSessionInvalidated(accid primitive.ObjectID } func (ws *websocketPeerHandler[T]) sessionMonitoring() { - all := make(map[primitive.ObjectID]*websocket.Conn) + all := make(map[primitive.ObjectID]*Conn) unauthdata := []byte{0x03, 0xec} unauthdata = append(unauthdata, []byte("unauthorized")...) for { @@ -191,7 +191,7 @@ func (ws *websocketPeerHandler[T]) sessionMonitoring() { all[estVal.accid] = estVal.conn case disVal := <-ws.peerDtorChannel: if c := all[disVal.accid]; c != nil { - c.WriteControl(websocket.CloseMessage, unauthdata, time.Time{}) + c.MakeWriter().WriteControl(websocket.CloseMessage, unauthdata, time.Time{}) delete(all, disVal.accid) } @@ -203,8 +203,8 @@ func (ws *websocketPeerHandler[T]) sessionMonitoring() { } } -func (ws *websocketPeerHandler[T]) upgrade_core(conn *websocket.Conn, accid primitive.ObjectID, sk string) { - go func(c *websocket.Conn, accid primitive.ObjectID, sk string) { +func (ws *websocketPeerHandler[T]) upgrade_core(conn *Conn, accid primitive.ObjectID, sk string) { + go func(c *Conn, accid primitive.ObjectID, sk string) { peer := ws.createPeer(accid) var closeReason string @@ -217,6 +217,7 @@ func (ws *websocketPeerHandler[T]) upgrade_core(conn *websocket.Conn, accid prim }() response := make([]byte, 255) + writer := c.MakeWriter() for { response = response[:5] messageType, r, err := c.NextReader() @@ -277,7 +278,7 @@ func (ws *websocketPeerHandler[T]) upgrade_core(conn *websocket.Conn, accid prim if err != nil { logger.Println("websocket.NewPreparedMessage failed :", err) } else { - c.WritePreparedMessage(pmsg) + writer.WritePreparedMessage(pmsg) } } else { cmd := make([]byte, flag[0]) @@ -346,7 +347,7 @@ func (ws *websocketPeerHandler[T]) upgrade_noauth(w http.ResponseWriter, r *http // alias = accid.Hex() // } - ws.upgrade_core(conn, accid, sk) + ws.upgrade_core(&Conn{innerConn: conn}, accid, sk) } func (ws *websocketPeerHandler[T]) upgrade(w http.ResponseWriter, r *http.Request) { @@ -387,5 +388,5 @@ func (ws *websocketPeerHandler[T]) upgrade(w http.ResponseWriter, r *http.Reques // } else { // alias = authinfo.Account.Hex() // } - ws.upgrade_core(conn, authinfo.Account, sk) + ws.upgrade_core(makeConn(conn), authinfo.Account, sk) }