Files
tavern/core/rpc/connrpc.go
2023-05-24 16:10:00 +09:00

196 lines
4.3 KiB
Go

package rpc
import (
"bytes"
"encoding/gob"
"fmt"
"reflect"
"repositories.action2quare.com/ayo/gocommon/logger"
"repositories.action2quare.com/ayo/gocommon/wshandler"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
)
var Everybody = primitive.ObjectID([12]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11})
func init() {
gob.Register(bson.M{})
gob.Register(primitive.ObjectID{})
gob.Register(primitive.Timestamp{})
}
type RpcCaller struct {
publish func(bt []byte) error
}
func NewRpcCaller(f func(bt []byte) error) RpcCaller {
return RpcCaller{
publish: f,
}
}
type rpcCallContext struct {
alias primitive.ObjectID
publish func(bt []byte) error
}
func (c *RpcCaller) One(alias primitive.ObjectID) rpcCallContext {
return rpcCallContext{
alias: alias,
publish: c.publish,
}
}
func (c *RpcCaller) Everybody() rpcCallContext {
return rpcCallContext{
alias: Everybody,
publish: c.publish,
}
}
func IsCallerCalleeMethodMatch[Callee any]() error {
var caller rpcCallContext
var callee Callee
callerType := reflect.TypeOf(caller)
calleeType := reflect.TypeOf(callee)
for i := 0; i < callerType.NumMethod(); i++ {
callerMethod := callerType.Method(i)
calleeMethod, ok := calleeType.MethodByName(callerMethod.Name)
if !ok {
return fmt.Errorf("method '%s' of '%s' is missing", callerMethod.Name, calleeType.Name())
}
if calleeMethod.Func.Type().NumIn() != callerMethod.Func.Type().NumIn() {
return fmt.Errorf("method '%s' argument num is not match", callerMethod.Name)
}
if calleeMethod.Func.Type().NumOut() != callerMethod.Func.Type().NumOut() {
return fmt.Errorf("method '%s' out num is not match", callerMethod.Name)
}
for i := 1; i < calleeMethod.Func.Type().NumIn(); i++ {
if calleeMethod.Func.Type().In(i) != callerMethod.Func.Type().In(i) {
return fmt.Errorf("method '%s' argument is not match. %s-%s", callerMethod.Name, calleeMethod.Func.Type().In(i).Name(), callerMethod.Func.Type().In(i).Name())
}
}
}
return nil
}
type fnsig struct {
FunctionName string `bson:"fn"`
Args []any `bson:"args"`
}
func Encode[T any](prefix T, fn string, args ...any) ([]byte, error) {
m := append([]any{
prefix,
fn,
}, args...)
buff := new(bytes.Buffer)
encoder := gob.NewEncoder(buff)
err := encoder.Encode(m)
if err != nil {
logger.Error("rpcCallContext.send err :", err)
return nil, err
}
return buff.Bytes(), nil
}
func Decode[T any](src []byte) (*T, string, []any, error) {
var m []any
decoder := gob.NewDecoder(bytes.NewReader(src))
if err := decoder.Decode(&m); err != nil {
logger.Error("RpcCallee.Call err :", err)
return nil, "", nil, err
}
prfix := m[0].(T)
fn := m[1].(string)
return &prfix, fn, m[2:], nil
}
func decode(src []byte) (string, []any, error) {
var sig fnsig
decoder := gob.NewDecoder(bytes.NewReader(src))
if err := decoder.Decode(&sig); err != nil {
logger.Error("RpcCallee.Call err :", err)
return "", nil, err
}
return sig.FunctionName, sig.Args, nil
}
func (c *rpcCallContext) send(fn string, args ...any) error {
bt, err := Encode(c.alias, fn, args...)
if err != nil {
return err
}
return c.publish(bt)
}
type RpcCallee[T any] struct {
methods map[string]reflect.Method
create func(*wshandler.Richconn) *T
}
func NewRpcCallee[T any](createReceiverFunc func(*wshandler.Richconn) *T) RpcCallee[T] {
out := RpcCallee[T]{
methods: make(map[string]reflect.Method),
create: createReceiverFunc,
}
var tmp *T
tp := reflect.TypeOf(tmp)
for i := 0; i < tp.NumMethod(); i++ {
method := tp.Method(i)
out.methods[method.Name] = method
}
return out
}
func (r RpcCallee[T]) Call(rc *wshandler.Richconn, src []byte) error {
defer func() {
s := recover()
if s != nil {
logger.Error(s)
}
}()
fn, params, err := decode(src)
if err != nil {
logger.Error("RpcCallee.Call err :", err)
return err
}
method, ok := r.methods[fn]
if !ok {
err := fmt.Errorf("method '%s' is missing", fn)
logger.Error("RpcCallee.Call err :", err)
return err
}
receiver := r.create(rc)
args := []reflect.Value{
reflect.ValueOf(receiver),
}
for _, arg := range params {
args = append(args, reflect.ValueOf(arg))
}
rets := method.Func.Call(args)
if len(rets) > 0 && rets[len(rets)-1].Interface() != nil {
return rets[len(rets)-1].Interface().(error)
}
return nil
}