package gocommon import ( "bytes" "context" "encoding/gob" "encoding/json" "errors" "fmt" "io" "math" "net" "net/http" "net/url" "os" "os/signal" "path" "reflect" "runtime" "strconv" "strings" "sync" "sync/atomic" "syscall" "time" "unsafe" "repositories.action2quare.com/ayo/gocommon/flagx" "repositories.action2quare.com/ayo/gocommon/logger" "github.com/pires/go-proxyproto" "go.mongodb.org/mongo-driver/bson/primitive" ) func init() { gob.Register(map[string]any{}) gob.Register(primitive.A{}) gob.Register(primitive.M{}) gob.Register(primitive.D{}) gob.Register(primitive.ObjectID{}) gob.Register([]any{}) } type ServerMuxInterface interface { http.Handler HandleFunc(pattern string, handler func(http.ResponseWriter, *http.Request)) Handle(pattern string, handler http.Handler) } const ( // HTTPStatusReloginRequired : http status를 이걸 받으면 클라이언트는 로그아웃하고 로그인 화면으로 돌아가야 한다. HTTPStatusReloginRequired = 599 // HTTPStatusReloginRequiredDupID : HTTPStatusReloginRequiredDupID = 598 // HTTPStatusPlayerBlocked : HTTPStatusPlayerBlocked = 597 ) type ShutdownFlag int32 const ( ShutdownFlagRunning = ShutdownFlag(0) ShutdownFlagTerminating = ShutdownFlag(1) ShutdownFlagRestarting = ShutdownFlag(2) ShutdownFlagIdle = ShutdownFlag(3) ) type ErrorWithStatus struct { StatusCode int } func (e ErrorWithStatus) Error() string { return fmt.Sprintf("%d", e.StatusCode) } type RpcReturnTypeInterface interface { Value() any Error() error Serialize(http.ResponseWriter) error } type functionCallContext struct { Method string Args []interface{} } // Server : type Server struct { httpserver *http.Server interrupt chan os.Signal } var healthcheckcounter = int64(0) func healthCheckHandler(w http.ResponseWriter, r *http.Request) { defer func() { io.Copy(io.Discard, r.Body) r.Body.Close() }() // 한번이라도 들어오면 lb에 붙어있다는 뜻 if t := atomic.AddInt64(&healthcheckcounter, 1); t < 0 { logger.Println("healthCheckHandler return StatusServiceUnavailable :", t) w.WriteHeader(http.StatusServiceUnavailable) } } func welcomeHandler(w http.ResponseWriter, r *http.Request) { defer func() { io.Copy(io.Discard, r.Body) r.Body.Close() }() w.Write([]byte("welcome")) } var tlsflag = flagx.String("tls", "", "") var portptr = flagx.Int("port", 80, "") func isTlsEnabled(fileout ...*string) bool { if len(*tlsflag) == 0 { return false } if strings.HasSuffix(*tlsflag, "/") { return false } crtfile := *tlsflag + ".crt" if _, err := os.Stat(crtfile); os.IsNotExist(err) { return false } keyfile := *tlsflag + ".key" if _, err := os.Stat(keyfile); os.IsNotExist(err) { return false } if len(fileout) > 0 { *fileout[0] = crtfile } if len(fileout) > 1 { *fileout[1] = keyfile } return true } func registUnhandledPattern(serveMux ServerMuxInterface) { defer func() { recover() }() serveMux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { logger.Println("page not found :", r.URL.Path) w.WriteHeader(http.StatusNotFound) }) } // NewHTTPServer : func NewHTTPServerWithPort(serveMux ServerMuxInterface, port int) *Server { if isTlsEnabled() && port == 80 { port = 443 } addr := fmt.Sprintf(":%d", port) serveMux.HandleFunc(MakeHttpHandlerPattern("welcome"), welcomeHandler) serveMux.HandleFunc(MakeHttpHandlerPattern("lb_health_chceck"), healthCheckHandler) serveMux.HandleFunc(MakeHttpHandlerPattern("lb_health_check"), healthCheckHandler) registUnhandledPattern(serveMux) server := &Server{ httpserver: &http.Server{ Addr: addr, Handler: serveMux, MaxHeaderBytes: 2 << 20, // 2 MB }, } server.httpserver.SetKeepAlivesEnabled(true) return server } func NewHTTPServer(serveMux ServerMuxInterface) *Server { // 시작시 자동으로 enable됨 if isTlsEnabled() && *portptr == 80 { *portptr = 443 } return NewHTTPServerWithPort(serveMux, *portptr) } // Shutdown : func (server *Server) shutdown() { logger.Println("server.shutdown()") signal.Stop(server.interrupt) if t := atomic.LoadInt64(&healthcheckcounter); t > 0 { logger.Println("http server shutdown. healthcheckcounter :", t) atomic.StoreInt64(&healthcheckcounter, math.MinInt64) for cnt := 0; cnt < 100; { next := atomic.LoadInt64(&healthcheckcounter) if next == t { cnt++ } else { t = next cnt = 0 } time.Sleep(100 * time.Millisecond) } logger.Println("http server shutdown. healthcheck completed") } else { logger.Println("http server shutdown. no lb") } if server.httpserver != nil { server.httpserver.Shutdown(context.Background()) server.httpserver = nil } } func (server *Server) Stop() { if server.interrupt != nil { server.interrupt <- os.Interrupt } else { server.shutdown() } } // Start : func (server *Server) Start(name ...string) error { if len(name) == 0 { exepath, _ := os.Executable() name = []string{path.Base(exepath)} } if server.httpserver != nil { ln, r := net.Listen("tcp", server.httpserver.Addr) if r != nil { return r } server.interrupt = make(chan os.Signal, 1) signal.Notify(server.interrupt, os.Interrupt, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) go func() { c := <-server.interrupt logger.Println("interrupt!!!!!!!! :", c.String()) server.shutdown() }() proxyListener := &proxyproto.Listener{ Listener: ln, ReadHeaderTimeout: 10 * time.Second, } defer proxyListener.Close() var err error var crtfile string var keyfile string if isTlsEnabled(&crtfile, &keyfile) { logger.Println("tls enabled :", crtfile, keyfile) err = server.httpserver.ServeTLS(proxyListener, crtfile, keyfile) } else { logger.Println("tls disabled") logger.Println(strings.Join(name, ", "), "started") err = server.httpserver.Serve(proxyListener) } if err != nil { if errors.Is(err, http.ErrServerClosed) { logger.Println("server.httpserver is closed normally") return nil } logger.Error("server.httpserver.Serve err :", err) return err } } return nil } // ConvertInterface : func ConvertInterface(from interface{}, toType reflect.Type) reflect.Value { fromrv := reflect.ValueOf(from) fromtype := reflect.TypeOf(from) if fromtype == toType { return fromrv } switch fromtype.Kind() { case reflect.Map: frommap := from.(map[string]interface{}) if frommap == nil { // map[string]interface{}가 아니고 뭐란말인가 return reflect.Zero(toType) } needAddr := false if toType.Kind() == reflect.Pointer { toType = toType.Elem() needAddr = true } out := reflect.New(toType) switch toType.Kind() { case reflect.Struct: for i := 0; i < toType.NumField(); i++ { field := toType.FieldByIndex([]int{i}) if field.Anonymous { fieldval := out.Elem().FieldByIndex([]int{i}) fieldval.Set(ConvertInterface(from, field.Type)) } else if v, ok := frommap[strings.ToLower(field.Name)]; ok { fieldval := out.Elem().FieldByIndex([]int{i}) fieldval.Set(ConvertInterface(v, fieldval.Type())) } } case reflect.Map: out.Elem().Set(reflect.MakeMap(toType)) if toType.Key().Kind() != reflect.String { for k, v := range frommap { tok := reflect.ValueOf(k).Convert(toType.Key()) tov := ConvertInterface(v, toType.Elem()) out.Elem().SetMapIndex(tok, tov) } } else { for k, v := range frommap { tov := ConvertInterface(v, toType.Elem()) out.Elem().SetMapIndex(reflect.ValueOf(k), tov) } } default: logger.Println("ConvertInterface failed : map ->", toType.Kind().String()) } if needAddr { return out } return out.Elem() case reflect.Slice: convslice := reflect.MakeSlice(toType, 0, fromrv.Len()) for i := 0; i < fromrv.Len(); i++ { convslice = reflect.Append(convslice, ConvertInterface(fromrv.Index(i).Interface(), toType.Elem())) } return convslice case reflect.Bool: if fromstr, ok := from.(string); ok { val, _ := strconv.ParseBool(fromstr) return reflect.ValueOf(val) } else if frombool, ok := from.(bool); ok { return reflect.ValueOf(frombool) } return reflect.ValueOf(false) case reflect.String: if toType == reflect.TypeOf(primitive.ObjectID{}) { objid, _ := primitive.ObjectIDFromHex(from.(string)) return reflect.ValueOf(objid) } } return fromrv.Convert(toType) } // ErrUnmarshalRequestFailed : var ErrUnmarshalRequestFailed = errors.New("unmarshal failed at rpc handler") var ErrWithStatusCode = errors.New("custom error with status code") var ErrSecondReturnShouldBeErrorType = errors.New("second return type should be error") type methodCacheType struct { sync.Mutex toIndex map[reflect.Type]map[string]int } var methodCache = methodCacheType{ toIndex: make(map[reflect.Type]map[string]int), } func (mc *methodCacheType) method(receiver any, method string) (reflect.Value, bool) { mc.Lock() defer mc.Unlock() t := reflect.TypeOf(receiver) table, ok := mc.toIndex[t] if !ok { table = make(map[string]int) mc.toIndex[t] = table } idx, ok := table[method] if !ok { m, ok := t.MethodByName(method) if !ok { return reflect.Value{}, false } table[method] = m.Index idx = m.Index } return reflect.ValueOf(receiver).Method(idx), true } // CallMethodInternal : func CallMethodInternal(receiver interface{}, context functionCallContext) (result RpcReturnTypeInterface, err error) { fn, ok := methodCache.method(receiver, context.Method) if !ok { return nil, errors.New(fmt.Sprint("method is missing :", receiver, context.Method)) } if len(context.Args) != fn.Type().NumIn() { return nil, errors.New(fmt.Sprint("argument is not matching :", context.Method)) } var argv []reflect.Value for i := 0; i < len(context.Args); i++ { argv = append(argv, ConvertInterface(context.Args[i], fn.Type().In(i))) } returnVals := fn.Call(argv) if len(returnVals) == 0 { return nil, ErrSecondReturnShouldBeErrorType } result = nil err = nil if result, ok = returnVals[0].Interface().(RpcReturnTypeInterface); !ok { return nil, ErrSecondReturnShouldBeErrorType } if len(returnVals) > 1 { if err, ok = returnVals[1].Interface().(error); !ok { return nil, ErrSecondReturnShouldBeErrorType } } return } // CallMethod : func CallMethod(receiver interface{}, context []byte) (RpcReturnTypeInterface, error) { defer func() { r := recover() if r != nil { logger.Error(r) } }() var meta functionCallContext if err := json.Unmarshal(context, &meta); err != nil { return nil, ErrUnmarshalRequestFailed } return CallMethodInternal(receiver, meta) } func ReadIntegerFormValue(r url.Values, key string) (int64, bool) { strval := r.Get(key) if len(strval) == 0 { return 0, false } temp, err := strconv.ParseInt(strval, 10, 0) if err != nil { logger.Error("common.ReadFloatFormValue failed :", key, err) return 0, false } return temp, true } func ReadFloatFormValue(r url.Values, key string) (float64, bool) { var inst float64 strval := r.Get(key) if len(strval) == 0 { return inst, false } temp, err := strconv.ParseFloat(strval, 64) if err != nil { logger.Error("common.ReadFloatFormValue failed :", key, err) return inst, false } return temp, true } func ReadObjectIDFormValue(r url.Values, key string) (primitive.ObjectID, bool) { strval := r.Get(key) if len(strval) == 0 { return primitive.NilObjectID, false } id, err := primitive.ObjectIDFromHex(strval) if err != nil { logger.Error("common.ReadObjectIDFormValue failed :", key, err) return primitive.NilObjectID, false } return id, true } func ReadBoolFormValue(r url.Values, key string) (bool, bool) { strval := r.Get(key) if len(strval) == 0 { return false, false } ret, err := strconv.ParseBool(strval) if err != nil { logger.Error("common.ReadBoolFormValue failed :", key, err) return false, false } return ret, true } func ReadStringFormValue(r url.Values, key string) (string, bool) { strval := r.Get(key) return strval, len(strval) > 0 } func ReadStringsFormValue(r url.Values, key string) ([]string, bool) { if r.Has(key) { return (map[string][]string)(r)[key], true } return nil, false } type encoder interface { Encode(any) error } type nilEncoder struct{} func (ne *nilEncoder) Encode(any) error { return nil } type decoder interface { Decode(any) error } type nilDecoder struct{} func (nd *nilDecoder) Decode(any) error { return nil } func MakeDecoder(r *http.Request) decoder { ct := r.Header.Get("Content-Type") if ct == "application/gob" { return gob.NewDecoder(r.Body) } else if ct == "application/json" { return json.NewDecoder(r.Body) } logger.Error("Content-Type is not supported :", ct) return &nilDecoder{} } func MakeEncoder(w http.ResponseWriter, r *http.Request) encoder { ct := r.Header.Get("Content-Type") if ct == "application/gob" { return gob.NewEncoder(w) } else if ct == "application/json" { return json.NewEncoder(w) } logger.Error("Content-Type is not supported :", ct) return &nilEncoder{} } func DotStringToTimestamp(tv string) primitive.Timestamp { if len(tv) == 0 { return primitive.Timestamp{T: 0, I: 0} } ti := strings.Split(tv, ".") t, _ := strconv.ParseUint(ti[0], 10, 0) if len(ti) > 1 { i, _ := strconv.ParseUint(ti[1], 10, 0) return primitive.Timestamp{T: uint32(t), I: uint32(i)} } return primitive.Timestamp{T: uint32(t), I: 0} } type RpcReturnSimple struct { value any } func (rt *RpcReturnSimple) Value() any { return rt.value } func (rt *RpcReturnSimple) Error() error { return nil } func (rt *RpcReturnSimple) Serialize(w http.ResponseWriter) error { err := SerializeInterface(w, rt.value) if err == nil { w.Write([]byte{0, 0}) } return err } type RpcReturnError struct { err error code int h map[string]any } func (rt *RpcReturnError) Value() any { return nil } var errDefaultError = errors.New("unknown error") func (rt *RpcReturnError) Error() error { if rt.err != nil { return rt.err } if rt.code != 0 { if rt.h == nil { return fmt.Errorf("http status code %d error", rt.code) } return fmt.Errorf("http status code %d error with header %v", rt.code, rt.h) } return errDefaultError } func (rt *RpcReturnError) Serialize(w http.ResponseWriter) error { if rt.h != nil { bt, _ := json.Marshal(rt.h) w.Header().Add("As-X-Err", string(bt)) } w.WriteHeader(rt.code) if rt.err != nil { logger.Error(rt.err) } if rt.code >= http.StatusInternalServerError { logger.Println("rpc return error :", rt.code, rt.h) } return nil } func (rt *RpcReturnError) WithCode(code int) *RpcReturnError { rt.code = code return rt } func (rt *RpcReturnError) WithError(err error) *RpcReturnError { if err2, ok := err.(*ErrorWithStatus); ok { rt.WithCode(err2.StatusCode) } else { rt.err = err } return rt } func (rt *RpcReturnError) WithHeader(k string, v any) *RpcReturnError { if rt.h == nil { rt.h = make(map[string]any) } rt.h[k] = v return rt } func (rt *RpcReturnError) WithHeaders(h map[string]any) *RpcReturnError { if rt.h == nil { rt.h = h } else { for k, v := range h { rt.h[k] = v } } return rt } // MakeRPCReturn : func MakeRPCReturn(value interface{}) *RpcReturnSimple { return &RpcReturnSimple{ value: value, } } func MakeRPCFail() *RpcReturnError { return &RpcReturnError{ err: nil, code: http.StatusInternalServerError, h: nil, } } func MakeRPCError() *RpcReturnError { pc, _, _, ok := runtime.Caller(1) if ok { frames := runtime.CallersFrames([]uintptr{pc}) frame, _ := frames.Next() logger.Printf("rpc error. func=%s, file=%s, line=%d", frame.Function, frame.File, frame.Line) } return &RpcReturnError{ err: nil, code: http.StatusInternalServerError, h: nil, } } type indirectBody struct { inner io.ReadCloser dump []byte closer func() } func (ib *indirectBody) Read(p []byte) (n int, err error) { n, err = ib.inner.Read(p) if n > 0 { ib.dump = append(ib.dump, p...) } return } func (ib *indirectBody) Close() error { if ib.closer != nil { ib.closer() ib.closer = nil } return ib.inner.Close() } func MakeHttpRequestForLogging(r *http.Request) *http.Request { ib := &indirectBody{ inner: r.Body, } closer := func() { var uv url.Values if r.Form != nil { uv = r.Form } else { uv = r.URL.Query() } logger.Println("request :") logger.Println(" header :", r.Header) logger.Println(" values :", uv) logger.Println(" body :", string(ib.dump)) } ib.closer = closer r.Body = ib return r } type apiFuncType func(http.ResponseWriter, *http.Request) type HttpApiHandler struct { methods map[string]apiFuncType originalReceiverName string } func MakeHttpApiHandler[T any](receiver *T, receiverName string) HttpApiHandler { methods := make(map[string]apiFuncType) tp := reflect.TypeOf(receiver) if len(receiverName) == 0 { receiverName = tp.Elem().Name() } writerType := reflect.TypeOf((*http.ResponseWriter)(nil)).Elem() for i := 0; i < tp.NumMethod(); i++ { method := tp.Method(i) if method.Type.NumIn() != 3 { continue } if method.Type.In(0) != tp { continue } if !method.Type.In(1).Implements(writerType) { continue } var r http.Request if method.Type.In(2) != reflect.TypeOf(&r) { continue } if method.Name == "ServeHTTP" { continue } funcptr := method.Func.Pointer() p1 := unsafe.Pointer(&funcptr) p2 := unsafe.Pointer(&p1) testfunc := (*func(*T, http.ResponseWriter, *http.Request))(p2) methods[receiverName+"."+method.Name] = func(w http.ResponseWriter, r *http.Request) { (*testfunc)(receiver, w, r) } } return HttpApiHandler{ methods: methods, originalReceiverName: tp.Elem().Name(), } } type HttpApiBroker struct { methods map[string]apiFuncType methods_dup map[string][]apiFuncType } type bufferReadCloser struct { *bytes.Reader } func (buff *bufferReadCloser) Close() error { return nil } type readOnlyResponseWriter struct { inner http.ResponseWriter statusCode int } func (w *readOnlyResponseWriter) Header() http.Header { return w.inner.Header() } func (w *readOnlyResponseWriter) Write(in []byte) (int, error) { logger.Println("readOnlyResponseWriter cannot write") return len(in), nil } func (w *readOnlyResponseWriter) WriteHeader(statusCode int) { w.statusCode = statusCode } func (hc *HttpApiBroker) AddHandler(receiver HttpApiHandler) { if hc.methods == nil { hc.methods = make(map[string]apiFuncType) hc.methods_dup = make(map[string][]apiFuncType) } for k, v := range receiver.methods { ab := strings.Split(k, ".") logger.Printf("http api registered : %s.%s -> %s\n", receiver.originalReceiverName, ab[1], k) hc.methods_dup[k] = append(hc.methods_dup[k], v) if len(hc.methods_dup[k]) > 1 { chain := hc.methods_dup[k] hc.methods[k] = func(w http.ResponseWriter, r *http.Request) { body, _ := io.ReadAll(r.Body) defer r.Body.Close() wrap := &readOnlyResponseWriter{inner: w, statusCode: 200} for _, f := range chain { r.Body = &bufferReadCloser{bytes.NewReader(body)} f(wrap, r) } if wrap.statusCode != 200 { w.WriteHeader(wrap.statusCode) } } } else { hc.methods[k] = v } } } func (hc *HttpApiBroker) AllMethodNames() (out []string) { out = make([]string, 0, len(hc.methods)) for name := range hc.methods { out = append(out, name) } return } func (hc *HttpApiBroker) CallByHeader(w http.ResponseWriter, r *http.Request) { funcname := r.Header.Get("AS-X-CALL") if len(funcname) == 0 { logger.Println("as-x-call header is missing") w.WriteHeader(http.StatusBadRequest) return } hc.call(funcname, w, r) } func (hc *HttpApiBroker) Call(w http.ResponseWriter, r *http.Request) { funcname := r.URL.Query().Get("call") if len(funcname) == 0 { logger.Println("query param 'call' is missing") w.WriteHeader(http.StatusBadRequest) return } hc.call(funcname, w, r) } func (hc *HttpApiBroker) call(funcname string, w http.ResponseWriter, r *http.Request) { if found := hc.methods[funcname]; found != nil { found(w, r) } else { logger.Println("api is not found :", funcname) } } func CallInternalServiceAPI[T any](url string, apitoken string, method string, data T, headers ...string) error { tempHeader := make(http.Header) tempHeader.Set("MG-X-API-TOKEN", apitoken) tempHeader.Set("Content-Type", "application/gob") for i := 1; i < len(headers); i += 2 { tempHeader.Set(headers[i-1], headers[i]) } buff := new(bytes.Buffer) ct := tempHeader.Get("Content-Type") if ct == "application/gob" { enc := gob.NewEncoder(buff) enc.Encode(data) } else if ct == "application/json" { enc := json.NewEncoder(buff) enc.Encode(data) } reqURL := fmt.Sprintf("%s/api?call=%s", url, method) req, err := http.NewRequest("POST", reqURL, buff) if err != nil { return err } req.Header = tempHeader resp, err := http.DefaultClient.Do(req) defer func() { if resp != nil && resp.Body != nil { resp.Body.Close() } }() if err == nil { if resp.StatusCode != http.StatusOK { return &ErrorWithStatus{StatusCode: resp.StatusCode} } } return err } func CallInternalServiceAPIAs[Tin any, Tout any](url string, apitoken string, method string, data Tin, out *Tout, headers ...string) error { tempHeader := make(http.Header) tempHeader.Set("MG-X-API-TOKEN", apitoken) tempHeader.Set("Content-Type", "application/gob") for i := 1; i < len(headers); i += 2 { tempHeader.Set(headers[i-1], headers[i]) } buff := new(bytes.Buffer) ct := tempHeader.Get("Content-Type") if ct == "application/gob" { enc := gob.NewEncoder(buff) enc.Encode(data) } else if ct == "application/json" { enc := json.NewEncoder(buff) enc.Encode(data) } reqURL := fmt.Sprintf("%s/api?call=%s", url, method) req, err := http.NewRequest("POST", reqURL, buff) if err != nil { return err } req.Header = tempHeader resp, err := http.DefaultClient.Do(req) if err != nil { return err } if resp.StatusCode != http.StatusOK { return &ErrorWithStatus{StatusCode: resp.StatusCode} } defer func() { if resp != nil && resp.Body != nil { resp.Body.Close() } }() if out != nil && resp.Body != nil { dec := gob.NewDecoder(resp.Body) return dec.Decode(out) } return nil }