package wshandler import ( "encoding/hex" "encoding/json" "fmt" "io" "math/rand" "net/http" "reflect" "strings" "time" "go.mongodb.org/mongo-driver/bson/primitive" "repositories.action2quare.com/ayo/gocommon/logger" "repositories.action2quare.com/ayo/gocommon/session" "github.com/gorilla/websocket" ) type WebsocketPeerHandler interface { RegisterHandlers(serveMux *http.ServeMux, prefix string) error } type connEstChannelValue struct { accid primitive.ObjectID conn *websocket.Conn } type connDisChannelValue struct { accid primitive.ObjectID closed bool } type websocketPeerHandler[T PeerInterface] struct { methods map[string]peerApiFuncType[T] createPeer func(primitive.ObjectID) T sessionConsumer session.Consumer connEstChannel chan connEstChannelValue connDisChannel chan connDisChannelValue } type PeerInterface interface { ClientDisconnected(string) ClientConnected(*websocket.Conn) } type peerApiFuncType[T PeerInterface] func(T, io.Reader) (any, error) type websocketPeerApiHandler[T PeerInterface] struct { methods map[string]peerApiFuncType[T] originalReceiverName string } func (hc *websocketPeerHandler[T]) call(recv T, funcname string, r io.Reader) (v any, e error) { defer func() { r := recover() if r != nil { logger.Error(r) e = fmt.Errorf("%v", r) } }() if found := hc.methods[funcname]; found != nil { return found(recv, r) } return nil, fmt.Errorf("api is not found : %s", funcname) } func makeWebsocketPeerApiHandler[T PeerInterface]() websocketPeerApiHandler[T] { methods := make(map[string]peerApiFuncType[T]) var archetype T tp := reflect.TypeOf(archetype) for i := 0; i < tp.NumMethod(); i++ { method := tp.Method(i) if method.Type.In(0) != tp { continue } if method.Name == ClientDisconnected { continue } var intypes []reflect.Type for i := 1; i < method.Type.NumIn(); i++ { intypes = append(intypes, method.Type.In(i)) } var outconv func([]reflect.Value) (any, error) if method.Type.NumOut() == 0 { outconv = func([]reflect.Value) (any, error) { return nil, nil } } else if method.Type.NumOut() == 1 { if method.Type.Out(0).Implements(reflect.TypeOf((*error)(nil)).Elem()) { outconv = func(out []reflect.Value) (any, error) { if out[0].Interface() == nil { return nil, nil } return nil, out[0].Interface().(error) } } else { outconv = func(out []reflect.Value) (any, error) { return out[0].Interface(), nil } } } else if method.Type.NumOut() == 2 && method.Type.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) { outconv = func(out []reflect.Value) (any, error) { if out[1].Interface() == nil { return out[0].Interface(), nil } return out[0].Interface(), out[1].Interface().(error) } } methods[method.Name] = func(recv T, r io.Reader) (any, error) { decoder := json.NewDecoder(r) inargs := make([]any, len(intypes)) for i, intype := range intypes { zerovalueptr := reflect.New(intype) inargs[i] = zerovalueptr.Interface() } err := decoder.Decode(&inargs) if err != nil { return nil, err } reflectargs := make([]reflect.Value, 0, len(inargs)+1) reflectargs = append(reflectargs, reflect.ValueOf(recv)) for _, p := range inargs { reflectargs = append(reflectargs, reflect.ValueOf(p).Elem()) } return outconv(method.Func.Call(reflectargs)) } } return websocketPeerApiHandler[T]{ methods: methods, originalReceiverName: tp.Elem().Name(), } } func NewWebsocketPeerHandler[T PeerInterface](consumer session.Consumer, creator func(primitive.ObjectID) T) WebsocketPeerHandler { methods := make(map[string]peerApiFuncType[T]) receiver := makeWebsocketPeerApiHandler[T]() for k, v := range receiver.methods { logger.Printf("ws api registered : %s.%s\n", receiver.originalReceiverName, k) methods[k] = v } wsh := &websocketPeerHandler[T]{ sessionConsumer: consumer, methods: methods, createPeer: creator, connEstChannel: make(chan connEstChannelValue), connDisChannel: make(chan connDisChannelValue), } consumer.RegisterOnSessionInvalidated(wsh.onSessionInvalidated) return wsh } func (ws *websocketPeerHandler[T]) onSessionInvalidated(accid primitive.ObjectID) { ws.connDisChannel <- connDisChannelValue{ accid: accid, closed: false, } } func (ws *websocketPeerHandler[T]) RegisterHandlers(serveMux *http.ServeMux, prefix string) error { go ws.sessionMonitoring() if *noAuthFlag { serveMux.HandleFunc(prefix, ws.upgrade_nosession) } else { serveMux.HandleFunc(prefix, ws.upgrade) } return nil } func (ws *websocketPeerHandler[T]) sessionMonitoring() { all := make(map[primitive.ObjectID]*websocket.Conn) for { select { case estVal := <-ws.connEstChannel: all[estVal.accid] = estVal.conn case disVal := <-ws.connDisChannel: if disVal.closed { delete(all, disVal.accid) } else if c := all[disVal.accid]; c != nil { c.Close() delete(all, disVal.accid) } } } } func (ws *websocketPeerHandler[T]) upgrade_core(conn *websocket.Conn, accid primitive.ObjectID, nonce uint32) { go func(c *websocket.Conn, accid primitive.ObjectID) { peer := ws.createPeer(accid) var closeReason string peer.ClientConnected(conn) ws.connEstChannel <- connEstChannelValue{accid: accid, conn: conn} defer func() { ws.connDisChannel <- connDisChannelValue{accid: accid, closed: true} peer.ClientDisconnected(closeReason) }() response := make([]byte, 255) for { response = response[:5] messageType, r, err := c.NextReader() if err != nil { if ce, ok := err.(*websocket.CloseError); ok { closeReason = ce.Text } c.Close() break } if messageType == websocket.CloseMessage { closeMsg, _ := io.ReadAll(r) closeReason = string(closeMsg) break } if messageType == websocket.BinaryMessage { var flag [1]byte r.Read(flag[:]) if flag[0] == 0xff { // nonce r.Read(response[1:5]) var size [1]byte r.Read(size[:]) cmd := make([]byte, size[0]) r.Read(cmd) result, err := ws.call(peer, string(cmd), r) if err != nil { response[0] = 21 // 21 : Negative Ack response = append(response, []byte(err.Error())...) } else { response[0] = 6 // 6 : Acknowledgement switch result := result.(type) { case string: response = append(response, []byte(result)...) case int8, int16, int32, int64, uint8, uint16, uint32, uint64: response = append(response, []byte(fmt.Sprintf("%d", result))...) case float32, float64: response = append(response, []byte(fmt.Sprintf("%f", result))...) case []byte: response = append(response, result...) default: j, _ := json.Marshal(result) response = append(response, j...) } } c.WriteMessage(websocket.BinaryMessage, response) } else { cmd := make([]byte, flag[0]) r.Read(cmd) ws.call(peer, string(cmd), r) } } } }(conn, accid) } func (ws *websocketPeerHandler[T]) upgrade_nosession(w http.ResponseWriter, r *http.Request) { // 클라이언트 접속 defer func() { s := recover() if s != nil { logger.Error(s) } io.Copy(io.Discard, r.Body) r.Body.Close() }() auth := strings.Split(r.Header.Get("Authorization"), " ") if len(auth) != 2 { w.WriteHeader(http.StatusBadRequest) return } temp, err := hex.DecodeString(auth[1]) if err != nil { w.WriteHeader(http.StatusBadRequest) return } if len(temp) != len(primitive.NilObjectID) { w.WriteHeader(http.StatusBadRequest) return } raw := (*[12]byte)(temp) accid := primitive.ObjectID(*raw) var upgrader = websocket.Upgrader{} // use default options conn, err := upgrader.Upgrade(w, r, nil) if err != nil { w.WriteHeader(http.StatusInternalServerError) return } // var alias string // if v := r.Header.Get("AS-X-ALIAS"); len(v) > 0 { // vt, _ := base64.StdEncoding.DecodeString(v) // alias = string(vt) // } else { // alias = accid.Hex() // } nonce := rand.New(rand.NewSource(time.Now().UnixNano())).Uint32() ws.upgrade_core(conn, accid, nonce) } func (ws *websocketPeerHandler[T]) upgrade(w http.ResponseWriter, r *http.Request) { // 클라이언트 접속 defer func() { s := recover() if s != nil { logger.Error(s) } io.Copy(io.Discard, r.Body) r.Body.Close() }() sk := r.Header.Get("AS-X-SESSION") logger.Println("WebsocketHandler.upgrade sk :", sk) authinfo, err := ws.sessionConsumer.Query(sk) if err != nil { w.WriteHeader(http.StatusInternalServerError) logger.Error("authorize query failed :", err) return } var upgrader = websocket.Upgrader{} // use default options conn, err := upgrader.Upgrade(w, r, nil) if err != nil { w.WriteHeader(http.StatusInternalServerError) return } // var alias string // if v := r.Header.Get("AS-X-ALIAS"); len(v) > 0 { // vt, _ := base64.StdEncoding.DecodeString(v) // alias = string(vt) // } else { // alias = authinfo.Account.Hex() // } nonce := rand.New(rand.NewSource(time.Now().UnixNano())).Uint32() ws.upgrade_core(conn, authinfo.Account, nonce) }