package rpc import ( "bytes" "context" "encoding/gob" "errors" "fmt" "reflect" "runtime" "strings" "time" "github.com/go-redis/redis/v8" "go.mongodb.org/mongo-driver/bson/primitive" "repositories.action2quare.com/ayo/gocommon/logger" ) type Receiver interface { TargetExists(primitive.ObjectID) bool } type receiverManifest struct { r Receiver methods map[string]reflect.Method } type rpcEngine struct { receivers map[string]receiverManifest publish func([]byte) error } var engine = rpcEngine{ receivers: make(map[string]receiverManifest), } func RegistReceiver(ptr Receiver) { rname := reflect.TypeOf(ptr).Elem().Name() rname = fmt.Sprintf("(*%s)", rname) methods := make(map[string]reflect.Method) for i := 0; i < reflect.TypeOf(ptr).NumMethod(); i++ { method := reflect.TypeOf(ptr).Method(i) methods[method.Name] = method } engine.receivers[rname] = receiverManifest{ r: ptr, methods: methods, } } func Start(ctx context.Context, redisClient *redis.Client) { if engine.publish != nil { return } pubsubName := primitive.NewObjectID().Hex()[6:] engine.publish = func(s []byte) error { _, err := redisClient.Publish(ctx, pubsubName, s).Result() return err } go engine.loop(ctx, redisClient, pubsubName) } func (re *rpcEngine) callFromMessage(msg *redis.Message) { defer func() { r := recover() if r != nil { logger.Error(r) } }() encoded := []byte(msg.Payload) var target primitive.ObjectID copy(target[:], encoded[:12]) encoded = encoded[12:] for i, c := range encoded { if c == ')' { if manifest, ok := re.receivers[string(encoded[:i+1])]; ok { // 리시버 찾음 if manifest.r.TargetExists(target) { // 이 리시버가 타겟을 가지고 있음 encoded = encoded[i+1:] decoder := gob.NewDecoder(bytes.NewBuffer(encoded)) var params []any if decoder.Decode(¶ms) == nil { method := manifest.methods[params[0].(string)] args := []reflect.Value{ reflect.ValueOf(manifest.r), } for _, arg := range params[1:] { args = append(args, reflect.ValueOf(arg)) } method.Func.Call(args) } } } } } } func (re *rpcEngine) loop(ctx context.Context, redisClient *redis.Client, chanName string) { defer func() { r := recover() if r != nil { logger.Error(r) } }() pubsub := redisClient.Subscribe(ctx, chanName) for { if ctx.Err() != nil { return } if pubsub == nil { pubsub = redisClient.Subscribe(ctx, chanName) } msg, err := pubsub.ReceiveMessage(ctx) if err != nil { if err == redis.ErrClosed { time.Sleep(time.Second) } pubsub = nil } else { re.callFromMessage(msg) } } } var errNoReceiver = errors.New("no receiver") type RpcCallContext struct { r Receiver t primitive.ObjectID } var ErrCanExecuteHere = errors.New("go ahead") func MakeCallContext(r Receiver) RpcCallContext { return RpcCallContext{r: r} } func (c *RpcCallContext) Target(t primitive.ObjectID) { c.t = t } func (c *RpcCallContext) Call(args ...any) error { if c.r.TargetExists(c.t) { // 여기 있네? return ErrCanExecuteHere } pc := make([]uintptr, 1) n := runtime.Callers(2, pc[:]) if n < 1 { return errNoReceiver } frame, _ := runtime.CallersFrames(pc).Next() prf := strings.Split(frame.Function, ".") rname := prf[1] funcname := prf[2] serialized, err := encode(c.t, rname, funcname, args...) if err != nil { return err } return engine.publish(serialized) } func encode(target primitive.ObjectID, receiver string, funcname string, args ...any) ([]byte, error) { buff := new(bytes.Buffer) // 타겟을 가장 먼저 기록 buff.Write(target[:]) // receiver buff.Write([]byte(receiver)) // 다음 call context 기록 m := append([]any{funcname}, args...) encoder := gob.NewEncoder(buff) err := encoder.Encode(m) if err != nil { logger.Error("rpcCallContext.send err :", err) return nil, err } return buff.Bytes(), nil }