package rpc import ( "bytes" "encoding/gob" "fmt" "reflect" "repositories.action2quare.com/ayo/gocommon/logger" "repositories.action2quare.com/ayo/gocommon/wshandler" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" ) var Everybody = primitive.ObjectID([12]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}) func init() { gob.Register(bson.M{}) gob.Register(primitive.ObjectID{}) gob.Register(primitive.Timestamp{}) } type RpcCaller struct { publish func(bt []byte) error } func NewRpcCaller(f func(bt []byte) error) RpcCaller { return RpcCaller{ publish: f, } } type rpcCallContext struct { alias primitive.ObjectID publish func(bt []byte) error } func (c *RpcCaller) One(alias primitive.ObjectID) rpcCallContext { return rpcCallContext{ alias: alias, publish: c.publish, } } func (c *RpcCaller) Everybody() rpcCallContext { return rpcCallContext{ alias: Everybody, publish: c.publish, } } func IsCallerCalleeMethodMatch[Callee any]() error { var caller rpcCallContext var callee Callee callerType := reflect.TypeOf(caller) calleeType := reflect.TypeOf(callee) for i := 0; i < callerType.NumMethod(); i++ { callerMethod := callerType.Method(i) calleeMethod, ok := calleeType.MethodByName(callerMethod.Name) if !ok { return fmt.Errorf("method '%s' of '%s' is missing", callerMethod.Name, calleeType.Name()) } if calleeMethod.Func.Type().NumIn() != callerMethod.Func.Type().NumIn() { return fmt.Errorf("method '%s' argument num is not match", callerMethod.Name) } if calleeMethod.Func.Type().NumOut() != callerMethod.Func.Type().NumOut() { return fmt.Errorf("method '%s' out num is not match", callerMethod.Name) } for i := 1; i < calleeMethod.Func.Type().NumIn(); i++ { if calleeMethod.Func.Type().In(i) != callerMethod.Func.Type().In(i) { return fmt.Errorf("method '%s' argument is not match. %s-%s", callerMethod.Name, calleeMethod.Func.Type().In(i).Name(), callerMethod.Func.Type().In(i).Name()) } } } return nil } type fnsig struct { FunctionName string `bson:"fn"` Args []any `bson:"args"` } func Encode[T any](prefix T, fn string, args ...any) ([]byte, error) { m := append([]any{ prefix, fn, }, args...) buff := new(bytes.Buffer) encoder := gob.NewEncoder(buff) err := encoder.Encode(m) if err != nil { logger.Error("rpcCallContext.send err :", err) return nil, err } return buff.Bytes(), nil } func Decode[T any](src []byte) (*T, string, []any, error) { var m []any decoder := gob.NewDecoder(bytes.NewReader(src)) if err := decoder.Decode(&m); err != nil { logger.Error("RpcCallee.Call err :", err) return nil, "", nil, err } prfix := m[0].(T) fn := m[1].(string) return &prfix, fn, m[2:], nil } func decode(src []byte) (string, []any, error) { var sig fnsig decoder := gob.NewDecoder(bytes.NewReader(src)) if err := decoder.Decode(&sig); err != nil { logger.Error("RpcCallee.Call err :", err) return "", nil, err } return sig.FunctionName, sig.Args, nil } func (c *rpcCallContext) send(fn string, args ...any) error { bt, err := Encode(c.alias, fn, args...) if err != nil { return err } return c.publish(bt) } type RpcCallee[T any] struct { methods map[string]reflect.Method create func(*wshandler.Richconn) *T } func NewRpcCallee[T any](createReceiverFunc func(*wshandler.Richconn) *T) RpcCallee[T] { out := RpcCallee[T]{ methods: make(map[string]reflect.Method), create: createReceiverFunc, } var tmp *T tp := reflect.TypeOf(tmp) for i := 0; i < tp.NumMethod(); i++ { method := tp.Method(i) out.methods[method.Name] = method } return out } func (r RpcCallee[T]) Call(rc *wshandler.Richconn, src []byte) error { defer func() { s := recover() if s != nil { logger.Error(s) } }() fn, params, err := decode(src) if err != nil { logger.Error("RpcCallee.Call err :", err) return err } method, ok := r.methods[fn] if !ok { err := fmt.Errorf("method '%s' is missing", fn) logger.Error("RpcCallee.Call err :", err) return err } receiver := r.create(rc) args := []reflect.Value{ reflect.ValueOf(receiver), } for _, arg := range params { args = append(args, reflect.ValueOf(arg)) } rets := method.Func.Call(args) if len(rets) > 0 && rets[len(rets)-1].Interface() != nil { return rets[len(rets)-1].Interface().(error) } return nil }