196 lines
4.3 KiB
Go
196 lines
4.3 KiB
Go
package rpc
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/gob"
|
|
"fmt"
|
|
"reflect"
|
|
|
|
"repositories.action2quare.com/ayo/gocommon/logger"
|
|
"repositories.action2quare.com/ayo/gocommon/wshandler"
|
|
|
|
"go.mongodb.org/mongo-driver/bson"
|
|
"go.mongodb.org/mongo-driver/bson/primitive"
|
|
)
|
|
|
|
var Everybody = primitive.ObjectID([12]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11})
|
|
|
|
func init() {
|
|
gob.Register(bson.M{})
|
|
gob.Register(primitive.ObjectID{})
|
|
gob.Register(primitive.Timestamp{})
|
|
}
|
|
|
|
type RpcCaller struct {
|
|
publish func(bt []byte) error
|
|
}
|
|
|
|
func NewRpcCaller(f func(bt []byte) error) RpcCaller {
|
|
return RpcCaller{
|
|
publish: f,
|
|
}
|
|
}
|
|
|
|
type rpcCallContext struct {
|
|
alias primitive.ObjectID
|
|
publish func(bt []byte) error
|
|
}
|
|
|
|
func (c *RpcCaller) One(alias primitive.ObjectID) rpcCallContext {
|
|
return rpcCallContext{
|
|
alias: alias,
|
|
publish: c.publish,
|
|
}
|
|
}
|
|
|
|
func (c *RpcCaller) Everybody() rpcCallContext {
|
|
return rpcCallContext{
|
|
alias: Everybody,
|
|
publish: c.publish,
|
|
}
|
|
}
|
|
|
|
func IsCallerCalleeMethodMatch[Callee any]() error {
|
|
var caller rpcCallContext
|
|
var callee Callee
|
|
|
|
callerType := reflect.TypeOf(caller)
|
|
calleeType := reflect.TypeOf(callee)
|
|
for i := 0; i < callerType.NumMethod(); i++ {
|
|
callerMethod := callerType.Method(i)
|
|
calleeMethod, ok := calleeType.MethodByName(callerMethod.Name)
|
|
if !ok {
|
|
return fmt.Errorf("method '%s' of '%s' is missing", callerMethod.Name, calleeType.Name())
|
|
}
|
|
|
|
if calleeMethod.Func.Type().NumIn() != callerMethod.Func.Type().NumIn() {
|
|
return fmt.Errorf("method '%s' argument num is not match", callerMethod.Name)
|
|
}
|
|
|
|
if calleeMethod.Func.Type().NumOut() != callerMethod.Func.Type().NumOut() {
|
|
return fmt.Errorf("method '%s' out num is not match", callerMethod.Name)
|
|
}
|
|
|
|
for i := 1; i < calleeMethod.Func.Type().NumIn(); i++ {
|
|
if calleeMethod.Func.Type().In(i) != callerMethod.Func.Type().In(i) {
|
|
return fmt.Errorf("method '%s' argument is not match. %s-%s", callerMethod.Name, calleeMethod.Func.Type().In(i).Name(), callerMethod.Func.Type().In(i).Name())
|
|
}
|
|
}
|
|
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
type fnsig struct {
|
|
FunctionName string `bson:"fn"`
|
|
Args []any `bson:"args"`
|
|
}
|
|
|
|
func Encode[T any](prefix T, fn string, args ...any) ([]byte, error) {
|
|
m := append([]any{
|
|
prefix,
|
|
fn,
|
|
}, args...)
|
|
|
|
buff := new(bytes.Buffer)
|
|
encoder := gob.NewEncoder(buff)
|
|
err := encoder.Encode(m)
|
|
if err != nil {
|
|
logger.Error("rpcCallContext.send err :", err)
|
|
return nil, err
|
|
}
|
|
|
|
return buff.Bytes(), nil
|
|
}
|
|
|
|
func Decode[T any](src []byte) (*T, string, []any, error) {
|
|
var m []any
|
|
decoder := gob.NewDecoder(bytes.NewReader(src))
|
|
if err := decoder.Decode(&m); err != nil {
|
|
logger.Error("RpcCallee.Call err :", err)
|
|
return nil, "", nil, err
|
|
}
|
|
|
|
prfix := m[0].(T)
|
|
fn := m[1].(string)
|
|
|
|
return &prfix, fn, m[2:], nil
|
|
}
|
|
|
|
func decode(src []byte) (string, []any, error) {
|
|
var sig fnsig
|
|
decoder := gob.NewDecoder(bytes.NewReader(src))
|
|
if err := decoder.Decode(&sig); err != nil {
|
|
logger.Error("RpcCallee.Call err :", err)
|
|
return "", nil, err
|
|
}
|
|
|
|
return sig.FunctionName, sig.Args, nil
|
|
}
|
|
|
|
func (c *rpcCallContext) send(fn string, args ...any) error {
|
|
bt, err := Encode(c.alias, fn, args...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return c.publish(bt)
|
|
}
|
|
|
|
type RpcCallee[T any] struct {
|
|
methods map[string]reflect.Method
|
|
create func(*wshandler.Richconn) *T
|
|
}
|
|
|
|
func NewRpcCallee[T any](createReceiverFunc func(*wshandler.Richconn) *T) RpcCallee[T] {
|
|
out := RpcCallee[T]{
|
|
methods: make(map[string]reflect.Method),
|
|
create: createReceiverFunc,
|
|
}
|
|
|
|
var tmp *T
|
|
|
|
tp := reflect.TypeOf(tmp)
|
|
for i := 0; i < tp.NumMethod(); i++ {
|
|
method := tp.Method(i)
|
|
out.methods[method.Name] = method
|
|
}
|
|
return out
|
|
}
|
|
|
|
func (r RpcCallee[T]) Call(rc *wshandler.Richconn, src []byte) error {
|
|
defer func() {
|
|
s := recover()
|
|
if s != nil {
|
|
logger.Error(s)
|
|
}
|
|
}()
|
|
|
|
fn, params, err := decode(src)
|
|
if err != nil {
|
|
logger.Error("RpcCallee.Call err :", err)
|
|
return err
|
|
}
|
|
|
|
method, ok := r.methods[fn]
|
|
if !ok {
|
|
err := fmt.Errorf("method '%s' is missing", fn)
|
|
logger.Error("RpcCallee.Call err :", err)
|
|
return err
|
|
}
|
|
|
|
receiver := r.create(rc)
|
|
args := []reflect.Value{
|
|
reflect.ValueOf(receiver),
|
|
}
|
|
for _, arg := range params {
|
|
args = append(args, reflect.ValueOf(arg))
|
|
}
|
|
|
|
rets := method.Func.Call(args)
|
|
if len(rets) > 0 && rets[len(rets)-1].Interface() != nil {
|
|
return rets[len(rets)-1].Interface().(error)
|
|
}
|
|
return nil
|
|
}
|