393 lines
9.5 KiB
Go
393 lines
9.5 KiB
Go
package wshandler
|
|
|
|
import (
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"reflect"
|
|
"strings"
|
|
"time"
|
|
|
|
"go.mongodb.org/mongo-driver/bson/primitive"
|
|
"repositories.action2quare.com/ayo/gocommon/logger"
|
|
"repositories.action2quare.com/ayo/gocommon/session"
|
|
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
type WebsocketPeerHandler interface {
|
|
RegisterHandlers(serveMux *http.ServeMux, prefix string) error
|
|
}
|
|
|
|
type peerCtorChannelValue struct {
|
|
accid primitive.ObjectID
|
|
conn *Conn
|
|
}
|
|
|
|
type peerDtorChannelValue struct {
|
|
accid primitive.ObjectID
|
|
sk string
|
|
}
|
|
|
|
type websocketPeerHandler[T PeerInterface] struct {
|
|
methods map[string]peerApiFuncType[T]
|
|
createPeer func(primitive.ObjectID) T
|
|
sessionConsumer session.Consumer
|
|
|
|
peerCtorChannel chan peerCtorChannelValue
|
|
peerDtorChannel chan peerDtorChannelValue
|
|
}
|
|
|
|
type PeerInterface interface {
|
|
ClientDisconnected(string)
|
|
ClientConnected(*Conn)
|
|
}
|
|
type peerApiFuncType[T PeerInterface] func(T, io.Reader) (any, error)
|
|
|
|
type websocketPeerApiHandler[T PeerInterface] struct {
|
|
methods map[string]peerApiFuncType[T]
|
|
originalReceiverName string
|
|
}
|
|
|
|
func (hc *websocketPeerHandler[T]) call(recv T, funcname string, r io.Reader) (v any, e error) {
|
|
defer func() {
|
|
r := recover()
|
|
if r != nil {
|
|
logger.Error(r)
|
|
e = fmt.Errorf("%v", r)
|
|
}
|
|
}()
|
|
|
|
if found := hc.methods[funcname]; found != nil {
|
|
return found(recv, r)
|
|
}
|
|
|
|
return nil, fmt.Errorf("api is not found : %s", funcname)
|
|
}
|
|
|
|
func makeWebsocketPeerApiHandler[T PeerInterface]() websocketPeerApiHandler[T] {
|
|
methods := make(map[string]peerApiFuncType[T])
|
|
|
|
var archetype T
|
|
tp := reflect.TypeOf(archetype)
|
|
|
|
for i := 0; i < tp.NumMethod(); i++ {
|
|
method := tp.Method(i)
|
|
if method.Type.In(0) != tp {
|
|
continue
|
|
}
|
|
|
|
if method.Name == ClientDisconnected {
|
|
continue
|
|
}
|
|
|
|
var intypes []reflect.Type
|
|
for i := 1; i < method.Type.NumIn(); i++ {
|
|
intypes = append(intypes, method.Type.In(i))
|
|
}
|
|
|
|
var outconv func([]reflect.Value) (any, error)
|
|
if method.Type.NumOut() == 0 {
|
|
outconv = func([]reflect.Value) (any, error) { return nil, nil }
|
|
} else if method.Type.NumOut() == 1 {
|
|
if method.Type.Out(0).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
|
|
outconv = func(out []reflect.Value) (any, error) {
|
|
if out[0].Interface() == nil {
|
|
return nil, nil
|
|
}
|
|
return nil, out[0].Interface().(error)
|
|
}
|
|
} else {
|
|
outconv = func(out []reflect.Value) (any, error) {
|
|
return out[0].Interface(), nil
|
|
}
|
|
}
|
|
} else if method.Type.NumOut() == 2 && method.Type.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
|
|
outconv = func(out []reflect.Value) (any, error) {
|
|
if out[1].Interface() == nil {
|
|
return out[0].Interface(), nil
|
|
}
|
|
return out[0].Interface(), out[1].Interface().(error)
|
|
}
|
|
}
|
|
|
|
methods[method.Name] = func(recv T, r io.Reader) (any, error) {
|
|
decoder := json.NewDecoder(r)
|
|
inargs := make([]any, len(intypes))
|
|
|
|
for i, intype := range intypes {
|
|
zerovalueptr := reflect.New(intype)
|
|
inargs[i] = zerovalueptr.Interface()
|
|
}
|
|
|
|
err := decoder.Decode(&inargs)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
reflectargs := make([]reflect.Value, 0, len(inargs)+1)
|
|
reflectargs = append(reflectargs, reflect.ValueOf(recv))
|
|
for _, p := range inargs {
|
|
reflectargs = append(reflectargs, reflect.ValueOf(p).Elem())
|
|
}
|
|
|
|
return outconv(method.Func.Call(reflectargs))
|
|
}
|
|
}
|
|
|
|
return websocketPeerApiHandler[T]{
|
|
methods: methods,
|
|
originalReceiverName: tp.Elem().Name(),
|
|
}
|
|
}
|
|
|
|
func NewWebsocketPeerHandler[T PeerInterface](consumer session.Consumer, creator func(primitive.ObjectID) T) WebsocketPeerHandler {
|
|
methods := make(map[string]peerApiFuncType[T])
|
|
receiver := makeWebsocketPeerApiHandler[T]()
|
|
|
|
for k, v := range receiver.methods {
|
|
logger.Printf("ws api registered : %s.%s\n", receiver.originalReceiverName, k)
|
|
methods[k] = v
|
|
}
|
|
|
|
wsh := &websocketPeerHandler[T]{
|
|
sessionConsumer: consumer,
|
|
methods: methods,
|
|
createPeer: creator,
|
|
peerCtorChannel: make(chan peerCtorChannelValue),
|
|
peerDtorChannel: make(chan peerDtorChannelValue),
|
|
}
|
|
|
|
consumer.RegisterOnSessionInvalidated(wsh.onSessionInvalidated)
|
|
return wsh
|
|
}
|
|
|
|
func (ws *websocketPeerHandler[T]) RegisterHandlers(serveMux *http.ServeMux, prefix string) error {
|
|
if *noAuthFlag {
|
|
serveMux.HandleFunc(prefix, ws.upgrade_noauth)
|
|
} else {
|
|
serveMux.HandleFunc(prefix, ws.upgrade)
|
|
}
|
|
go ws.sessionMonitoring()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (ws *websocketPeerHandler[T]) onSessionInvalidated(accid primitive.ObjectID) {
|
|
ws.peerDtorChannel <- peerDtorChannelValue{
|
|
accid: accid,
|
|
}
|
|
}
|
|
|
|
func (ws *websocketPeerHandler[T]) sessionMonitoring() {
|
|
all := make(map[primitive.ObjectID]*Conn)
|
|
unauthdata := []byte{0x03, 0xec}
|
|
unauthdata = append(unauthdata, []byte("unauthorized")...)
|
|
for {
|
|
select {
|
|
case estVal := <-ws.peerCtorChannel:
|
|
all[estVal.accid] = estVal.conn
|
|
case disVal := <-ws.peerDtorChannel:
|
|
if c := all[disVal.accid]; c != nil {
|
|
c.MakeWriter().WriteControl(websocket.CloseMessage, unauthdata, time.Time{})
|
|
delete(all, disVal.accid)
|
|
}
|
|
|
|
if len(disVal.sk) > 0 {
|
|
ws.sessionConsumer.Revoke(disVal.sk)
|
|
delete(all, disVal.accid)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (ws *websocketPeerHandler[T]) upgrade_core(conn *Conn, accid primitive.ObjectID, sk string) {
|
|
go func(c *Conn, accid primitive.ObjectID, sk string) {
|
|
peer := ws.createPeer(accid)
|
|
var closeReason string
|
|
|
|
peer.ClientConnected(conn)
|
|
ws.peerCtorChannel <- peerCtorChannelValue{accid: accid, conn: conn}
|
|
|
|
defer func() {
|
|
ws.peerDtorChannel <- peerDtorChannelValue{accid: accid, sk: sk}
|
|
peer.ClientDisconnected(closeReason)
|
|
}()
|
|
|
|
response := make([]byte, 255)
|
|
writer := c.MakeWriter()
|
|
for {
|
|
response = response[:5]
|
|
messageType, r, err := c.NextReader()
|
|
if err != nil {
|
|
if ce, ok := err.(*websocket.CloseError); ok {
|
|
closeReason = ce.Text
|
|
}
|
|
c.Close()
|
|
break
|
|
}
|
|
|
|
if messageType == websocket.CloseMessage {
|
|
closeMsg, _ := io.ReadAll(r)
|
|
closeReason = string(closeMsg)
|
|
break
|
|
}
|
|
|
|
if messageType == websocket.BinaryMessage {
|
|
var flag [1]byte
|
|
r.Read(flag[:])
|
|
if flag[0] == 0xff {
|
|
// nonce
|
|
r.Read(response[1:5])
|
|
|
|
var size [1]byte
|
|
r.Read(size[:])
|
|
|
|
cmd := make([]byte, size[0])
|
|
r.Read(cmd)
|
|
|
|
result, err := ws.call(peer, string(cmd), r)
|
|
|
|
if err != nil {
|
|
response[0] = 21 // 21 : Negative Ack
|
|
response = append(response, []byte(err.Error())...)
|
|
} else {
|
|
response[0] = 6 // 6 : Acknowledgement
|
|
|
|
switch result := result.(type) {
|
|
case string:
|
|
response = append(response, []byte(result)...)
|
|
|
|
case int8, int16, int32, int64, uint8, uint16, uint32, uint64:
|
|
response = append(response, []byte(fmt.Sprintf("%d", result))...)
|
|
|
|
case float32, float64:
|
|
response = append(response, []byte(fmt.Sprintf("%f", result))...)
|
|
|
|
case []byte:
|
|
response = append(response, result...)
|
|
|
|
default:
|
|
j, _ := json.Marshal(result)
|
|
response = append(response, j...)
|
|
}
|
|
}
|
|
pmsg, err := websocket.NewPreparedMessage(websocket.BinaryMessage, response)
|
|
if err != nil {
|
|
logger.Println("websocket.NewPreparedMessage failed :", err)
|
|
} else {
|
|
writer.WritePreparedMessage(pmsg)
|
|
}
|
|
} else {
|
|
cmd := make([]byte, flag[0])
|
|
r.Read(cmd)
|
|
ws.call(peer, string(cmd), r)
|
|
}
|
|
}
|
|
}
|
|
}(conn, accid, sk)
|
|
}
|
|
|
|
func (ws *websocketPeerHandler[T]) upgrade_noauth(w http.ResponseWriter, r *http.Request) {
|
|
// 클라이언트 접속
|
|
defer func() {
|
|
s := recover()
|
|
if s != nil {
|
|
logger.Error(s)
|
|
}
|
|
io.Copy(io.Discard, r.Body)
|
|
r.Body.Close()
|
|
}()
|
|
|
|
sk := r.Header.Get("AS-X-SESSION")
|
|
var accid primitive.ObjectID
|
|
if len(sk) > 0 {
|
|
authinfo, err := ws.sessionConsumer.Query(sk)
|
|
if err == nil {
|
|
accid = authinfo.Account
|
|
}
|
|
}
|
|
|
|
if accid.IsZero() {
|
|
auth := strings.Split(r.Header.Get("Authorization"), " ")
|
|
if len(auth) != 2 {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
temp, err := hex.DecodeString(auth[1])
|
|
if err != nil {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
if len(temp) != len(primitive.NilObjectID) {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
raw := (*[12]byte)(temp)
|
|
accid = primitive.ObjectID(*raw)
|
|
}
|
|
|
|
var upgrader = websocket.Upgrader{} // use default options
|
|
conn, err := upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// var alias string
|
|
// if v := r.Header.Get("AS-X-ALIAS"); len(v) > 0 {
|
|
// vt, _ := base64.StdEncoding.DecodeString(v)
|
|
// alias = string(vt)
|
|
// } else {
|
|
// alias = accid.Hex()
|
|
// }
|
|
|
|
ws.upgrade_core(&Conn{innerConn: conn}, accid, sk)
|
|
}
|
|
|
|
func (ws *websocketPeerHandler[T]) upgrade(w http.ResponseWriter, r *http.Request) {
|
|
// 클라이언트 접속
|
|
defer func() {
|
|
s := recover()
|
|
if s != nil {
|
|
logger.Error(s)
|
|
}
|
|
io.Copy(io.Discard, r.Body)
|
|
r.Body.Close()
|
|
}()
|
|
|
|
sk := r.Header.Get("AS-X-SESSION")
|
|
authinfo, err := ws.sessionConsumer.Query(sk)
|
|
if err != nil {
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
logger.Error("authorize query failed :", err)
|
|
return
|
|
}
|
|
|
|
if authinfo.Account.IsZero() || authinfo.Invalidated() {
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
var upgrader = websocket.Upgrader{} // use default options
|
|
conn, err := upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// var alias string
|
|
// if v := r.Header.Get("AS-X-ALIAS"); len(v) > 0 {
|
|
// vt, _ := base64.StdEncoding.DecodeString(v)
|
|
// alias = string(vt)
|
|
// } else {
|
|
// alias = authinfo.Account.Hex()
|
|
// }
|
|
ws.upgrade_core(makeConn(conn), authinfo.Account, sk)
|
|
}
|