From e95efa06a6d458d4b9229f5d84e815188ead9f49 Mon Sep 17 00:00:00 2001 From: mountain Date: Wed, 24 May 2023 15:10:15 +0900 Subject: [PATCH] =?UTF-8?q?go-ayo/common=EC=9D=84=20gocommon=EC=9C=BC?= =?UTF-8?q?=EB=A1=9C=20=EB=B6=84=EB=A6=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 3 + authcollection.go | 416 ++++++++++++++++++++++++++ azure/bot.go | 456 +++++++++++++++++++++++++++++ azure/func.go | 199 +++++++++++++ azure/graph.go | 356 ++++++++++++++++++++++ azure/misc.go | 261 +++++++++++++++++ azure/monitor.go | 202 +++++++++++++ document/document.go | 22 ++ document/document_filebase.go | 179 ++++++++++++ flag/flag_helper.go | 56 ++++ go.mod | 28 ++ go.sum | 73 +++++ locker_redis.go | 43 +++ logger/logger.go | 106 +++++++ misc.go | 104 +++++++ mongo.go | 485 ++++++++++++++++++++++++++++++ redis.go | 70 +++++ reflect_config.go | 75 +++++ s3/func.go | 392 +++++++++++++++++++++++++ server.go | 535 ++++++++++++++++++++++++++++++++++ xboxlive/xboxlive.go | 328 +++++++++++++++++++++ 21 files changed, 4389 insertions(+) create mode 100644 .gitignore create mode 100644 authcollection.go create mode 100644 azure/bot.go create mode 100644 azure/func.go create mode 100644 azure/graph.go create mode 100644 azure/misc.go create mode 100644 azure/monitor.go create mode 100644 document/document.go create mode 100644 document/document_filebase.go create mode 100644 flag/flag_helper.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 locker_redis.go create mode 100644 logger/logger.go create mode 100644 misc.go create mode 100644 mongo.go create mode 100644 redis.go create mode 100644 reflect_config.go create mode 100644 s3/func.go create mode 100644 server.go create mode 100644 xboxlive/xboxlive.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..7da7304 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +.vscode/ +*.log +*.exe diff --git a/authcollection.go b/authcollection.go new file mode 100644 index 0000000..0dc2c1c --- /dev/null +++ b/authcollection.go @@ -0,0 +1,416 @@ +package common + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "sync" + "sync/atomic" + "time" + "unsafe" + + "repositories.action2quare.com/ayo/gocommon/logger" + + "github.com/go-redis/redis/v8" + + "go.mongodb.org/mongo-driver/bson/primitive" +) + +const ( + sessionSyncChannelName = "session-sync-channel2" +) + +type AuthinfoCell interface { + ToBytes() []byte + ToAuthinfo() *Authinfo +} + +type AuthCollection struct { + lock sync.Mutex + + // key : session + auths map[string]*Authinfo + expiring map[string]*Authinfo + + // accid -> session (auths) + reverseOn map[primitive.ObjectID]string + // accid -> session (expiring) + reverseOff map[primitive.ObjectID]string + + nextTrim time.Time + ttl time.Duration + + SessionRemoved func(string) + SessionAdded func(AuthinfoCell) + QuerySession func(string, string) AuthinfoCell + Stop func() +} + +func MakeAuthCollection(sessionTTL time.Duration) *AuthCollection { + return &AuthCollection{ + auths: make(map[string]*Authinfo), + expiring: make(map[string]*Authinfo), + reverseOn: make(map[primitive.ObjectID]string), + reverseOff: make(map[primitive.ObjectID]string), + nextTrim: time.Now().Add(time.Hour * 1000000), + ttl: sessionTTL, + SessionRemoved: func(string) {}, + SessionAdded: func(AuthinfoCell) {}, + QuerySession: func(string, string) AuthinfoCell { return nil }, + } +} + +type authCollectionConfig struct { + RegionStorageConfig + Maingate string `json:"maingate_service_url"` +} + +type redisAuthCell struct { + raw []byte +} + +func (ac *redisAuthCell) ToAuthinfo() *Authinfo { + var out Authinfo + err := json.Unmarshal(ac.raw, &out) + if err != nil { + logger.Error("redisAuthCell ToAuthinfo failed :", string(ac.raw), err) + return nil + } + return &out +} + +func (ac *redisAuthCell) ToBytes() []byte { + return ac.raw +} + +func newAuthCollectionWithRedis(redisClient *redis.Client, subctx context.Context, maingateURL string, apiToken string) *AuthCollection { + req, _ := http.NewRequest("GET", fmt.Sprintf("%s/config", maingateURL), nil) + req.Header.Add("MG-X-API-TOKEN", apiToken) + + sessionTTL := int64(3600) + client := http.Client{ + Timeout: 3 * time.Second, + } + resp, err := client.Do(req) + if resp != nil { + defer resp.Body.Close() + } + + if err != nil { + if !*Devflag { + logger.Error("get maingate config failed :", err) + return nil + } + } else if resp.StatusCode == http.StatusOK { + raw, _ := io.ReadAll(resp.Body) + if len(raw) == 0 { + logger.Error("get maingate config failed :", err) + return nil + } + + var config map[string]any + err = json.Unmarshal(raw, &config) + if err != nil { + logger.Error("get maingate config failed :", err) + return nil + } + + if ttl, ok := config["maingate_session_ttl"].(float64); ok { + sessionTTL = int64(ttl) + } + } else if !*Devflag { + logger.Error("get maingate config failed :", err) + return nil + } + + ac := MakeAuthCollection(time.Duration(sessionTTL * int64(time.Second))) + pubsub := redisClient.Subscribe(subctx, sessionSyncChannelName) + ctx, cancel := context.WithCancel(context.TODO()) + go func(ctx context.Context, sub *redis.PubSub, authCache *AuthCollection) { + for { + select { + case <-ctx.Done(): + return + + case msg := <-sub.Channel(): + if msg == nil { + return + } + + if len(msg.Payload) == 0 { + continue + } + + if msg.Payload[0] == '-' { + authCache.RemoveBySessionKey(msg.Payload[1:], false) + } else { + authCache.AddRaw(&redisAuthCell{ + raw: []byte(msg.Payload), + }) + } + } + } + }(ctx, pubsub, ac) + + ac.Stop = cancel + ac.QuerySession = func(key string, token string) AuthinfoCell { + req, _ := http.NewRequest("GET", fmt.Sprintf("%s/query?sk=%s", maingateURL, key), nil) + req.Header.Add("Authorization", "Bearer "+token) + req.Header.Add("MG-X-API-TOKEN", apiToken) + + client := http.Client{} + resp, err := client.Do(req) + if err != nil { + logger.Error("authorize query failed :", err) + return nil + } + defer resp.Body.Close() + + raw, _ := io.ReadAll(resp.Body) + if len(raw) == 0 { + // 세션키가 없네? 클라이언트한테 재로그인하라고 알려줘야 함 + return nil + } + + return &redisAuthCell{ + raw: raw, + } + } + ac.SessionAdded = func(cell AuthinfoCell) { + redisClient.Publish(context.Background(), sessionSyncChannelName, cell.ToBytes()) + } + ac.SessionRemoved = func(sk string) { + redisClient.Publish(context.Background(), sessionSyncChannelName, "-"+sk) + } + + return ac +} + +type AuthCollectionGlobal struct { + apiToken string + ptr unsafe.Pointer // map[string]*AuthCollection +} + +func (acg *AuthCollectionGlobal) Get(region string) *AuthCollection { + ptr := atomic.LoadPointer(&acg.ptr) + oldval := *(*map[string]*AuthCollection)(ptr) + + return oldval[region] +} + +func (acg *AuthCollectionGlobal) Regions() (out []string) { + ptr := atomic.LoadPointer(&acg.ptr) + oldval := *(*map[string]*AuthCollection)(ptr) + + for k := range oldval { + out = append(out, k) + } + return +} + +func (acg *AuthCollectionGlobal) Reload(context context.Context) error { + ptr := atomic.LoadPointer(&acg.ptr) + oldval := *(*map[string]*AuthCollection)(ptr) + + var config authCollectionConfig + if err := LoadConfig(&config); err != nil { + return err + } + + newval := make(map[string]*AuthCollection) + for r, c := range oldval { + if _, ok := config.RegionStorage[r]; !ok { + // 없어졌네? 닫음 + c.Stop() + } else { + newval[r] = c + } + } + + for r, url := range config.RegionStorage { + if _, ok := oldval[r]; !ok { + // 새로 생겼네 + redisClient, err := NewRedisClient(url.Redis.URL, url.Redis.Offset["session"]) + if err != nil { + return err + } + + if authCache := newAuthCollectionWithRedis(redisClient, context, config.Maingate, acg.apiToken); authCache != nil { + newval[r] = authCache + } + } + } + + atomic.StorePointer(&acg.ptr, unsafe.Pointer(&newval)) + return nil +} + +func NewAuthCollectionGlobal(context context.Context, apiToken string) (AuthCollectionGlobal, error) { + var config authCollectionConfig + if err := LoadConfig(&config); err != nil { + return AuthCollectionGlobal{}, err + } + + output := make(map[string]*AuthCollection) + for region, url := range config.RegionStorage { + redisClient, err := NewRedisClient(url.Redis.URL, url.Redis.Offset["session"]) + if err != nil { + return AuthCollectionGlobal{}, err + } + + if authCache := newAuthCollectionWithRedis(redisClient, context, config.Maingate, apiToken); authCache != nil { + output[region] = authCache + } + } + + return AuthCollectionGlobal{ + apiToken: apiToken, + ptr: unsafe.Pointer(&output), + }, nil +} + +func (sc *AuthCollection) AddRaw(cell AuthinfoCell) { + sc.lock.Lock() + defer sc.lock.Unlock() + + if time.Now().After(sc.nextTrim) { + sc.expiring, sc.auths = sc.auths, sc.expiring + sc.reverseOff, sc.reverseOn = sc.reverseOn, sc.reverseOff + + // maps 패키지는 아직 0.0.0 상태;; https://pkg.go.dev/golang.org/x/exp/maps?tab=versions + // maps.Clear(sc.auths) + sc.auths = make(map[string]*Authinfo) + sc.reverseOn = make(map[primitive.ObjectID]string) + } + + newauth := cell.ToAuthinfo() + if newauth == nil { + logger.Println("AuthCollection.AddRaw failed. cell.ToAuthinfo returns nil") + return + } + + sk := newauth.Sk.Hex() + if oldsk, exists := sc.reverseOn[newauth.Accid]; exists { + delete(sc.auths, oldsk) + delete(sc.reverseOn, newauth.Accid) + } else if oldsk, exists = sc.reverseOff[newauth.Accid]; exists { + delete(sc.expiring, oldsk) + delete(sc.reverseOff, newauth.Accid) + } + + delete(sc.auths, sk) + delete(sc.expiring, sk) + + sc.auths[sk] = newauth + sc.reverseOn[newauth.Accid] = sk + + if len(sc.auths) == 1 { + sc.nextTrim = time.Now().Add(sc.ttl) + } +} + +func (sc *AuthCollection) Find(sk string) *Authinfo { + sc.lock.Lock() + defer sc.lock.Unlock() + + if found, ok := sc.auths[sk]; ok { + return found + } + + return sc.expiring[sk] +} + +func (sc *AuthCollection) RemoveByAccId(accid primitive.ObjectID) { + sc.lock.Lock() + defer sc.lock.Unlock() + + logger.Println("AuthCollection.RemoveByAccId :", accid.Hex()) + + var sk string + if on, ok := sc.reverseOn[accid]; ok { + sk = on + } else if off, ok := sc.reverseOff[accid]; ok { + sk = off + } + + if len(sk) > 0 { + old := sc.auths[sk] + if old != nil { + accid = old.Accid + delete(sc.auths, sk) + delete(sc.reverseOn, accid) + } else if old = sc.expiring[sk]; old != nil { + accid = old.Accid + delete(sc.expiring, sk) + delete(sc.reverseOff, accid) + } + } +} + +func (sc *AuthCollection) RemoveBySessionKey(sk string, publish bool) (accid primitive.ObjectID) { + sc.lock.Lock() + defer sc.lock.Unlock() + + logger.Println("AuthCollection.RemoveBySessionKey :", sk, publish) + + if publish { + // 나한테 있든 없든 무조건 publish해야 함 + sc.SessionRemoved(sk) + } + + old := sc.auths[sk] + if old != nil { + accid = old.Accid + delete(sc.auths, sk) + delete(sc.reverseOn, accid) + } else if old = sc.expiring[sk]; old != nil { + accid = old.Accid + delete(sc.expiring, sk) + delete(sc.reverseOff, accid) + } else { + accid = primitive.NilObjectID + } + + return +} + +func (sc *AuthCollection) IsValid(sk string, token string) (accid primitive.ObjectID, success bool) { + exists := sc.Find(sk) + if exists != nil { + now := int64(primitive.NewDateTimeFromTime(time.Now().UTC())) + //if int64(exists.Expired) > now && exists.Token == token { + if int64(exists.Expired) > now { //-- accesstoken은 사실상 쓰지 않는다. + return exists.Accid, true + } + + if exists.Expired == 0 { + // 이미 maingate db까지 가서 만료된 것으로 확인된 키다. + return primitive.NilObjectID, false + } + } + + cell := sc.QuerySession(sk, token) + if cell == nil { + // maingate db까지 가서 만료된 것으로 확인된 키다. Expired를 0으로 저장해 놓고 쿼리를 더 이상 보내지 않도록 + sc.lock.Lock() + defer sc.lock.Unlock() + + sc.auths[sk] = &Authinfo{Expired: 0} + logger.Println("session is invalid. cell is nil :", sk) + + return primitive.NilObjectID, false + } + + newauth := cell.ToAuthinfo() + if newauth == nil { + logger.Println("session is invalid. ToAuthinfo() returns nil :", sk) + return primitive.NilObjectID, false + } + + sc.AddRaw(cell) + sc.SessionAdded(cell) + + return newauth.Accid, true +} diff --git a/azure/bot.go b/azure/bot.go new file mode 100644 index 0000000..265113c --- /dev/null +++ b/azure/bot.go @@ -0,0 +1,456 @@ +package azure + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "strconv" + "strings" + "time" +) + +type ChannelAccount struct { + Id string `json:"id"` + Name string `json:"name,omitempty"` + AadObjectId string `json:"aadObjectId,omitempty"` + Role string `json:"role,omitempty"` +} + +type AttachmentData struct { + Name string `json:"name,omitempty"` + OriginalBase64 string `json:"originalBase64,omitempty"` + ThumbnailBase64 string `json:"thumbnailBase64,omitempty"` + Type string `json:"type,omitempty"` +} + +type MessageActivity struct { + Type string `json:"type,omitempty"` + Timestamp time.Time `json:"timestamp,omitempty"` + LocalTimestamp time.Time `json:"localTimestamp,omitempty"` + Id string `json:"id"` + ChannelId string `json:"channelId,omitempty"` + ServiceUrl string `json:"serviceUrl,omitempty"` + Message string + Mentions []string + From ChannelAccount `json:"from,omitempty"` + Conversation *struct { + ConversationType string `json:"conversationType,omitempty"` + TenantId string `json:"tenantId,omitempty"` + Id string `json:"id"` + Name string `json:"name"` + } `json:"conversation,omitempty"` + Recipient *struct { + Name string `json:"name,omitempty"` + Id string `json:"id"` + } `json:"recipient,omitempty"` + Entities []interface{} `json:"entities,omitempty"` + ChannelData *struct { + TeamsChannelId string `json:"teamsChannelId,omitempty"` + TeamsTeamId string `json:"teamsTeamId,omitempty"` + Channel *struct { + Id string `json:"id"` + } `json:"channel,omitempty"` + Team *struct { + Id string `json:"id"` + } `json:"team,omitempty"` + Tenant *struct { + Id string `json:"id,omitempty"` + } `json:"tenant,omitempty"` + } `json:"channelData,omitempty"` + ReplyToId string `json:"replyToId,omitempty"` + Value map[string]string `json:"value,omitempty"` + Locale string `json:"Locale,omitempty"` + LocalTimezone string `json:"LocalTimezone,omitempty"` +} + +type AdaptiveCard struct { + Body []interface{} + Actions []interface{} +} + +type tempMessageActivity struct { + MessageActivity + RawText string `json:"text,omitempty"` +} + +var botFrameworkAuth = accessToken{} + +func ParseMessageActivity(src []byte) (MessageActivity, error) { + var message tempMessageActivity + err := json.Unmarshal(src, &message) + if err != nil { + return MessageActivity{}, err + } + + for len(message.RawText) > 0 { + s := strings.Index(message.RawText, "") + if s < 0 { + break + } + message.Message += message.RawText[:s] + message.RawText = message.RawText[s+4:] + e := strings.Index(message.RawText, "") + if e < 0 { + break + } + mention := message.RawText[:e] + message.RawText = message.RawText[e+5:] + message.Mentions = append(message.Mentions, mention) + } + message.Message += strings.Trim(message.RawText, " ") + message.Message = strings.TrimSpace(message.Message) + if !strings.HasSuffix(message.ServiceUrl, "/") { + message.ServiceUrl += "/" + } + return message.MessageActivity, nil +} + +func sendRequireMessage(requrl string, method string, stream io.Reader) (string, error) { + req, err := http.NewRequest(method, requrl, stream) + if err != nil { + return "", err + } + + auth, err := botFrameworkAuth.getAuthoizationToken() + if err != nil { + return "", err + } + req.Header.Set("Authorization", auth) + req.Header.Set("Content-type", "application/json") + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + + bt, _ := io.ReadAll(resp.Body) + var resultdoc map[string]interface{} + json.Unmarshal(bt, &resultdoc) + if len(resultdoc) > 0 { + if idraw, ok := resultdoc["id"]; ok { + return idraw.(string), nil + } + } + + return "", err +} + +func putMessage(requrl string, stream io.Reader) (string, error) { + return sendRequireMessage(requrl, "PUT", stream) +} + +func postMessage(requrl string, stream io.Reader) (string, error) { + return sendRequireMessage(requrl, "POST", stream) +} + +func unescapeUnicodeCharactersInJSON(src []byte) (string, error) { + str, err := strconv.Unquote(strings.Replace(strconv.Quote(string(src)), `\\u`, `\u`, -1)) + if err != nil { + return "", err + } + return str, nil +} + +func (m *MessageActivity) SendStreamOnChannel(reader io.Reader) (string, error) { + serviceUrl := m.ServiceUrl + + // mention 없이 여기 왔으면 private 채팅이다. reply를 해야함 + var conversationId string + if m.ChannelData.Channel == nil { + conversationId = m.Conversation.Id + } else { + conversationId = m.ChannelData.Channel.Id + } + + requrl := fmt.Sprintf("%sv3/conversations/%s/activities", serviceUrl, conversationId) + + return postMessage(requrl, reader) +} + +func (m *MessageActivity) UploadFile(attachment AttachmentData) (string, error) { + // POST /v3/conversations/{conversationId}/attachments + // Request body An AttachmentData object. + // Returns A ResourceResponse object. The id property specifies the attachment ID that can be used with the Get Attachment Info operation and the Get Attachment operation. + serviceUrl := m.ServiceUrl + conversationId := m.Conversation.Id + bt, _ := json.Marshal(attachment) + + requrl := fmt.Sprintf("%sv3/conversations/%s/attachments", serviceUrl, conversationId) + return postMessage(requrl, bytes.NewReader([]byte(bt))) +} + +func makeAttachmentDocs(attachments []AdaptiveCard) []interface{} { + var attachconv []interface{} + for _, attachment := range attachments { + attachconv = append(attachconv, map[string]interface{}{ + "contentType": "application/vnd.microsoft.card.adaptive", + "content": map[string]interface{}{ + "$schema": "http://adaptivecards.io/schemas/adaptive-card.json", + "type": "AdaptiveCard", + "version": "1.3", + "body": attachment.Body, + }, + }) + } + return attachconv +} + +// func (m *MessageActivity) update(activityId string, message string, attachments ...AdaptiveCard) (string, error) { +// serviceUrl := m.ServiceUrl +// conversationId := m.Conversation.Id + +// attachconv := makeAttachmentDocs(attachments) +// activity := map[string]interface{}{ +// "type": "message", +// "text": strings.ReplaceAll(message, `"`, `\"`), +// } +// if len(attachconv) > 0 { +// activity["attachments"] = attachconv +// } + +// reqdoc, _ := json.Marshal(activity) +// unq, _ := unescapeUnicodeCharactersInJSON(reqdoc) + +// requrl := fmt.Sprintf("%sv3/conversations/%s/activities/%s", serviceUrl, conversationId, activityId) +// return putMessage(requrl, bytes.NewReader([]byte(unq))) +// } + +// func (m *MessageActivity) reply(message string, mention bool, attachments ...AdaptiveCard) (string, error) { +// serviceUrl := m.ServiceUrl +// conversationId := m.Conversation.Id + +// attachconv := makeAttachmentDocs(attachments) + +// var activity map[string]interface{} +// if mention { +// mentionText := fmt.Sprintf(`%s`, m.From.Name) +// activity = map[string]interface{}{ +// "type": "message", +// "text": mentionText + " " + strings.ReplaceAll(message, `"`, `\"`), +// "entities": []interface{}{ +// map[string]interface{}{ +// "mentioned": m.From, +// "text": mentionText, +// "type": "mention", +// }, +// }, +// } +// } else { +// activity = map[string]interface{}{ +// "type": "message", +// "text": strings.ReplaceAll(message, `"`, `\"`), +// } +// } + +// activity["replyToId"] = m.Id +// if len(attachconv) > 0 { +// activity["attachments"] = attachconv +// } + +// reqdoc, _ := json.Marshal(activity) +// unq, _ := unescapeUnicodeCharactersInJSON(reqdoc) + +// requrl := fmt.Sprintf("%sv3/conversations/%s/activities", serviceUrl, conversationId) +// return postMessage(requrl, bytes.NewReader([]byte(unq))) +// } + +// func (m *MessageActivity) ReplyWithMentionf(format string, val ...interface{}) (string, error) { +// return m.ReplyWithMention(fmt.Sprintf(format, val...)) +// } + +// func (m *MessageActivity) ReplyWithMention(message string, attachments ...AdaptiveCard) (string, error) { +// return m.reply(message, true, attachments...) +// } + +// func (m *MessageActivity) Reply(message string, attachments ...AdaptiveCard) (string, error) { +// return m.reply(message, false, attachments...) +// } + +// func (m *MessageActivity) Replyf(format string, val ...interface{}) (string, error) { +// return m.Reply(fmt.Sprintf(format, val...)) +// } + +// func (m *MessageActivity) Update(activityId string, message string, attachments ...AdaptiveCard) (string, error) { +// return m.update(activityId, message, attachments...) +// } + +// func (m *MessageActivity) Updatef(activityId string, format string, val ...interface{}) (string, error) { +// return m.Update(activityId, fmt.Sprintf(format, val...)) +// } + +type MessageWrap struct { + *MessageActivity + Prefix string + Mention bool + Attachments []AdaptiveCard +} + +func (m *MessageActivity) MakeWrap() *MessageWrap { + return &MessageWrap{MessageActivity: m} +} + +func (m *MessageWrap) WithPrefix(prefix string) *MessageWrap { + m.Prefix = prefix + return m +} + +func (m *MessageWrap) WithMention() *MessageWrap { + m.Mention = true + return m +} + +func (m *MessageWrap) WithAttachments(attachments ...AdaptiveCard) *MessageWrap { + m.Attachments = attachments + return m +} + +func (m *MessageWrap) Reply(text string) (string, error) { + serviceUrl := m.ServiceUrl + conversationId := m.Conversation.Id + attachconv := makeAttachmentDocs(m.Attachments) + + var activity map[string]interface{} + if m.Mention { + var mentionText string + if len(m.Prefix) > 0 { + mentionText = fmt.Sprintf(`**[%s]** %s`, m.Prefix, m.From.Name) + } else { + mentionText = fmt.Sprintf(`%s`, m.From.Name) + } + activity = map[string]interface{}{ + "type": "message", + "text": mentionText + " " + strings.ReplaceAll(text, `"`, `\"`), + "entities": []interface{}{ + map[string]interface{}{ + "mentioned": m.From, + "text": mentionText, + "type": "mention", + }, + }, + } + } else { + if len(m.Prefix) > 0 { + text = fmt.Sprintf("**[%s]** %s", m.Prefix, text) + } + + activity = map[string]interface{}{ + "type": "message", + "text": strings.ReplaceAll(text, `"`, `\"`), + } + } + + activity["replyToId"] = m.Id + if len(attachconv) > 0 { + activity["attachments"] = attachconv + } + + activity["from"] = m.Recipient + activity["channelId"] = m.ChannelId + reqdoc, _ := json.Marshal(activity) + unq, _ := unescapeUnicodeCharactersInJSON(reqdoc) + + requrl := fmt.Sprintf("%sv3/conversations/%s/activities", serviceUrl, conversationId) + return postMessage(requrl, bytes.NewReader([]byte(unq))) +} + +func (m *MessageWrap) Replyf(format string, args ...interface{}) (string, error) { + return m.Reply(fmt.Sprintf(format, args...)) +} + +func (m *MessageActivity) Update(reader io.Reader) { + serviceUrl := m.ServiceUrl + conversationId := m.Conversation.Id + activityId := m.ReplyToId + + requrl := fmt.Sprintf("%sv3/conversations/%s/activities/%s", serviceUrl, conversationId, activityId) + putMessage(requrl, reader) +} + +func (m *MessageWrap) Update(activityId string, text string) (string, error) { + serviceUrl := m.ServiceUrl + conversationId := m.Conversation.Id + if len(m.Prefix) > 0 { + text = fmt.Sprintf("**[%s]** %s", m.Prefix, text) + } + attachconv := makeAttachmentDocs(m.Attachments) + activity := map[string]interface{}{ + "type": "message", + "text": strings.ReplaceAll(text, `"`, `\"`), + } + if len(attachconv) > 0 { + activity["attachments"] = attachconv + } + + reqdoc, _ := json.Marshal(activity) + unq, _ := unescapeUnicodeCharactersInJSON(reqdoc) + + requrl := fmt.Sprintf("%sv3/conversations/%s/activities/%s", serviceUrl, conversationId, activityId) + return putMessage(requrl, bytes.NewReader([]byte(unq))) +} + +func (m *MessageWrap) Updatef(activityId string, format string, args ...interface{}) (string, error) { + return m.Update(activityId, fmt.Sprintf(format, args...)) +} + +type MessageReplyCache struct { + sourceActivity *MessageActivity + ReplyWrap *MessageWrap + Replyaid string +} + +func (c *MessageReplyCache) Serialize() string { + bt, err := json.Marshal(c) + if err != nil { + return "" + } + linkfilename := fmt.Sprintf("updatelink_%s.json", time.Now().Format("2006-01-02T15-04-05")) + if os.WriteFile(linkfilename, bt, 0666) == nil { + return linkfilename + } + + return "" +} + +func DeserializeMessageReplyCache(filename string) *MessageReplyCache { + bt, err := os.ReadFile(filename) + if err != nil { + return nil + } + + var out MessageReplyCache + if json.Unmarshal(bt, &out) == nil { + return &out + } + + return nil +} + +func MakeMessageReplyCache(src *MessageActivity) MessageReplyCache { + return MessageReplyCache{ + sourceActivity: src, + } +} + +func (ch *MessageReplyCache) SourceActivity() *MessageActivity { + return ch.sourceActivity +} + +func (ch *MessageReplyCache) Reply(text string) { + if ch.ReplyWrap == nil { + hostname, _ := os.Hostname() + ch.ReplyWrap = ch.sourceActivity.MakeWrap().WithPrefix(hostname) + } + + if len(ch.Replyaid) == 0 { + ch.Replyaid, _ = ch.ReplyWrap.Reply(text) + } else { + ch.ReplyWrap.Update(ch.Replyaid, text) + } +} + +func (ch *MessageReplyCache) Replyf(format string, args ...interface{}) { + ch.Reply(fmt.Sprintf(format, args...)) +} diff --git a/azure/func.go b/azure/func.go new file mode 100644 index 0000000..a04e55b --- /dev/null +++ b/azure/func.go @@ -0,0 +1,199 @@ +package azure + +import ( + "crypto/rsa" + "encoding/base64" + "encoding/binary" + "encoding/json" + "errors" + "io" + "math" + "math/big" + "strconv" + + "net/http" + "net/url" + "os" + "strings" + "sync" + "time" + + "github.com/golang-jwt/jwt" +) + +type accessToken struct { + sync.Mutex + typetoken string + expireAt time.Time + url string + values url.Values +} + +var jwkCache struct { + headerlock sync.Mutex + typetoken string + expireAt time.Time + pks map[string]*rsa.PublicKey +} + +func microsoftAppId() string { + val := os.Getenv("MICROSOFT_APP_ID") + if len(val) == 0 { + val = "b5367590-5a94-4df3-bca0-ecd4b693ddf0" + } + return val +} + +func microsoftAppPassword() string { + val := os.Getenv("MICROSOFT_APP_PASSWORD") + if len(val) == 0 { + val = "~VG1cf2-~5Fw3Wz9_4.A.XxpZPO8BwJ36y" + } + return val +} + +func getOpenIDConfiguration(x5t string) (*rsa.PublicKey, error) { + // https://docs.microsoft.com/ko-kr/azure/bot-service/rest-api/bot-framework-rest-connector-authentication?view=azure-bot-service-4.0 + jwkCache.headerlock.Lock() + defer jwkCache.headerlock.Unlock() + + if time.Now().After(jwkCache.expireAt) { + resp, err := http.Get("https://login.botframework.com/v1/.well-known/openidconfiguration") + if err != nil { + return nil, err + } + + defer resp.Body.Close() + bt, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + var doc map[string]interface{} + if err = json.Unmarshal(bt, &doc); err != nil { + return nil, err + } + + url := doc["jwks_uri"].(string) + resp, err = http.Get(url) + if err != nil { + return nil, err + } + defer resp.Body.Close() + bt, err = io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + if err = json.Unmarshal(bt, &doc); err != nil { + return nil, err + } + + keys := doc["keys"].([]interface{}) + newPks := make(map[string]*rsa.PublicKey) + for _, key := range keys { + keydoc := key.(map[string]interface{}) + x5t := keydoc["x5t"].(string) + eb := make([]byte, 4) + nb, _ := base64.RawURLEncoding.DecodeString(keydoc["n"].(string)) + base64.RawURLEncoding.Decode(eb, []byte(keydoc["e"].(string))) + n := big.NewInt(0).SetBytes(nb) + e := binary.LittleEndian.Uint32(eb) + pk := &rsa.PublicKey{ + N: n, + E: int(e), + } + newPks[x5t] = pk + } + + jwkCache.expireAt = time.Now().Add(24 * time.Hour) + jwkCache.pks = newPks + } + + return jwkCache.pks[x5t], nil +} + +func VerifyJWT(header string) error { + if !strings.HasPrefix(header, "Bearer ") { + return errors.New("invalid token") + } + tokenString := header[7:] + token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { + return getOpenIDConfiguration(token.Header["x5t"].(string)) + }) + if err != nil { + return err + } + + if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid { + if claims["iss"].(string) != "https://api.botframework.com" { + return errors.New("issuer is not valid") + } + if claims["aud"].(string) != microsoftAppId() { + return errors.New("audience is not valid") + } + expireAt := int64(claims["exp"].(float64)) + if math.Abs(float64((expireAt-time.Now().UTC().Unix())/int64(time.Second))) >= 300 { + return errors.New("token expired") + } + } else { + return errors.New("VerifyJWT token claims failed") + } + + return nil +} + +func (at *accessToken) getAuthoizationToken() (string, error) { + at.Lock() + defer at.Unlock() + + if len(at.url) == 0 { + at.url = "https://login.microsoftonline.com/botframework.com/oauth2/v2.0/token" + at.values = url.Values{ + "grant_type": {"client_credentials"}, + "client_id": {microsoftAppId()}, + "scope": {"https://api.botframework.com/.default"}, + "client_secret": {microsoftAppPassword()}, + } + } + + if time.Now().After(at.expireAt) { + resp, err := http.PostForm(at.url, at.values) + if err != nil { + return "", err + } + + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + var doc map[string]interface{} + err = json.Unmarshal(body, &doc) + if err != nil { + return "", err + } + + if v, ok := doc["error"]; ok { + if desc, ok := doc["error_description"]; ok { + return "", errors.New(desc.(string)) + } + + return "", errors.New(v.(string)) + } + + tokenType := doc["token_type"].(string) + token := doc["access_token"].(string) + expin := doc["expires_in"] + + var tokenDur int + switch expin := expin.(type) { + case float64: + tokenDur = int(expin) + case string: + tokenDur, _ = strconv.Atoi(expin) + } + + at.typetoken = tokenType + " " + token + at.expireAt = time.Now().Add(time.Duration(tokenDur) * time.Second) + } + + return at.typetoken, nil + +} diff --git a/azure/graph.go b/azure/graph.go new file mode 100644 index 0000000..8846017 --- /dev/null +++ b/azure/graph.go @@ -0,0 +1,356 @@ +package azure + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "os" + "path" + "time" +) + +type Graph struct { + *accessToken + tenantId string + groupId string +} + +func NewGraphByEnv() *Graph { + tenantId := os.Getenv("GRAPH_TENANT_ID") + groupId := os.Getenv("GRAPH_GROUP_ID") + return NewGraph(tenantId, groupId) +} + +func NewGraph(tenantId string, groupId string) *Graph { + return &Graph{ + accessToken: &accessToken{ + url: fmt.Sprintf("https://login.microsoftonline.com/%s/oauth2/v2.0/token", tenantId), + values: url.Values{ + "grant_type": {"client_credentials"}, + "client_id": {microsoftAppId()}, + "scope": {"https://graph.microsoft.com/.default"}, + "client_secret": {microsoftAppPassword()}, + }, + }, + tenantId: tenantId, + groupId: groupId, + } +} + +func (s Graph) GroupId() string { + return s.groupId +} + +var errGroupIdMissing = errors.New("GRAPH_GROUP_ID is missing") + +func (s Graph) ReadFile(relativePath string) (io.ReadCloser, error) { + if len(s.groupId) == 0 { + return nil, errGroupIdMissing + } + + token, err := s.getAuthoizationToken() + if err != nil { + return nil, err + } + requrl := fmt.Sprintf("https://graph.microsoft.com/v1.0/groups/%s/drive/root:/%s:/content", s.groupId, relativePath) + req, _ := http.NewRequest("GET", requrl, nil) + + req.Header.Add("Authorization", token) + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + resp.Body.Close() + return nil, errors.New(resp.Status) + } + + return resp.Body, nil +} + +func (s Graph) GetWebURLOfFile(relativePath string) (string, error) { + if len(s.groupId) == 0 { + return "", errGroupIdMissing + } + + token, err := s.getAuthoizationToken() + if err != nil { + return "", err + } + requrl := fmt.Sprintf("https://graph.microsoft.com/v1.0/groups/%s/drive/root:/%s:", s.groupId, relativePath) + req, _ := http.NewRequest("GET", requrl, nil) + + req.Header.Add("Authorization", token) + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", err + } + defer func() { + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + }() + + itemdoc, _ := io.ReadAll(resp.Body) + var item map[string]interface{} + json.Unmarshal(itemdoc, &item) + + url, ok := item["webUrl"] + if ok { + return url.(string), nil + } + + return "", nil +} + +func (s Graph) DownloadFile(relativePath string, outputFile string) error { + body, err := s.ReadFile(relativePath) + if err != nil { + return err + } + defer body.Close() + + // Create the file + out, err := os.Create(outputFile) + if err != nil { + return err + } + defer out.Close() + + // Write the body to file + _, err = io.Copy(out, body) + return err +} + +func (s Graph) UploadStream(relativePath string, reader io.Reader) (string, error) { + if len(s.groupId) == 0 { + return "", errGroupIdMissing + } + + token, err := s.getAuthoizationToken() + if err != nil { + return "", err + } + + requrl := fmt.Sprintf("https://graph.microsoft.com/v1.0/groups/%s/drive/root:/%s:/content", s.groupId, relativePath) + req, _ := http.NewRequest("PUT", requrl, reader) + + req.Header.Add("Authorization", token) + req.Header.Add("Content-Type", "text/plain") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", err + } + + defer func() { + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + }() + + body, _ := io.ReadAll(resp.Body) + unquot, _ := unescapeUnicodeCharactersInJSON(body) + + var drive map[string]interface{} + err = json.Unmarshal([]byte(unquot), &drive) + if err != nil { + return "", err + } + + if errnode, ok := drive["error"]; ok { + errmsg := errnode.(map[string]interface{})["message"].(string) + return "", errors.New(errmsg) + } + + return drive["webUrl"].(string), nil +} + +func (s Graph) UploadBytes(relativePath string, content []byte) (string, error) { + return s.UploadStream(relativePath, bytes.NewReader(content)) +} + +func (s Graph) GetChannelFilesFolderName(channel string) (string, error) { + if len(s.groupId) == 0 { + return "", errGroupIdMissing + } + + token, err := s.getAuthoizationToken() + if err != nil { + return "", err + } + + requrl := fmt.Sprintf("https://graph.microsoft.com/v1.0/teams/%s/channels/%s/filesFolder", s.groupId, channel) + req, _ := http.NewRequest("GET", requrl, nil) + req.Header.Add("Authorization", token) + req.Header.Add("Accept", "application/json") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", err + } + + defer func() { + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + }() + body, _ := io.ReadAll(resp.Body) + + unquot, _ := unescapeUnicodeCharactersInJSON(body) + + var drive map[string]interface{} + err = json.Unmarshal([]byte(unquot), &drive) + if err != nil { + return "", err + } + + if errnode, ok := drive["error"]; ok { + errmsg := errnode.(map[string]interface{})["message"].(string) + return "", errors.New(errmsg) + } + + return drive["name"].(string), nil +} + +func (s Graph) listChildren(path string) (map[string]interface{}, error) { + if len(s.groupId) == 0 { + return nil, errGroupIdMissing + } + + token, err := s.getAuthoizationToken() + if err != nil { + return nil, err + } + + // groupId := "612620b7-cd90-4b3f-a9bd-d34cddf52517" + var requrl string + if len(path) == 0 { + requrl = fmt.Sprintf("https://graph.microsoft.com/v1.0/groups/%s/drive/root/children", s.groupId) + } else { + requrl = fmt.Sprintf("https://graph.microsoft.com/v1.0/groups/%s/drive/root:/%s:/children", s.groupId, path) + } + req, _ := http.NewRequest("GET", requrl, nil) + req.Header.Add("Authorization", token) + req.Header.Add("Accept", "application/json") + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + + defer func() { + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + }() + body, _ := io.ReadAll(resp.Body) + + unquot, _ := unescapeUnicodeCharactersInJSON(body) + + var drive map[string]interface{} + err = json.Unmarshal([]byte(unquot), &drive) + if err != nil { + return nil, err + } + + if errnode, ok := drive["error"]; ok { + errmsg := errnode.(map[string]interface{})["message"].(string) + return nil, errors.New(errmsg) + } + + return drive, nil +} + +func (s Graph) ListFolders(parent string) ([]string, error) { + drive, err := s.listChildren(parent) + if err != nil { + return nil, err + } + + var output []string + items := drive["value"].([]interface{}) + for _, item := range items { + itemobj := item.(map[string]interface{}) + if _, ok := itemobj["file"]; !ok { + output = append(output, itemobj["name"].(string)) + } + } + return output, nil +} + +type FileMeta struct { + Path string `json:"path"` + LastModified time.Time `json:"lastModified"` +} + +func (s Graph) ListFiles(parent string) ([]FileMeta, error) { + drive, err := s.listChildren(parent) + if err != nil { + return nil, err + } + + var output []FileMeta + items := drive["value"].([]interface{}) + for _, item := range items { + itemobj := item.(map[string]interface{}) + if _, ok := itemobj["file"]; ok { + modTime, _ := time.Parse(time.RFC3339, itemobj["lastModifiedDateTime"].(string)) + + output = append(output, FileMeta{ + Path: path.Join(parent, itemobj["name"].(string)), + LastModified: modTime, + }) + } + } + return output, nil +} + +func (s Graph) GetChannels() (map[string]string, error) { + if len(s.groupId) == 0 { + return nil, errGroupIdMissing + } + + token, err := s.getAuthoizationToken() + if err != nil { + return nil, err + } + + requrl := fmt.Sprintf("https://graph.microsoft.com/v1.0/teams/%s/channels", s.groupId) + req, _ := http.NewRequest("GET", requrl, nil) + req.Header.Add("Authorization", token) + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + defer func() { + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + }() + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + var raw map[string]interface{} + if err = json.Unmarshal(body, &raw); err != nil { + return nil, err + } + + if errnode, ok := raw["error"]; ok { + return nil, errors.New(errnode.(map[string]interface{})["message"].(string)) + } + + if r, ok := raw["value"]; ok { + valuenode := r.([]interface{}) + output := make(map[string]string) + for _, v := range valuenode { + data := v.(map[string]interface{}) + id := data["id"].(string) + displayname := data["displayName"].(string) + output[id] = displayname + } + return output, nil + } + + return nil, errors.New("error") +} diff --git a/azure/misc.go b/azure/misc.go new file mode 100644 index 0000000..c148582 --- /dev/null +++ b/azure/misc.go @@ -0,0 +1,261 @@ +package azure + +import ( + "crypto/md5" + "encoding/hex" + "errors" + "flag" + "fmt" + "io" + "os" + "os/exec" + "path" + "runtime" + "sort" + "strconv" + "strings" + "syscall" + "time" + + "repositories.action2quare.com/ayo/gocommon/logger" + + "go.mongodb.org/mongo-driver/bson" +) + +func SortVersions(versions []string) []string { + sort.Slice(versions, func(i, j int) bool { + leftnum := 0 + for _, iv := range strings.Split(versions[i], ".") { + n, _ := strconv.Atoi(iv) + leftnum += leftnum<<8 + n + } + + rightnum := 0 + for _, iv := range strings.Split(versions[j], ".") { + n, _ := strconv.Atoi(iv) + rightnum += rightnum<<8 + n + } + + return leftnum < rightnum + }) + return versions +} + +const ( + HoustonStatPoolCount = int(10) +) + +type HoustonStatReport struct { + Region string `json:"region"` + FreeMemory float64 `json:"freeMemory"` + FreePercent float64 `json:"freePercent"` + CPUUsage float64 `json:"cpuUsage"` + PlayingUsers int `json:"users"` + PlayingSessions int `json:"sessions"` +} + +var prefetchptr = flag.Bool("prefetch", false, "") + +func NeedPrefetch() bool { + return *prefetchptr +} + +var ErrUpdateUnnecessary = errors.New("binary is already latest") + +func SelfUpdateUsingScript(replycache *MessageReplyCache, graph *Graph, patchroot string, force bool) (err error) { + // 1. 다운로드 + var linkfilearg string + if replycache != nil { + if replycache.ReplyWrap == nil { + hostname, _ := os.Hostname() + replycache.ReplyWrap = replycache.SourceActivity().MakeWrap().WithPrefix(hostname) + } + linkfile := replycache.Serialize() + linkfilearg = fmt.Sprintf("-updatelink=%s", linkfile) + } + + currentexe, _ := os.Executable() + currentexe = strings.ReplaceAll(currentexe, "\\", "/") + exefile := path.Base(currentexe) + if runtime.GOOS == "windows" && !strings.HasSuffix(exefile, ".exe") { + exefile += ".exe" + } + + newbinary := fmt.Sprintf("%s.%s", exefile, time.Now().Format("2006.01.02-15.04.05")) + downloadurl := path.Join(patchroot, exefile) + + err = graph.DownloadFile(downloadurl, newbinary) + if err != nil { + return + } + + _, err = os.Stat(newbinary) + if os.IsNotExist(err) { + err = errors.New("다운로드 실패 : " + newbinary) + return + } + + if !force { + currentfile, e := os.Open(currentexe) + if e != nil { + return e + } + defer currentfile.Close() + + hash1 := md5.New() + _, err = io.Copy(hash1, currentfile) + if err != nil { + return + } + currentHash := hex.EncodeToString(hash1.Sum(nil)) + + nextfile, e := os.Open(newbinary) + if e != nil { + return e + } + defer nextfile.Close() + hash2 := md5.New() + _, err = io.Copy(hash2, nextfile) + if err != nil { + return + } + nextHash := hex.EncodeToString(hash2.Sum(nil)) + + if currentHash == nextHash { + // 해시가 같으니까 업데이트 할 필요가 없다. + return ErrUpdateUnnecessary + } + } + + var scriptfileName string + if runtime.GOOS == "linux" { + scriptfileName = "selfupdate.sh" + } else if runtime.GOOS == "windows" { + scriptfileName = "selfupdate.bat" + } + + scripturl := path.Join(patchroot, scriptfileName) + if _, err = os.Stat(scriptfileName); os.IsExist(err) { + err = os.Remove(scriptfileName) + } else { + err = nil + } + + if err != nil { + return + } + + // 2. 스크립트 다운로드 + err = graph.DownloadFile(scripturl, scriptfileName) + if err != nil { + return + } + + var nextArgs []string + for _, arg := range os.Args[1:] { + // -updatelink : selfupdate시에 + if !strings.HasPrefix(arg, "-updatelink=") && arg != "-prefetch" { + nextArgs = append(nextArgs, arg) + } + } + pid := strconv.Itoa(os.Getpid()) + args := append([]string{ + pid, + newbinary, + exefile, + }, nextArgs...) + + if len(linkfilearg) > 0 { + args = append(args, linkfilearg) + } + + // 3. 독립 실행 + if runtime.GOOS == "linux" { + // 실행 가능한 권한 부여 + err = os.Chmod(scriptfileName, 0777) + if err != nil { + return + } + // 실행 + env := os.Environ() + currentpath := path.Dir(currentexe) + argv0 := path.Join(currentpath, scriptfileName) + args = append([]string{"/bin/bash", argv0}, args...) + err = syscall.Exec(args[0], args, env) + if err != nil { + return + } + } else if runtime.GOOS == "windows" { + windowsargs := append([]string{ + "/C", + "start", + scriptfileName, + }, args...) + + cmd := exec.Command("cmd.exe", windowsargs...) + err = cmd.Run() + if err != nil { + return + } + } + + return +} + +var linkupdate = flag.String("updatelink", "", "") + +func ReplyUpdateComplete() { + defer func() { + r := recover() + if r != nil { + logger.Error(r) + } + }() + + if len(*linkupdate) > 0 { + cache := DeserializeMessageReplyCache(*linkupdate) + if cache != nil { + os.Remove(*linkupdate) + if cache.ReplyWrap != nil { + cache.ReplyWrap.Update(cache.Replyaid, "업데이트 완료") + } + } + } +} + +// var objectIDCounter = readRandomUint32() +// var processUnique = processUniqueBytes() + +// func processUniqueBytes() [5]byte { +// var b [5]byte +// _, err := io.ReadFull(rand.Reader, b[:]) +// if err != nil { +// panic(fmt.Errorf("cannot initialize objectid package with crypto.rand.Reader: %v", err)) +// } + +// return b +// } + +// func readRandomUint32() uint32 { +// var b [4]byte +// _, err := io.ReadFull(rand.Reader, b[:]) +// if err != nil { +// panic(fmt.Errorf("cannot initialize objectid package with crypto.rand.Reader: %v", err)) +// } + +// return (uint32(b[0]) << 0) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24) +// } + +type BsonMarshaler[T any] struct { + val T +} + +func NewBsonMarshaler[T any](val T) *BsonMarshaler[T] { + return &BsonMarshaler[T]{ + val: val, + } +} + +func (m *BsonMarshaler[T]) MarshalBinary() (data []byte, err error) { + return bson.Marshal(m.val) +} diff --git a/azure/monitor.go b/azure/monitor.go new file mode 100644 index 0000000..a793de2 --- /dev/null +++ b/azure/monitor.go @@ -0,0 +1,202 @@ +package azure + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "math" + "net/http" + "net/url" +) + +type CounterCurcuit struct { + curcuit []float64 + cursor int +} + +func (cc *CounterCurcuit) Append(val float64) { + cc.curcuit[cc.cursor] = val + cc.cursor++ +} + +func (cc *CounterCurcuit) Stat() (min float64, max float64, sum float64, count int) { + if cc.cursor == 0 { + return 0, 0, 0, 0 + } + + min = math.MaxFloat64 + max = -math.MaxFloat64 + sum = float64(0) + + if cc.cursor < len(cc.curcuit) { + count = cc.cursor + for i := 0; i < cc.cursor; i++ { + if min > cc.curcuit[i] { + min = cc.curcuit[i] + } + if max < cc.curcuit[i] { + max = cc.curcuit[i] + } + + sum += cc.curcuit[i] + } + return + } + + count = len(cc.curcuit) + for _, v := range cc.curcuit { + if min > v { + min = v + } + if max < v { + max = v + } + + sum += v + } + + return +} + +func MakeCounterCurcuit(length int) *CounterCurcuit { + return &CounterCurcuit{ + curcuit: make([]float64, length), + cursor: 0, + } +} + +type Float64Counter struct { + Min float64 `json:"min"` + Max float64 `json:"max"` + Sum float64 `json:"sum"` + Count int `json:"count"` +} + +var DefaultFloat64Counter = Float64Counter{ + Min: math.MaxFloat64, + Max: -math.MaxFloat64, + Sum: 0, + Count: 0, +} + +func (c *Float64Counter) Valid() bool { + return c.Count > 0 +} + +func (c *Float64Counter) Add(val float64) { + if c.Max < val { + c.Max = val + } + + if c.Min > val { + c.Min = val + } + + c.Sum += val + c.Count++ +} + +func (c *Float64Counter) SingleMin() Float64Counter { + return Float64Counter{ + Min: c.Min, + Max: c.Min, + Sum: c.Min, + Count: 1, + } +} + +func (c *Float64Counter) SingleMax() Float64Counter { + return Float64Counter{ + Min: c.Max, + Max: c.Max, + Sum: c.Max, + Count: 1, + } +} + +func (c *Float64Counter) AddInt(val int) { + fval := float64(val) + if c.Max < fval { + c.Max = fval + } + + if c.Min > fval { + c.Min = fval + } + + c.Sum += fval + c.Count++ +} + +type MetricSeries struct { + Float64Counter + DimValues []string `json:"dimValues"` +} + +type MetructBaseData struct { + Metric string `json:"metric"` + Namespace string `json:"namespace"` + DimNames []string `json:"dimNames"` + Series []MetricSeries `json:"series"` +} + +type MetricData struct { + BaseData MetructBaseData `json:"baseData"` +} + +type MetricDocument struct { + Time string `json:"time"` + Data MetricData `json:"data"` +} + +func MakeMetricDocument(timestamp string, namespace string, metric string, dimNames []string) MetricDocument { + output := MetricDocument{Time: timestamp} + output.Data.BaseData.DimNames = dimNames + output.Data.BaseData.Metric = metric + output.Data.BaseData.Namespace = namespace + return output +} + +type Monitor struct { + *accessToken + tenantId string +} + +func NewMonitor(tenantId string) *Monitor { + return &Monitor{ + accessToken: &accessToken{ + url: fmt.Sprintf("https://login.microsoftonline.com/%s/oauth2/token", tenantId), + values: url.Values{ + "grant_type": {"client_credentials"}, + "client_id": {microsoftAppId()}, + "resource": {"https://monitoring.azure.com/"}, + "client_secret": {microsoftAppPassword()}, + }, + }, + tenantId: tenantId, + } +} + +func (m Monitor) PostMetrics(url string, docs ...MetricDocument) error { + // https://docs.microsoft.com/ko-kr/rest/api/monitor/metrics%20(data%20plane)/create + token, err := m.getAuthoizationToken() + if err != nil { + return err + } + + for _, doc := range docs { + bt, _ := json.Marshal(doc) + req, _ := http.NewRequest("POST", url, bytes.NewReader(bt)) + req.Header.Add("Authorization", token) + req.Header.Add("Content-Type", "application/json") + resp, err := http.DefaultClient.Do(req) + if err != nil { + return err + } + + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + } + return nil +} diff --git a/document/document.go b/document/document.go new file mode 100644 index 0000000..f7ebd56 --- /dev/null +++ b/document/document.go @@ -0,0 +1,22 @@ +package document + +import "errors" + +// Document interface +type Document interface { + ReadBool(path string) (bool, error) + ReadString(path string) (string, error) + Read(path string) (interface{}, error) + + Write(path string, val interface{}) + Serialize() error +} + +// ErrDocumentNotExist : 문서가 존재하지 않을 때 +var ErrDocumentNotExist = errors.New("document does not exist") + +// ErrDocumentPathNotExist : 문서의 경로가 존재하지 않을 때 +var ErrDocumentPathNotExist = errors.New("given document path does not exist") + +// ErrDocumentPathTypeMismatch : 해당 경로의 값 타입이 다름 +var ErrDocumentPathTypeMismatch = errors.New("given document path has different value of type") diff --git a/document/document_filebase.go b/document/document_filebase.go new file mode 100644 index 0000000..89cb9b7 --- /dev/null +++ b/document/document_filebase.go @@ -0,0 +1,179 @@ +package document + +import ( + "encoding/json" + "fmt" + "os" + "path" + "strings" +) + +func getDocumentFilepath(owner string) string { + return fmt.Sprintf("./docs/%s.json", owner) +} + +type fileDocument struct { + owner string + root map[string]interface{} + dirty bool +} + +// LoadFileDocument : 파일로 도큐먼트 로딩 +func LoadFileDocument(owner string) (Document, error) { + bt, err := os.ReadFile(getDocumentFilepath(owner)) + if err != nil && !os.IsNotExist(err) { + return nil, err + } + + if len(bt) == 0 { + return &fileDocument{ + owner: owner, + root: make(map[string]interface{}), + dirty: false, + }, nil + } + + var doc map[string]interface{} + err = json.Unmarshal(bt, &doc) + if err != nil { + return nil, err + } + + return &fileDocument{ + owner: owner, + root: doc, + dirty: false, + }, nil +} + +func findNodeInterface(doc map[string]interface{}, path string) (interface{}, error) { + if doc == nil { + return nil, ErrDocumentNotExist + } + + parent := doc + for { + idx := strings.IndexRune(path, '/') + var nodename string + if idx < 0 { + nodename = path + path = "" + } else { + nodename = path[:idx] + path = path[idx+1:] + } + child, ok := parent[nodename] + + if !ok { + return nil, ErrDocumentPathNotExist + } + + if len(path) == 0 { + return child, nil + } + + if parent, ok = child.(map[string]interface{}); !ok { + return nil, ErrDocumentPathTypeMismatch + } + } +} + +func findEdgeContainer(doc map[string]interface{}, path string) (map[string]interface{}, string) { + parent := doc + for { + idx := strings.IndexRune(path, '/') + var nodename string + if idx < 0 { + nodename = path + path = "" + } else { + nodename = path[:idx] + path = path[idx+1:] + } + + if len(path) == 0 { + return parent, nodename + } + + child, ok := parent[nodename] + if !ok { + child = make(map[string]interface{}) + parent[nodename] = child + } + + parent = child.(map[string]interface{}) + } +} + +// ReadBool : +func (doc *fileDocument) ReadBool(path string) (bool, error) { + val, err := findNodeInterface(doc.root, path) + if err != nil { + return false, err + } + + out, ok := val.(bool) + if !ok { + return false, ErrDocumentPathTypeMismatch + } + + return out, nil +} + +// ReadString : +func (doc *fileDocument) ReadString(path string) (string, error) { + val, err := findNodeInterface(doc.root, path) + if err != nil { + return "", err + } + + out, ok := val.(string) + if !ok { + return "", ErrDocumentPathTypeMismatch + } + + return out, nil +} + +// Read : +func (doc *fileDocument) Read(path string) (interface{}, error) { + return findNodeInterface(doc.root, path) +} + +// WriteBool : +func (doc *fileDocument) Write(path string, val interface{}) { + container, edge := findEdgeContainer(doc.root, path) + container[edge] = val + doc.dirty = true +} + +// Serialize : +func (doc *fileDocument) Serialize() error { + if !doc.dirty { + return nil + } + + bt, err := json.Marshal(doc.root) + if err != nil { + return err + } + + filepath := getDocumentFilepath(doc.owner) + if err := os.WriteFile(filepath, bt, 0644); err != nil { + if _, patherr := err.(*os.PathError); !patherr { + return err + } + + dir := path.Dir(filepath) + if err = os.MkdirAll(dir, 0755); err != nil { + return err + } + + if err = os.WriteFile(filepath, bt, 0644); err != nil { + return err + } + } + doc.dirty = false + + return nil +} diff --git a/flag/flag_helper.go b/flag/flag_helper.go new file mode 100644 index 0000000..e36962c --- /dev/null +++ b/flag/flag_helper.go @@ -0,0 +1,56 @@ +package flag + +import ( + "flag" + "fmt" + "os" + "strings" +) + +var flaged = make(map[string]interface{}) + +func IntNoCase(name string, value int, usage string) *int { + test := fmt.Sprintf("--%s=", name) + for _, arg := range os.Args { + if strings.HasPrefix(strings.ToLower(arg), test) || strings.HasPrefix(strings.ToLower(arg), test[1:]) { + kv := strings.Split(arg, "=") + realname := strings.TrimLeft(kv[0], "-") + oldptr, ok := flaged[realname] + + if !ok { + ptr := flag.Int(realname, value, usage) + flaged[realname] = ptr + return ptr + } + + return oldptr.(*int) + } + } + + defval := new(int) + *defval = value + return defval +} + +func StringNoCase(name string, value string, usage string) *string { + test := fmt.Sprintf("--%s=", name) + for _, arg := range os.Args { + if strings.HasPrefix(strings.ToLower(arg), test) || strings.HasPrefix(strings.ToLower(arg), test[1:]) { + kv := strings.Split(arg, "=") + realname := strings.TrimLeft(kv[0], "-") + oldptr, ok := flaged[realname] + if !ok { + ptr := flag.String(realname, value, usage) + flaged[realname] = ptr + return ptr + } + + return oldptr.(*string) + } + } + + defval := new(string) + *defval = value + + return defval +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..fe1d182 --- /dev/null +++ b/go.mod @@ -0,0 +1,28 @@ +module repositories.action2quare.com/ayo/gocommon + +go 1.19 + +replace repositories.action2quare.com/ayo/gocommon => ./ + +require ( + github.com/go-redis/redis/v8 v8.11.5 + github.com/golang-jwt/jwt v3.2.2+incompatible + github.com/pires/go-proxyproto v0.7.0 + go.mongodb.org/mongo-driver v1.11.6 + golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d +) + +require ( + github.com/cespare/xxhash/v2 v2.1.2 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/golang/snappy v0.0.1 // indirect + github.com/klauspost/compress v1.13.6 // indirect + github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/xdg-go/pbkdf2 v1.0.0 // indirect + github.com/xdg-go/scram v1.1.1 // indirect + github.com/xdg-go/stringprep v1.0.3 // indirect + github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d // indirect + golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect + golang.org/x/text v0.3.7 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..c8217c5 --- /dev/null +++ b/go.sum @@ -0,0 +1,73 @@ +github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE= +github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= +github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= +github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= +github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= +github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= +github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4= +github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/google/go-cmp v0.5.2 h1:X2ev0eStA3AbceY54o37/0PQ/UWqKEiiO2dKL5OPaFM= +github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/klauspost/compress v1.13.6 h1:P76CopJELS0TiO2mebmnzgWaajssP/EszplttgQxcgc= +github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe h1:iruDEfMl2E6fbMZ9s0scYfZQ84/6SPL6zC8ACM2oIL0= +github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc= +github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= +github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= +github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE= +github.com/pires/go-proxyproto v0.7.0 h1:IukmRewDQFWC7kfnb66CSomk2q/seBuilHBYFwyq0Hs= +github.com/pires/go-proxyproto v0.7.0/go.mod h1:Vz/1JPY/OACxWGQNIRY2BeyDmpoaWmEP40O9LbuiFR4= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/tidwall/pretty v1.0.0 h1:HsD+QiTn7sK6flMKIvNmpqz1qrpP3Ps6jOKIKMooyg4= +github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= +github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c= +github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= +github.com/xdg-go/scram v1.1.1 h1:VOMT+81stJgXW3CpHyqHN3AXDYIMsx56mEFrB37Mb/E= +github.com/xdg-go/scram v1.1.1/go.mod h1:RaEWvsqvNKKvBPvcKeFjrG2cJqOkHTiyTpzz23ni57g= +github.com/xdg-go/stringprep v1.0.3 h1:kdwGpVNwPFtjs98xCGkHjQtGKh86rDcRZN17QEMCOIs= +github.com/xdg-go/stringprep v1.0.3/go.mod h1:W3f5j4i+9rC0kuIEJL0ky1VpHXQU3ocBgklLGvcBnW8= +github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d h1:splanxYIlg+5LfHAM6xpdFEAYOk8iySO56hMFq6uLyA= +github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d/go.mod h1:rHwXgn7JulP+udvsHwJoVG1YGAP6VLg4y9I5dyZdqmA= +go.mongodb.org/mongo-driver v1.11.6 h1:XM7G6PjiGAO5betLF13BIa5TlLUUE3uJ/2Ox3Lz1K+o= +go.mongodb.org/mongo-driver v1.11.6/go.mod h1:G9TgswdsWjX4tmDA5zfs2+6AEPpYJwqblyjsfuh8oXY= +golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d h1:sK3txAijHtOK88l68nt020reeT1ZdKLIYetKl95FzVY= +golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2 h1:CIJ76btIcR3eFI5EgSo6k1qKw9KJexJuRLI9G7Hp5wE= +golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e h1:fLOSk5Q00efkSvAm+4xcoXD+RRmLmmulPn5I3Y9F2EM= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/locker_redis.go b/locker_redis.go new file mode 100644 index 0000000..4613791 --- /dev/null +++ b/locker_redis.go @@ -0,0 +1,43 @@ +package common + +import ( + "context" + "errors" + "time" + + "repositories.action2quare.com/ayo/gocommon/logger" + + "github.com/go-redis/redis/v8" +) + +var txSetArgs redis.SetArgs = redis.SetArgs{ + TTL: time.Second * 30, + Mode: "NX", +} + +type LockerWithRedis struct { + key string +} + +var ErrTransactionLocked = errors.New("transaction is already locked") +var ErrTransactionHSet = errors.New("transaction set is failed by unkwoun reason") + +func (locker *LockerWithRedis) Lock(rc *redis.Client, key string) error { + resultStr, err := rc.SetArgs(context.Background(), key, true, txSetArgs).Result() + if err != nil && err != redis.Nil { + return err + } + + if len(resultStr) <= 0 { + // 이미 있네? 락이 걸려있다 + logger.Println(ErrTransactionLocked, key) + return ErrTransactionLocked + } + + locker.key = key + return nil +} + +func (locker *LockerWithRedis) Unlock(rc *redis.Client) { + rc.Del(context.Background(), locker.key).Result() +} diff --git a/logger/logger.go b/logger/logger.go new file mode 100644 index 0000000..216e299 --- /dev/null +++ b/logger/logger.go @@ -0,0 +1,106 @@ +package logger + +import ( + "fmt" + "io" + "log" + "os" + "path" + "runtime/debug" + "strings" + "time" +) + +var stdlogger *log.Logger +var errlogger *log.Logger + +func init() { + binpath, _ := os.Executable() + binname := path.Base(strings.ReplaceAll(binpath, "\\", "/")) + + logpath := os.Getenv("AYO_LOGGER_FILE_PATH") + + if len(logpath) == 0 { + stdlogger = log.New(os.Stdout, "", log.LstdFlags) + errlogger = log.New(os.Stderr, "", log.LstdFlags) + } else { + ext := path.Ext(binname) + if len(ext) > 0 { + binname = binname[:len(binname)-len(ext)] + } + + if _, err := os.Stat(logpath); os.IsNotExist(err) { + os.Mkdir("logs", 0777) + } + + now := time.Now() + logFile, err := os.Create(fmt.Sprintf("%s/%s_%s.log", logpath, binname, now.Format("2006-01-02T15-04-05"))) + if err != nil { + panic(err) + } + stdlogger = log.New(io.MultiWriter(os.Stdout, logFile), "", log.LstdFlags) + if fi, err := os.Stat(binpath); err == nil { + stdlogger.Println(fi.Name(), fi.ModTime()) + } + + logFile, err = os.Create(fmt.Sprintf("%s/%s_%s.err", logpath, binname, now.Format("2006-01-02T15-04-05"))) + if err != nil { + panic(err) + } + errlogger = log.New(io.MultiWriter(os.Stderr, logFile), "", log.LstdFlags) + if fi, err := os.Stat(binpath); err == nil { + errlogger.Println(fi.Name(), fi.ModTime()) + } + } + stdlogger.Println(binname) + errlogger.Println(binname) +} + +func Println(v ...interface{}) { + stdlogger.Output(2, fmt.Sprintln(v...)) +} + +func Printf(format string, v ...interface{}) { + stdlogger.Output(2, fmt.Sprintf(format, v...)) +} + +func Error(v ...interface{}) { + errlogger.Output(2, fmt.Sprintln(v...)) + errlogger.Output(2, string(debug.Stack())) +} + +func Errorf(format string, v ...interface{}) { + errlogger.Output(2, fmt.Sprintf(format, v...)) + errlogger.Output(2, string(debug.Stack())) +} + +func Fatal(v ...interface{}) { + errlogger.Output(2, fmt.Sprint(v...)) + errlogger.Output(2, string(debug.Stack())) + os.Exit(1) +} +func Fatalln(v ...interface{}) { + errlogger.Output(2, fmt.Sprintln(v...)) + errlogger.Output(2, string(debug.Stack())) + os.Exit(1) +} +func Panic(v ...interface{}) { + s := fmt.Sprint(v...) + errlogger.Output(2, s) + errlogger.Output(2, string(debug.Stack())) + panic(s) +} + +func Panicf(format string, v ...interface{}) { + s := fmt.Sprintf(format, v...) + errlogger.Output(2, s) + errlogger.Output(2, string(debug.Stack())) + panic(s) +} + +func Panicln(v ...interface{}) { + s := fmt.Sprintln(v...) + errlogger.Output(2, s) + errlogger.Output(2, string(debug.Stack())) + panic(s) +} diff --git a/misc.go b/misc.go new file mode 100644 index 0000000..cd76c0e --- /dev/null +++ b/misc.go @@ -0,0 +1,104 @@ +package common + +import ( + "encoding/binary" + "encoding/json" + "fmt" + "io" + "math/rand" + "reflect" + "strconv" + "strings" + "sync/atomic" + "time" +) + +var sequenceStart = rand.Uint32() + +func MakeHttpHandlerPattern(n ...string) string { + r := strings.ReplaceAll(strings.Join(n, "/"), "//", "/") + if len(r) == 0 { + return "/" + } + + if len(r) > 0 && r[0] != '/' { + r = "/" + r + } + + return r +} + +func MakeLocalUniqueId() string { + var b [10]byte + + now := time.Now() + binary.BigEndian.PutUint32(b[6:10], uint32(now.Unix())) + binary.BigEndian.PutUint32(b[3:7], uint32(now.Nanosecond())) + binary.BigEndian.PutUint32(b[0:4], atomic.AddUint32(&sequenceStart, 1)) + + u := binary.LittleEndian.Uint64(b[2:]) + a := strconv.FormatUint(u, 36) + return a[1:] +} + +func SerializeInterface(w io.Writer, val interface{}) (err error) { + if val == nil { + return nil + } + + value := reflect.ValueOf(val) + encoder := json.NewEncoder(w) + + switch value.Kind() { + case reflect.String: + _, err = w.Write([]byte(value.String())) + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + _, err = w.Write([]byte(fmt.Sprintf("%d", value.Int()))) + + case reflect.Float32, reflect.Float64: + _, err = w.Write([]byte(fmt.Sprintf("%f", value.Float()))) + + case reflect.Bool: + if value.Bool() { + _, err = w.Write([]byte("true")) + } else { + _, err = w.Write([]byte("false")) + } + + case reflect.Slice: + switch value.Type().Elem().Kind() { + case reflect.Uint8: + _, err = w.Write(value.Bytes()) + + default: + var conv []interface{} + for i := 0; i < value.Len(); i++ { + conv = append(conv, value.Index(i).Interface()) + } + if len(conv) == 0 { + _, err = w.Write([]byte("[]")) + } else { + err = encoder.Encode(conv) + } + } + + case reflect.Interface, reflect.Struct, reflect.Map: + if value.Kind() == reflect.Struct { + err = encoder.Encode(value.Interface()) + } else if !value.IsNil() { + err = encoder.Encode(value.Interface()) + } + + case reflect.Ptr: + if !value.IsNil() { + if wro, ok := value.Interface().(io.WriterTo); ok { + _, err = wro.WriteTo(w) + } else { + err = json.NewEncoder(w).Encode(value.Interface()) + } + } + } + + return +} diff --git a/mongo.go b/mongo.go new file mode 100644 index 0000000..e7d1510 --- /dev/null +++ b/mongo.go @@ -0,0 +1,485 @@ +package common + +import ( + "context" + "encoding/json" + "errors" + "flag" + "os" + "time" + + "repositories.action2quare.com/ayo/gocommon/logger" + + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/event" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" + "go.mongodb.org/mongo-driver/mongo/readpref" +) + +type MongoClient struct { + db *mongo.Database + c *mongo.Client +} + +type ConnectionInfo struct { + Url string + Database string +} + +var Devflag = flag.Bool("dev", false, "") + +type CollectionName string + +func ParseObjectID(hexstr string) (out primitive.ObjectID) { + out, _ = primitive.ObjectIDFromHex(hexstr) + return +} + +func NewMongoConnectionInfo(url string, dbname string) *ConnectionInfo { + if len(dbname) == 0 { + panic("dbname is empty") + } + + if *Devflag { + hostname, _ := os.Hostname() + dbname = hostname + "-" + dbname + } + + return &ConnectionInfo{ + Url: url, + Database: dbname, + } +} + +func (ci *ConnectionInfo) SetURL(url string) *ConnectionInfo { + ci.Url = url + return ci +} + +func (ci *ConnectionInfo) SetDatabase(dbname string) *ConnectionInfo { + ci.Database = dbname + return ci +} + +func NewMongoClient(ctx context.Context, url string, dbname string) (MongoClient, error) { + return newMongoClient(ctx, NewMongoConnectionInfo(url, dbname)) +} + +func newMongoClient(ctx context.Context, ci *ConnectionInfo) (MongoClient, error) { + if len(ci.Url) == 0 { + return MongoClient{}, errors.New("mongo connection string is empty") + } + + secondaryPref := readpref.SecondaryPreferred() + //client, err := mongo.NewClient(options.Client().ApplyURI(ci.Url).SetReadPreference(secondaryPref)) + client, err := mongo.NewClient(options.Client().ApplyURI(ci.Url).SetReadPreference(secondaryPref).SetServerMonitor(&event.ServerMonitor{ + ServerOpening: func(evt *event.ServerOpeningEvent) { + logger.Println("mongodb ServerOpening :", *evt) + }, + ServerClosed: func(evt *event.ServerClosedEvent) { + logger.Println("mongodb ServerClosed :", *evt) + }, + TopologyOpening: func(evt *event.TopologyOpeningEvent) { + logger.Println("mongodb TopologyOpening :", *evt) + }, + TopologyClosed: func(evt *event.TopologyClosedEvent) { + logger.Println("mongodb TopologyClosed :", *evt) + }, + })) + if err != nil { + return MongoClient{}, err + } + + err = client.Connect(ctx) + if err != nil { + return MongoClient{}, err + } + + go func() { + for { + if err := client.Ping(ctx, nil); err != nil { + logger.Error("mongo client ping err :", err) + } + + select { + case <-time.After(10 * time.Second): + continue + + case <-ctx.Done(): + return + } + } + }() + + mdb := client.Database(ci.Database, nil) + return MongoClient{c: client, db: mdb}, nil +} + +func (mc MongoClient) Connected() bool { + return mc.db != nil && mc.c != nil +} + +func (mc MongoClient) Close() { + if mc.c != nil { + mc.c.Disconnect(context.Background()) + } +} + +func (mc MongoClient) Watch(coll CollectionName, pipeline mongo.Pipeline, opts ...*options.ChangeStreamOptions) (*mongo.ChangeStream, error) { + if len(opts) == 0 { + opts = []*options.ChangeStreamOptions{options.ChangeStream().SetFullDocument(options.UpdateLookup).SetMaxAwaitTime(0)} + } + return mc.Collection(coll).Watch(context.Background(), pipeline, opts...) +} + +func (mc MongoClient) Collection(collname CollectionName) *mongo.Collection { + return mc.db.Collection(string(collname)) +} + +func (mc MongoClient) All(coll CollectionName, opts ...*options.FindOptions) ([]bson.M, error) { + cursor, err := mc.Collection(coll).Find(context.Background(), bson.D{}, opts...) + if err != nil { + return nil, err + } + defer cursor.Close(context.Background()) + + var all []bson.M + err = cursor.All(context.Background(), &all) + if err != nil { + return nil, err + } + + return all, nil +} + +func (mc MongoClient) FindOneAndDelete(coll CollectionName, filter bson.M, opts ...*options.FindOneAndDeleteOptions) (bson.M, error) { + result := mc.Collection(coll).FindOneAndDelete(context.Background(), filter, opts...) + err := result.Err() + if err != nil { + if err == mongo.ErrNoDocuments { + return nil, nil + } + return nil, err + } + + tmp := make(map[string]interface{}) + err = result.Decode(&tmp) + if err != nil { + return nil, err + } + + return bson.M(tmp), nil +} + +func (mc MongoClient) Delete(coll CollectionName, filter bson.M, opts ...*options.DeleteOptions) (bool, error) { + r, err := mc.Collection(coll).DeleteOne(context.Background(), filter, opts...) + if err != nil { + return false, err + } + + return r.DeletedCount > 0, nil +} + +func (mc MongoClient) UnsetField(coll CollectionName, filter bson.M, doc bson.M) error { + _, err := mc.Collection(coll).UpdateOne(context.Background(), filter, bson.M{ + "$unset": doc, + }) + return err +} + +func (mc MongoClient) DeleteMany(coll CollectionName, filters bson.D, opts ...*options.DeleteOptions) (int, error) { + if len(filters) == 0 { + // 큰일난다 + return 0, nil + } + + result, err := mc.Collection(coll).DeleteMany(context.Background(), filters, opts...) + if err != nil { + return 0, err + } + + return int(result.DeletedCount), nil +} + +type CommandInsertMany[T any] struct { + MongoClient + Collection CollectionName + Documents []T +} + +func (c *CommandInsertMany[T]) Exec(opts ...*options.InsertManyOptions) (int, error) { + conv := make([]any, len(c.Documents)) + for i, v := range c.Documents { + conv[i] = v + } + return c.InsertMany(c.Collection, conv, opts...) +} + +func (mc MongoClient) InsertMany(coll CollectionName, documents []interface{}, opts ...*options.InsertManyOptions) (int, error) { + result, err := mc.Collection(coll).InsertMany(context.Background(), documents, opts...) + if err != nil { + return 0, err + } + + return len(result.InsertedIDs), nil +} + +func (mc MongoClient) UpdateMany(coll CollectionName, filter bson.M, doc bson.M, opts ...*options.UpdateOptions) (count int, err error) { + result, e := mc.Collection(coll).UpdateMany(context.Background(), filter, doc, opts...) + + if e != nil { + return 0, e + } + + err = nil + count = int(result.UpsertedCount + result.ModifiedCount) + return +} + +type Marshaler interface { + MarshalBSON() ([]byte, error) +} + +type JsonDefaultMashaller struct { + doc *bson.M +} + +func (m *JsonDefaultMashaller) MarshalBSON() ([]byte, error) { + return json.Marshal(m.doc) +} + +func (mc MongoClient) Update(coll CollectionName, filter bson.M, doc interface{}, opts ...*options.UpdateOptions) (worked bool, newid interface{}, err error) { + result, e := mc.Collection(coll).UpdateOne(context.Background(), filter, doc, opts...) + + if e != nil { + return false, "", e + } + + err = nil + worked = result.MatchedCount > 0 || result.UpsertedCount > 0 || result.ModifiedCount > 0 + newid = result.UpsertedID + return +} + +func (mc MongoClient) UpsertOne(coll CollectionName, filter bson.M, doc interface{}) (worked bool, newid interface{}, err error) { + return mc.Update(coll, filter, bson.M{ + "$set": doc, + }, options.Update().SetUpsert(true)) + + // return mc.Update(coll, filter, &JsonDefaultMashaller{doc: &bson.M{ + // "$set": doc, + // }}, options.Update().SetUpsert(true)) +} + +func (mc MongoClient) FindOneAs(coll CollectionName, filter bson.M, out interface{}, opts ...*options.FindOneOptions) error { + err := mc.Collection(coll).FindOne(context.Background(), filter, opts...).Decode(out) + if err == mongo.ErrNoDocuments { + err = nil + } + return err +} + +func (mc MongoClient) FindOne(coll CollectionName, filter bson.M, opts ...*options.FindOneOptions) (doc bson.M, err error) { + result := mc.Collection(coll).FindOne(context.Background(), filter, opts...) + tmp := make(map[string]interface{}) + err = result.Decode(&tmp) + if err == nil { + doc = bson.M(tmp) + } else if err == mongo.ErrNoDocuments { + err = nil + } + + return +} + +func (mc MongoClient) FindOneAndUpdateAs(coll CollectionName, filter bson.M, doc bson.M, out interface{}, opts ...*options.FindOneAndUpdateOptions) error { + result := mc.Collection(coll).FindOneAndUpdate(context.Background(), filter, doc, opts...) + err := result.Decode(out) + if err == nil { + return nil + } + + if err == mongo.ErrNoDocuments { + return nil + } + + return err +} + +func (mc MongoClient) FindOneAndUpdate(coll CollectionName, filter bson.M, doc bson.M, opts ...*options.FindOneAndUpdateOptions) (olddoc bson.M, err error) { + result := mc.Collection(coll).FindOneAndUpdate(context.Background(), filter, doc, opts...) + tmp := make(map[string]interface{}) + err = result.Decode(&tmp) + if err == nil { + olddoc = bson.M(tmp) + } else if err == mongo.ErrNoDocuments { + err = nil + } + + return +} + +func (mc MongoClient) Exists(coll CollectionName, filter bson.M) (bool, error) { + cnt, err := mc.Collection(coll).CountDocuments(context.Background(), filter, options.Count().SetLimit(1)) + if err != nil { + return false, err + } + return cnt > 0, nil +} + +func (mc MongoClient) SearchText(coll CollectionName, text string, opts ...*options.FindOptions) ([]bson.M, error) { + cursor, err := mc.Collection(coll).Find(context.Background(), bson.M{"$text": bson.M{"$search": text}}, opts...) + if err != nil { + return nil, err + } + defer cursor.Close(context.Background()) + + var output []bson.M + err = cursor.All(context.Background(), &output) + if err != nil { + return nil, err + } + + return output, nil +} + +func (mc MongoClient) FindAll(coll CollectionName, filter bson.M, opts ...*options.FindOptions) ([]bson.M, error) { + cursor, err := mc.Collection(coll).Find(context.Background(), filter, opts...) + if err != nil { + return nil, err + } + defer cursor.Close(context.Background()) + + var output []bson.M + err = cursor.All(context.Background(), &output) + if err != nil { + return nil, err + } + + return output, nil +} + +func (mc MongoClient) FindAllAs(coll CollectionName, filter bson.M, output interface{}, opts ...*options.FindOptions) error { + cursor, err := mc.Collection(coll).Find(context.Background(), filter, opts...) + if err != nil { + return err + } + defer cursor.Close(context.Background()) + + err = cursor.All(context.Background(), output) + if err != nil { + return err + } + return nil +} + +func (mc MongoClient) MakeExpireIndex(coll CollectionName, expireSeconds int32) error { + matchcoll := mc.Collection(coll) + indices, err := matchcoll.Indexes().List(context.Background(), options.ListIndexes().SetMaxTime(time.Second)) + if err != nil { + return err + } + + allindices := make([]interface{}, 0) + err = indices.All(context.Background(), &allindices) + if err != nil { + return err + } + + tsfound := false + var tsname string + var exp int32 + +IndexSearchLabel: + for _, index := range allindices { + d := index.(bson.D) + key := d.Map()["key"].(bson.D) + for _, kd := range key { + if kd.Key == "_ts" { + tsfound = true + + if v, ok := d.Map()["name"]; ok { + tsname = v.(string) + } + if v, ok := d.Map()["expireAfterSeconds"]; ok { + exp = v.(int32) + } + break IndexSearchLabel + } + } + } + + if tsfound { + if exp == expireSeconds { + return nil + } + _, err = matchcoll.Indexes().DropOne(context.Background(), tsname) + if err != nil { + return err + } + } + + mod := mongo.IndexModel{ + Keys: primitive.M{"_ts": 1}, + Options: options.Index().SetExpireAfterSeconds(expireSeconds), + } + + _, err = matchcoll.Indexes().CreateOne(context.Background(), mod) + return err +} + +func (mc MongoClient) makeIndicesWithOption(coll CollectionName, indices map[string]bson.D, option *options.IndexOptions) error { + collection := mc.Collection(coll) + cursor, err := collection.Indexes().List(context.Background(), options.ListIndexes().SetMaxTime(time.Second)) + if err != nil { + return err + } + defer cursor.Close(context.Background()) + + found := make(map[string]bool) + for k := range indices { + found[k] = false + } + + for cursor.TryNext(context.Background()) { + rawval := cursor.Current + name := rawval.Lookup("name").StringValue() + if _, ok := indices[name]; ok { + found[name] = true + } + } + + for name, exist := range found { + if !exist { + v := indices[name] + var mod mongo.IndexModel + if len(v) == 1 { + mod = mongo.IndexModel{ + Keys: primitive.M{v[0].Key: v[0].Value}, + Options: options.MergeIndexOptions(options.Index().SetName(name), option), + } + } else { + mod = mongo.IndexModel{ + Keys: indices[name], + Options: options.MergeIndexOptions(options.Index().SetName(name), option), + } + } + + _, err = collection.Indexes().CreateOne(context.Background(), mod) + if err != nil { + return err + } + } + } + return nil +} + +func (mc MongoClient) MakeUniqueIndices(coll CollectionName, indices map[string]bson.D) error { + return mc.makeIndicesWithOption(coll, indices, options.Index().SetUnique(true)) +} + +func (mc MongoClient) MakeIndices(coll CollectionName, indices map[string]bson.D) error { + return mc.makeIndicesWithOption(coll, indices, options.Index()) +} diff --git a/redis.go b/redis.go new file mode 100644 index 0000000..b4974c4 --- /dev/null +++ b/redis.go @@ -0,0 +1,70 @@ +package common + +import ( + "context" + "net/url" + "os" + "strconv" + + "github.com/go-redis/redis/v8" + "go.mongodb.org/mongo-driver/bson/primitive" +) + +type Authinfo struct { + Accid primitive.ObjectID `bson:"_id,omitempty" json:"_id,omitempty"` + ServiceCode string `bson:"code" json:"code"` + Platform string + Uid string + Token string + Sk primitive.ObjectID + RefreshToken string `bson:"refresh_token,omitempty" json:"refresh_token,omitempty"` + Expired primitive.DateTime `bson:"_ts" json:"_ts"` +} + +func newRedisClient(uri string, dbidxoffset int) *redis.Client { + option, err := redis.ParseURL(uri) + if err != nil { + return nil + } + option.DB += dbidxoffset + + return redis.NewClient(option) +} + +func NewRedisClient(uri string, dbidx int) (*redis.Client, error) { + if !*Devflag { + return newRedisClient(uri, dbidx), nil + } + + rdb := newRedisClient(uri, 0) + devUrl, _ := url.Parse(uri) + hostname, _ := os.Hostname() + myidx, _ := rdb.HGet(context.Background(), "private_db", hostname).Result() + if len(myidx) > 0 { + devUrl.Path = "/" + myidx + return newRedisClient(devUrl.String(), dbidx), nil + } + + alldbs, err := rdb.HGetAll(context.Background(), "private_db").Result() + if err != nil { + rdb.Close() + return nil, err + } + + maxidx := 0 + for _, prvdb := range alldbs { + actualidx, _ := strconv.Atoi(prvdb) + if maxidx < actualidx { + maxidx = actualidx + } + } + + newidx := maxidx + 1 + _, err = rdb.HSet(context.Background(), "private_db", hostname, newidx).Result() + if err != nil { + return nil, err + } + + devUrl.Path = "/" + strconv.Itoa(newidx) + return newRedisClient(devUrl.String(), dbidx), nil +} diff --git a/reflect_config.go b/reflect_config.go new file mode 100644 index 0000000..887be23 --- /dev/null +++ b/reflect_config.go @@ -0,0 +1,75 @@ +package common + +import ( + "encoding/json" + "flag" + "os" + "time" +) + +var configfileflag = flag.String("config", "", "") + +func configFilePath() string { + configfilepath := "./config_template.json" + if configfileflag != nil && len(*configfileflag) > 0 { + configfilepath = *configfileflag + } + + return configfilepath +} + +func ConfigModTime() time.Time { + fi, err := os.Stat(configFilePath()) + if err == nil { + return fi.ModTime() + } + return time.Time{} +} + +func MonitorConfig[T any](onChanged func(newconf *T)) error { + fi, err := os.Stat(configFilePath()) + if err != nil { + return err + } + + go func(modTime time.Time, filepath string) { + for { + if fi, err := os.Stat(filepath); err == nil { + if modTime != fi.ModTime() { + var next T + if err := LoadConfig(&next); err == nil { + fi, _ := os.Stat(filepath) + modTime = fi.ModTime() + + onChanged(&next) + } + } + } + time.Sleep(time.Second) + } + }(fi.ModTime(), configFilePath()) + + return nil +} + +func LoadConfig[T any](outptr *T) error { + configfilepath := configFilePath() + content, err := os.ReadFile(configfilepath) + if os.IsNotExist(err) { + return os.WriteFile(configfilepath, []byte("{}"), 0666) + } + + return json.Unmarshal(content, outptr) +} + +type StorageAddr struct { + Mongo string + Redis struct { + URL string + Offset map[string]int + } +} + +type RegionStorageConfig struct { + RegionStorage map[string]StorageAddr `json:"region_storage"` +} diff --git a/s3/func.go b/s3/func.go new file mode 100644 index 0000000..04ad7f3 --- /dev/null +++ b/s3/func.go @@ -0,0 +1,392 @@ +package s3 + +import ( + "bytes" + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "encoding/xml" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "os" + "path" + "sort" + "strconv" + "strings" + "time" +) + +const ( + HMAC_ALGORITHM = "HmacSHA256" + AWS_ALGORITHM = "AWS4-HMAC-SHA256" + SERVICE_NAME = "s3" + REQUEST_TYPE = "aws4_request" + UNSIGNED_PAYLOAD = "UNSIGNED-PAYLOAD" + REGION_NAME = "kr-standard" + ENDPOINT = "https://kr.object.ncloudstorage.com" +) + +type sortedHeader map[string]string + +var errAccessKeyIsMissing = errors.New("NCLOUD_ACCESS_KEY is missing") +var errSecretKeyIsMissing = errors.New("NCLOUD_SECRET_KEY is missing") +var errRegionIsMissing = errors.New("NCLOUD_REGION is missing") + +func sortVersions(versions []string) []string { + sort.Slice(versions, func(i, j int) bool { + leftnum := 0 + for _, iv := range strings.Split(versions[i], ".") { + n, _ := strconv.Atoi(iv) + leftnum += leftnum<<8 + n + } + + rightnum := 0 + for _, iv := range strings.Split(versions[j], ".") { + n, _ := strconv.Atoi(iv) + rightnum += rightnum<<8 + n + } + + return leftnum < rightnum + }) + return versions +} + +func sign(data string, key []byte) []byte { + mac := hmac.New(sha256.New, key) + mac.Write([]byte(data)) + return mac.Sum(nil) +} + +func hash(text string) string { + s256 := sha256.New() + s256.Write([]byte(text)) + return hex.EncodeToString(s256.Sum(nil)) +} + +func getStandardizedQueryParameters(query url.Values) string { + return query.Encode() +} + +func getSignedHeaders(header http.Header) string { + keys := make([]string, 0, len(header)) + for k := range header { + keys = append(keys, strings.ToLower(k)) + } + sort.Strings(keys) + return strings.Join(keys, ";") + ";" +} + +func getStandardizedHeaders(header http.Header) string { + keys := make([]string, 0, len(header)) + for k := range header { + keys = append(keys, k) + } + sort.Strings(keys) + + standardHeaders := make([]string, 0, len(header)) + for _, k := range keys { + standardHeaders = append(standardHeaders, fmt.Sprintf("%s:%s", strings.ToLower(k), header.Get(k))) + } + return strings.Join(standardHeaders, "\n") + "\n" +} + +func getCanonicalRequest(req *http.Request, standardizedQueryParam string, standardHeaders string, signedHeader string) string { + return strings.Join([]string{ + req.Method, + req.URL.Path, + standardizedQueryParam, + standardHeaders, + signedHeader, + UNSIGNED_PAYLOAD, + }, "\n") +} + +func getScope(datestamp string, regionName string) string { + return strings.Join([]string{ + datestamp, + regionName, // "kr-standard" + SERVICE_NAME, + REQUEST_TYPE, + }, "/") +} + +func getStringToSign(timestamp string, scope string, canonicalReq string) string { + return strings.Join([]string{ + AWS_ALGORITHM, // AWS_ALGORITHM + timestamp, + scope, + hash(canonicalReq), + }, "\n") +} + +func getSignature(secretKey string, datestamp string, regionName string, stringToSign string) string { + kSecret := []byte("AWS4" + secretKey) + kDate := sign(datestamp, kSecret) + kRegion := sign(regionName, kDate) + kService := sign(SERVICE_NAME, kRegion) + signingKey := sign(REQUEST_TYPE, kService) + + return hex.EncodeToString(sign(stringToSign, signingKey)) +} + +func getAuthorization(accessKey string, scope string, signedHeader string, signature string) string { + signingCredentials := accessKey + "/" + scope + credential := "Credential=" + signingCredentials + signerHeaders := "SignedHeaders=" + signedHeader + signatureHeader := "Signature=" + signature + + return fmt.Sprintf("%s %s, %s, %s", AWS_ALGORITHM, credential, signerHeaders, signatureHeader) +} + +func (s S3) addAuthorizationHeader(req *http.Request) { + req.Header.Add("host", req.Host) + + now := time.Now().UTC() + datestamp := now.Format("20060102") + timestamp := now.Format("20060102T150405Z") + req.Header.Add("x-amz-date", timestamp) + req.Header.Add("x-amz-content-sha256", UNSIGNED_PAYLOAD) + + standardizedQueryParameters := getStandardizedQueryParameters(req.URL.Query()) + signedHeaders := getSignedHeaders(req.Header) + standardizedHeaders := getStandardizedHeaders(req.Header) + canonicalRequest := getCanonicalRequest(req, standardizedQueryParameters, standardizedHeaders, signedHeaders) + scope := getScope(datestamp, s.regionName) + stringToSign := getStringToSign(timestamp, scope, canonicalRequest) + signature := getSignature(s.secretKey, datestamp, s.regionName, stringToSign) + authorization := getAuthorization(s.accessKey, scope, signedHeaders, signature) + req.Header.Add("Authorization", authorization) +} + +type S3 struct { + accessKey string + secretKey string + regionName string +} + +func NewNCloud() (S3, error) { + accessKey := os.Getenv("NCLOUD_ACCESS_KEY") + if len(accessKey) == 0 { + return S3{}, errAccessKeyIsMissing + } + secretKey := os.Getenv("NCLOUD_SECRET_KEY") + if len(secretKey) == 0 { + return S3{}, errSecretKeyIsMissing + } + region := os.Getenv("NCLOUD_REGION") + if len(region) == 0 { + return S3{}, errRegionIsMissing + } + return New(accessKey, secretKey, region), nil +} + +func New(accessKey string, secretKey string, regionName string) S3 { + return S3{ + accessKey: accessKey, + secretKey: secretKey, + regionName: regionName, + } +} + +func (s S3) MakeGetObjectRequest(objectURL string) (*http.Request, error) { + req, err := http.NewRequest("GET", objectURL, nil) + if err != nil { + return nil, err + } + + s.addAuthorizationHeader(req) + + return req, nil +} + +func (s S3) makeGetItemsRequest(prefixURL string, delimiter string) (*http.Request, error) { + u, err := url.Parse(prefixURL) + if err != nil { + return nil, err + } + + endpoint := u.Host + relpath := strings.TrimLeft(u.Path, "/") + ns := strings.SplitN(relpath, "/", 2) + bucket := ns[0] + prefix := "" + if len(ns) > 1 { + prefix = ns[1] + } + + var completeurl string + if len(delimiter) > 0 { + completeurl = fmt.Sprintf("%s://%s/%s?prefix=%s&delimiter=%s", u.Scheme, endpoint, bucket, prefix, delimiter) + } else { + completeurl = fmt.Sprintf("%s://%s/%s?prefix=%s", u.Scheme, endpoint, bucket, prefix) + } + + req, err := http.NewRequest("GET", completeurl, nil) + if err != nil { + return nil, err + } + s.addAuthorizationHeader(req) + + return req, nil +} + +type FileMeta struct { + Key string + LastModified time.Time +} + +type listBucketResult struct { + Name string + Prefix string + Marker string + MaxKeys int + Delimiter string + IsTruncated bool + CommonPrefixes []struct { + Prefix string + } + Contents []FileMeta +} + +func (s S3) ListFiles(prefixURL string) ([]FileMeta, error) { + req, err := s.makeGetItemsRequest(prefixURL, "") + if err != nil { + return nil, err + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + var result listBucketResult + err = xml.Unmarshal(body, &result) + + if err != nil { + return nil, err + } + + var out []FileMeta + for _, c := range result.Contents { + if !strings.HasSuffix(c.Key, "/") { + c.Key = prefixURL + "/" + path.Base(c.Key) + out = append(out, c) + } + } + return out, nil +} + +func (s S3) ListFolders(prefixURL string) ([]string, error) { + req, err := s.makeGetItemsRequest(prefixURL, "/") + if err != nil { + return nil, err + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + var result listBucketResult + err = xml.Unmarshal(body, &result) + if err != nil { + return nil, err + } + + output := make([]string, 0, len(result.CommonPrefixes)) + for _, prefix := range result.CommonPrefixes { + output = append(output, strings.TrimRight(prefix.Prefix, "/")) + } + + return sortVersions(output), nil +} + +func (s S3) ReadFile(url string) (io.ReadCloser, error) { + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return nil, err + } + s.addAuthorizationHeader(req) + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + return resp.Body, nil +} + +func (s S3) DownloadFile(url string, outputFile string) error { + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return err + } + s.addAuthorizationHeader(req) + resp, err := http.DefaultClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + // Create the file + out, err := os.Create(outputFile) + if err != nil { + return err + } + defer out.Close() + + // Write the body to file + _, err = io.Copy(out, resp.Body) + return err +} + +func (s S3) UploadFile(url string, content []byte, publicRead bool) error { + req, err := http.NewRequest("PUT", url, bytes.NewReader(content)) + if err != nil { + return err + } + s.addAuthorizationHeader(req) + if publicRead { + req.Header.Add("x-amz-acl", "public-read") + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode == http.StatusOK { + return nil + } + return io.EOF +} + +// https://api.ncloud-docs.com/docs/storage-objectstorage-putobjectacl +func (s S3) SetObjectACL(url string, acl string) error { + url += "?acl=" + acl + req, err := http.NewRequest("PUT", url, nil) + if err != nil { + return err + } + s.addAuthorizationHeader(req) + resp, err := http.DefaultClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode == http.StatusOK { + return nil + } + return io.EOF +} diff --git a/server.go b/server.go new file mode 100644 index 0000000..7b84032 --- /dev/null +++ b/server.go @@ -0,0 +1,535 @@ +package common + +import ( + "context" + "encoding/json" + "errors" + "flag" + "fmt" + "io" + "math" + "net" + "net/http" + "net/url" + "os" + "os/signal" + "reflect" + "runtime" + "strconv" + "strings" + "sync" + "sync/atomic" + "syscall" + "time" + + "repositories.action2quare.com/ayo/gocommon/logger" + + "github.com/pires/go-proxyproto" + "go.mongodb.org/mongo-driver/bson/primitive" +) + +const ( + // HTTPStatusReloginRequired : http status를 이걸 받으면 클라이언트는 로그아웃하고 로그인 화면으로 돌아가야 한다. + HTTPStatusReloginRequired = 599 + // HTTPStatusReloginRequiredDupID : + HTTPStatusReloginRequiredDupID = 598 + // HTTPStatusPlayerBlocked : + HTTPStatusPlayerBlocked = 597 +) + +type ShutdownFlag int32 + +const ( + ShutdownFlagRunning = ShutdownFlag(0) + ShutdownFlagTerminating = ShutdownFlag(1) + ShutdownFlagRestarting = ShutdownFlag(2) + ShutdownFlagIdle = ShutdownFlag(3) +) + +type RpcReturnTypeInterface interface { + Value() any + Error() error + Serialize(http.ResponseWriter) error +} + +type functionCallContext struct { + Method string + Args []interface{} +} + +// Server : +type Server struct { + httpserver *http.Server + interrupt chan os.Signal +} + +var PrefixPtr = flag.String("prefix", "", "'") +var portptr = flag.Int("port", 80, "") +var tls = flag.String("tls", "", "") +var NoSessionFlag = flag.Bool("nosession", false, "nosession=[true|false]") +var healthcheckcounter = int64(0) + +func healthCheckHandler(w http.ResponseWriter, r *http.Request) { + defer func() { + io.Copy(io.Discard, r.Body) + r.Body.Close() + }() + + // 한번이라도 들어오면 lb에 붙어있다는 뜻 + if t := atomic.AddInt64(&healthcheckcounter, 1); t < 0 { + w.WriteHeader(http.StatusServiceUnavailable) + } +} + +func welcomeHandler(w http.ResponseWriter, r *http.Request) { + defer func() { + io.Copy(io.Discard, r.Body) + r.Body.Close() + }() + + w.Write([]byte("welcome")) +} + +// NewHTTPServer : +func NewHTTPServerWithPort(serveMux *http.ServeMux, port int) *Server { + if len(*tls) > 0 && port == 80 { + port = 443 + } + addr := fmt.Sprintf(":%d", port) + serveMux.HandleFunc(MakeHttpHandlerPattern("welcome"), welcomeHandler) + serveMux.HandleFunc(MakeHttpHandlerPattern("lb_health_chceck"), healthCheckHandler) + + server := &Server{ + httpserver: &http.Server{Addr: addr, Handler: serveMux}, + interrupt: make(chan os.Signal, 1), + } + server.httpserver.SetKeepAlivesEnabled(true) + + signal.Notify(server.interrupt, os.Interrupt, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) + go func() { + c := <-server.interrupt + logger.Println("interrupt!!!!!!!! :", c.String()) + server.shutdown() + }() + + return server +} + +func NewHTTPServer(serveMux *http.ServeMux) *Server { + // 시작시 자동으로 enable됨 + if len(*tls) > 0 && *portptr == 80 { + *portptr = 443 + } + return NewHTTPServerWithPort(serveMux, *portptr) +} + +// Shutdown : +func (server *Server) shutdown() { + logger.Println("server.shutdown()") + + signal.Stop(server.interrupt) + + if atomic.LoadInt64(&healthcheckcounter) > 0 { + atomic.StoreInt64(&healthcheckcounter, math.MinInt64) + for atomic.LoadInt64(&healthcheckcounter)-math.MinInt64 > 10 { + time.Sleep(100 * time.Millisecond) + } + } + + if server.httpserver != nil { + server.httpserver.Shutdown(context.Background()) + server.httpserver = nil + } +} + +// Start : +func (server *Server) Start() error { + if server.httpserver != nil { + ln, r := net.Listen("tcp", server.httpserver.Addr) + if r != nil { + return r + } + proxyListener := &proxyproto.Listener{ + Listener: ln, + ReadHeaderTimeout: 10 * time.Second, + } + defer proxyListener.Close() + + var err error + if len(*tls) > 0 { + crtfile := *tls + ".crt" + keyfile := *tls + ".key" + logger.Println("tls enabled :", crtfile, keyfile) + err = server.httpserver.ServeTLS(proxyListener, crtfile, keyfile) + } else { + err = server.httpserver.Serve(proxyListener) + } + + logger.Error("server.httpserver.Serve err :", err) + if err != http.ErrServerClosed { + return err + } + } + return nil +} + +// ConvertInterface : +func ConvertInterface(from interface{}, toType reflect.Type) reflect.Value { + fromrv := reflect.ValueOf(from) + fromtype := reflect.TypeOf(from) + + if fromtype == toType { + return fromrv + } + + switch fromtype.Kind() { + case reflect.Map: + frommap := from.(map[string]interface{}) + if frommap == nil { + // map[string]interface{}가 아니고 뭐란말인가 + return reflect.Zero(toType) + } + + needAddr := false + if toType.Kind() == reflect.Pointer { + toType = toType.Elem() + needAddr = true + } + + out := reflect.New(toType) + switch toType.Kind() { + case reflect.Struct: + for i := 0; i < toType.NumField(); i++ { + field := toType.FieldByIndex([]int{i}) + if field.Anonymous { + fieldval := out.Elem().FieldByIndex([]int{i}) + fieldval.Set(ConvertInterface(from, field.Type)) + } else if v, ok := frommap[field.Name]; ok { + fieldval := out.Elem().FieldByIndex([]int{i}) + fieldval.Set(ConvertInterface(v, fieldval.Type())) + } + } + + case reflect.Map: + out.Elem().Set(reflect.MakeMap(toType)) + + if toType.Key().Kind() != reflect.String { + for k, v := range frommap { + tok := reflect.ValueOf(k).Convert(toType.Key()) + tov := ConvertInterface(v, toType.Elem()) + out.Elem().SetMapIndex(tok, tov) + } + } else { + for k, v := range frommap { + tov := ConvertInterface(v, toType.Elem()) + out.Elem().SetMapIndex(reflect.ValueOf(k), tov) + } + } + default: + logger.Println("ConvertInterface failed : map ->", toType.Kind().String()) + } + + if needAddr { + return out + } + + return out.Elem() + + case reflect.Slice: + convslice := reflect.MakeSlice(toType, 0, fromrv.Len()) + for i := 0; i < fromrv.Len(); i++ { + convslice = reflect.Append(convslice, ConvertInterface(fromrv.Index(i).Interface(), toType.Elem())) + } + + return convslice + + case reflect.Bool: + val, _ := strconv.ParseBool(from.(string)) + return reflect.ValueOf(val) + } + + return fromrv.Convert(toType) +} + +// ErrUnmarshalRequestFailed : +var ErrUnmarshalRequestFailed = errors.New("unmarshal failed at rpc handler") +var ErrWithStatusCode = errors.New("custom error with status code") +var ErrSecondReturnShouldBeErrorType = errors.New("second return type should be error") + +type methodCacheType struct { + sync.Mutex + toIndex map[reflect.Type]map[string]int +} + +var methodCache = methodCacheType{ + toIndex: make(map[reflect.Type]map[string]int), +} + +func (mc *methodCacheType) method(receiver any, method string) (reflect.Value, bool) { + mc.Lock() + defer mc.Unlock() + + t := reflect.TypeOf(receiver) + table, ok := mc.toIndex[t] + if !ok { + table = make(map[string]int) + mc.toIndex[t] = table + } + + idx, ok := table[method] + if !ok { + m, ok := t.MethodByName(method) + if !ok { + return reflect.Value{}, false + } + + table[method] = m.Index + idx = m.Index + } + + return reflect.ValueOf(receiver).Method(idx), true +} + +// CallMethodInternal : +func CallMethodInternal(receiver interface{}, context functionCallContext) (result RpcReturnTypeInterface, err error) { + fn, ok := methodCache.method(receiver, context.Method) + if !ok { + return nil, errors.New(fmt.Sprint("method is missing :", receiver, context.Method)) + } + if len(context.Args) != fn.Type().NumIn() { + return nil, errors.New(fmt.Sprint("argument is not matching :", context.Method)) + } + + var argv []reflect.Value + for i := 0; i < len(context.Args); i++ { + argv = append(argv, ConvertInterface(context.Args[i], fn.Type().In(i))) + } + returnVals := fn.Call(argv) + + if len(returnVals) == 0 { + return nil, ErrSecondReturnShouldBeErrorType + } + + result = nil + err = nil + if result, ok = returnVals[0].Interface().(RpcReturnTypeInterface); !ok { + return nil, ErrSecondReturnShouldBeErrorType + } + + if len(returnVals) > 1 { + if err, ok = returnVals[1].Interface().(error); !ok { + return nil, ErrSecondReturnShouldBeErrorType + } + } + + return +} + +// CallMethod : +func CallMethod(receiver interface{}, context []byte) (RpcReturnTypeInterface, error) { + defer func() { + r := recover() + if r != nil { + logger.Error(r) + } + }() + + var meta functionCallContext + if err := json.Unmarshal(context, &meta); err != nil { + return nil, ErrUnmarshalRequestFailed + } + + return CallMethodInternal(receiver, meta) +} + +func ReadIntegerFormValue(r url.Values, key string) (int64, bool) { + strval := r.Get(key) + if len(strval) == 0 { + return 0, false + } + temp, err := strconv.ParseInt(strval, 10, 0) + if err != nil { + logger.Error("common.ReadFloatFormValue failed :", key, err) + return 0, false + } + + return temp, true + +} + +func ReadFloatFormValue(r url.Values, key string) (float64, bool) { + var inst float64 + strval := r.Get(key) + if len(strval) == 0 { + return inst, false + } + temp, err := strconv.ParseFloat(strval, 64) + if err != nil { + logger.Error("common.ReadFloatFormValue failed :", key, err) + return inst, false + } + + return temp, true +} + +func ReadObjectIDFormValue(r url.Values, key string) (primitive.ObjectID, bool) { + strval := r.Get(key) + if len(strval) == 0 { + return primitive.NilObjectID, false + } + id, err := primitive.ObjectIDFromHex(strval) + if err != nil { + logger.Error("common.ReadObjectIDFormValue failed :", key, err) + return primitive.NilObjectID, false + } + return id, true +} + +func ReadBoolFormValue(r url.Values, key string) (bool, bool) { + strval := r.Get(key) + if len(strval) == 0 { + return false, false + } + ret, err := strconv.ParseBool(strval) + if err != nil { + logger.Error("common.ReadBoolFormValue failed :", key, err) + return false, false + } + return ret, true +} + +func ReadStringFormValue(r url.Values, key string) (string, bool) { + strval := r.Get(key) + return strval, len(strval) > 0 +} + +func DotStringToTimestamp(tv string) primitive.Timestamp { + if len(tv) == 0 { + return primitive.Timestamp{T: 0, I: 0} + } + + ti := strings.Split(tv, ".") + t, _ := strconv.ParseUint(ti[0], 10, 0) + if len(ti) > 1 { + i, _ := strconv.ParseUint(ti[1], 10, 0) + return primitive.Timestamp{T: uint32(t), I: uint32(i)} + } + + return primitive.Timestamp{T: uint32(t), I: 0} +} + +type RpcReturnSimple struct { + value any +} + +func (rt *RpcReturnSimple) Value() any { + return rt.value +} + +func (rt *RpcReturnSimple) Error() error { + return nil +} + +func (rt *RpcReturnSimple) Serialize(w http.ResponseWriter) error { + err := SerializeInterface(w, rt.value) + if err == nil { + w.Write([]byte{0, 0}) + } + return err +} + +type RpcReturnError struct { + err error + code int + h map[string]any +} + +func (rt *RpcReturnError) Value() any { + return nil +} + +var errDefaultError = errors.New("unknown error") + +func (rt *RpcReturnError) Error() error { + if rt.err != nil { + return rt.err + } + + if rt.code != 0 { + if rt.h == nil { + return fmt.Errorf("http status code %d error", rt.code) + } + return fmt.Errorf("http status code %d error with header %v", rt.code, rt.h) + } + + return errDefaultError +} + +func (rt *RpcReturnError) Serialize(w http.ResponseWriter) error { + if rt.h != nil { + bt, _ := json.Marshal(rt.h) + w.Header().Add("As-X-Err", string(bt)) + } + + w.WriteHeader(rt.code) + if rt.err != nil { + logger.Error(rt.err) + } + + if rt.code >= http.StatusInternalServerError { + logger.Println("rpc return error :", rt.code, rt.h) + } + + return nil +} + +func (rt *RpcReturnError) WithCode(code int) *RpcReturnError { + rt.code = code + return rt +} + +func (rt *RpcReturnError) WithError(err error) *RpcReturnError { + rt.err = err + return rt +} + +func (rt *RpcReturnError) WithHeader(k string, v any) *RpcReturnError { + if rt.h == nil { + rt.h = make(map[string]any) + } + rt.h[k] = v + return rt +} + +func (rt *RpcReturnError) WithHeaders(h map[string]any) *RpcReturnError { + if rt.h == nil { + rt.h = h + } else { + for k, v := range h { + rt.h[k] = v + } + } + return rt +} + +// MakeRPCReturn : +func MakeRPCReturn(value interface{}) *RpcReturnSimple { + return &RpcReturnSimple{ + value: value, + } +} + +func MakeRPCError() *RpcReturnError { + pc, _, _, ok := runtime.Caller(1) + if ok { + frames := runtime.CallersFrames([]uintptr{pc}) + frame, _ := frames.Next() + logger.Printf("rpc error. func=%s, file=%s, line=%d", frame.Function, frame.File, frame.Line) + } + + return &RpcReturnError{ + err: nil, + code: http.StatusInternalServerError, + h: nil, + } +} diff --git a/xboxlive/xboxlive.go b/xboxlive/xboxlive.go new file mode 100644 index 0000000..8d5d5fc --- /dev/null +++ b/xboxlive/xboxlive.go @@ -0,0 +1,328 @@ +package xboxlive + +import ( + "bytes" + "compress/flate" + "crypto/aes" + "crypto/cipher" + "crypto/hmac" + "crypto/rand" + "crypto/rsa" + "crypto/sha1" + "crypto/sha256" + "crypto/x509" + "errors" + "io" + "log" + "net/http" + "os" + "strings" + "sync" + + b64 "encoding/base64" + "encoding/binary" + "encoding/json" + "encoding/pem" + + "github.com/golang-jwt/jwt" + "golang.org/x/crypto/pkcs12" + //"software.sslmate.com/src/go-pkcs12" +) + +type JWT_Header struct { + Typ string `json:"typ"` + Alg string `json:"alg"` + X5t string `json:"x5t"` + X5u string `json:"x5u"` +} + +type JWT_XDI_data struct { + Dty string `json:"dty"` // Device Type : ex) XboxOne +} + +type JWT_XUI_data struct { + Ptx string `json:"ptx"` // 파트너 Xbox 사용자 ID (ptx) - PXUID (ptx), publisher별로 고유한 ID : ex) 293812B467D21D3295ADA06B121981F805CC38F0 + Gtg string `json:"gtg"` // 게이머 태그 +} + +type JWT_XBoxLiveBody struct { + XDI JWT_XDI_data `json:"xdi"` + XUI []JWT_XUI_data `json:"xui"` + Sbx string `json:"sbx"` // SandBoxID : ex) BLHLQG.99 +} + +var cachedCert map[string]map[string]string +var cachedCertLock = new(sync.RWMutex) + +func getcachedCert(x5u string, x5t string) string { + cachedCertLock.Lock() + defer cachedCertLock.Unlock() + + if cachedCert == nil { + cachedCert = make(map[string]map[string]string) + } + + var certKey string + certKey = "" + + if CachedCertURI, existCachedCertURI := cachedCert[x5u]; existCachedCertURI { + if CachedCert, existCachedCert := CachedCertURI[x5t]; existCachedCert { + certKey = CachedCert + } + } + return certKey +} + +func setcachedCert(x5u string, x5t string, certKey string) { + cachedCertLock.Lock() + defer cachedCertLock.Unlock() + if cachedCert[x5u] == nil { + cachedCert[x5u] = make(map[string]string) + } + + cachedCert[x5u][x5t] = certKey +} + +func JWT_DownloadXSTSSigningCert(x5u string, x5t string) string { + certKey := getcachedCert(x5u, x5t) + + // -- 캐싱된 자료가 없으면 웹에서 받아 온다. + if certKey == "" { + resp, err := http.Get(x5u) // GET 호출 + if err != nil { + panic(err) + } + defer func() { + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + }() + + // 결과 출력 + data, err := io.ReadAll(resp.Body) + if err != nil { + panic(err) + } + + var parseddata map[string]string + err = json.Unmarshal([]byte(data), &parseddata) + if err != nil { + panic(err) + } + + if downloadedkey, exist := parseddata[x5t]; exist { + // downloadedkey = strings.Replace(downloadedkey, "-----BEGIN CERTIFICATE-----\n", "", -1) + // downloadedkey = strings.Replace(downloadedkey, "\n-----END CERTIFICATE-----\n", "", -1) + certKey = downloadedkey + } else { + panic("JWT_DownloadXSTSSigningCert : Key not found : " + x5t) + } + } + + setcachedCert(x5u, x5t, certKey) + return certKey +} + +func jwt_Decrypt_forXBoxLive(jwt_token string) (JWT_Header, JWT_XBoxLiveBody) { + parts := strings.Split(jwt_token, ".") + jwt_header, err := b64.RawURLEncoding.DecodeString(parts[0]) + if err != nil { + panic(err) + } + + JWT_Header_obj := JWT_Header{} + json.Unmarshal([]byte(string(jwt_header)), &JWT_Header_obj) + + if JWT_Header_obj.Typ != "JWT" { + panic("JWT Decrypt Error : typ is not JWT") + } + + if JWT_Header_obj.Alg != "RS256" { + panic("JWT Decrypt Error : alg is not RS256") + } + + var publicKey string + if len(JWT_Header_obj.X5u) >= len("https://xsts.auth.xboxlive.com") && JWT_Header_obj.X5u[:len("https://xsts.auth.xboxlive.com")] == "https://xsts.auth.xboxlive.com" { + publicKey = JWT_DownloadXSTSSigningCert(JWT_Header_obj.X5u, JWT_Header_obj.X5t) + } else { + panic("JWT Decrypt Error : Invalid x5u host that is not trusted" + JWT_Header_obj.X5u) + } + + block, _ := pem.Decode([]byte(publicKey)) + var cert *x509.Certificate + cert, _ = x509.ParseCertificate(block.Bytes) + rsaPublicKey := cert.PublicKey.(*rsa.PublicKey) + + err = verifyJWT_forXBoxLive(jwt_token, rsaPublicKey) + if err != nil { + panic(err) + } + jwt_body, err := b64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + panic(err) + } + JWT_XBoxLiveBody_obj := JWT_XBoxLiveBody{} + json.Unmarshal([]byte(string(jwt_body)), &JWT_XBoxLiveBody_obj) + + return JWT_Header_obj, JWT_XBoxLiveBody_obj + +} + +func verifyJWT_forXBoxLive(decompressed string, rsaPublicKey *rsa.PublicKey) error { + + token, err := jwt.Parse(decompressed, func(token *jwt.Token) (interface{}, error) { + return rsaPublicKey, nil + }) + + if err != nil { + if err := err.(*jwt.ValidationError); err != nil { + if err.Errors == jwt.ValidationErrorExpired { + return nil + } + } + + return err + } + + if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid { + if claims["iss"].(string) != "xsts.auth.xboxlive.com" { + return errors.New("issuer is not valid") + } + if claims["aud"].(string) != "rp://actionsquaredev.com/" { + return errors.New("audience is not valid") + } + return nil + } + + return errors.New("token is not valid") +} + +func splitSecretKey(data []byte) ([]byte, []byte) { + if len(data) < 2 { + panic(" SplitSecretKey : secretkey is too small.. ") + } + + if len(data)%2 != 0 { + panic(" SplitSecretKey : data error ") + } + + midpoint := len(data) / 2 + + firstHalf := data[0:midpoint] + secondHalf := data[midpoint : midpoint+midpoint] + return firstHalf, secondHalf +} + +func pkcs7UnPadding(origData []byte) []byte { + length := len(origData) + unpadding := int(origData[length-1]) + return origData[:(length - unpadding)] +} + +func aesCBCDncrypt(plaintext []byte, key []byte, iv []byte) []byte { + block, _ := aes.NewCipher(key) + ciphertext := make([]byte, len(plaintext)) + mode := cipher.NewCBCDecrypter(block, iv) + mode.CryptBlocks(ciphertext, plaintext) + ciphertext = pkcs7UnPadding(ciphertext) + return ciphertext +} + +func deflate(inflated []byte) string { + + byteReader := bytes.NewReader(inflated) + + wBuf := new(strings.Builder) + + zr := flate.NewReader(byteReader) + if _, err := io.Copy(wBuf, zr); err != nil { + log.Fatal(err) + } + + return wBuf.String() +} + +var errHashMismatch = errors.New("authentication tag does not match with the computed hash") + +func verifyAuthenticationTag(aad []byte, iv []byte, cipherText []byte, hmacKey []byte, authTag []byte) error { + + aadBitLength := make([]byte, 8) + binary.BigEndian.PutUint64(aadBitLength, uint64(len(aad)*8)) + + dataToSign := append(append(append(aad, iv...), cipherText...), aadBitLength...) + + h := hmac.New(sha256.New, []byte(hmacKey)) + h.Write([]byte(dataToSign)) + hash := h.Sum(nil) + + computedAuthTag, _ := splitSecretKey(hash) + + // Check if the auth tag is equal + // The authentication tag is the first half of the hmac result + if !bytes.Equal(authTag, computedAuthTag) { + return errHashMismatch + } + + return nil +} + +var privateKeydata []byte + +func privateKeyFile() string { + return os.Getenv("XBOXLIVE_PTX_FILE_NAME") +} +func privateKeyFilePass() string { + return os.Getenv("XBOXLIVE_PTX_FILE_PASSWORD") +} + +func Init() { + if len(privateKeyFile()) == 0 || len(privateKeyFilePass()) == 0 { + return + } + + var err error + privateKeydata, err = os.ReadFile(privateKeyFile()) + if err != nil { + panic("Error reading private key file") + } +} + +// 실제 체크 함수 +func AuthCheck(token string) (ptx string, err error) { + parts := strings.Split(token, ".") + + encryptedData, _ := b64.RawURLEncoding.DecodeString(parts[1]) + + privateKey, _, e := pkcs12.Decode(privateKeydata, privateKeyFilePass()) + if e != nil { + return "", e + } + + if e := privateKey.(*rsa.PrivateKey).Validate(); e != nil { + return "", e + } + + hash := sha1.New() + random := rand.Reader + decryptedData, decryptErr := rsa.DecryptOAEP(hash, random, privateKey.(*rsa.PrivateKey), encryptedData, nil) + if decryptErr != nil { + return "", decryptErr + } + + hmacKey, aesKey := splitSecretKey(decryptedData) + + iv, _ := b64.RawURLEncoding.DecodeString(parts[2]) + encryptedContent, _ := b64.RawURLEncoding.DecodeString(parts[3]) + + // Decrypt the payload using the AES + IV + decrypted := aesCBCDncrypt(encryptedContent, aesKey, iv) + decompressed := deflate(decrypted) + + _, body := jwt_Decrypt_forXBoxLive(decompressed) + + authTag, _ := b64.RawURLEncoding.DecodeString(parts[4]) + authData := []byte(parts[0]) + err = verifyAuthenticationTag(authData, iv, encryptedContent, hmacKey, authTag) + + return body.XUI[0].Ptx, err + +}